Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 91 additions & 64 deletions homeassistant/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,14 @@ def name(self):

async def async_credentials(self):
"""Return all credentials of this provider."""
return await self.store.credentials_for_provider(self.type, self.id)
users = await self.store.async_get_users()
return [
credentials
for user in users
for credentials in user.credentials
if (credentials.auth_provider_type == self.type and
credentials.auth_provider_id == self.id)
]

@callback
def async_create_credentials(self, data):
Expand Down Expand Up @@ -118,10 +125,11 @@ async def async_user_meta_for_credentials(self, credentials):
class User:
"""A user."""

name = attr.ib(type=str)
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
is_owner = attr.ib(type=bool, default=False)
is_active = attr.ib(type=bool, default=False)
name = attr.ib(type=str, default=None)
system_generated = attr.ib(type=bool, default=False)

# List of credentials of a user.
credentials = attr.ib(type=list, default=attr.Factory(list), cmp=False)
Expand Down Expand Up @@ -300,10 +308,45 @@ async def async_get_user(self, user_id):
"""Retrieve a user."""
return await self._store.async_get_user(user_id)

async def async_create_system_user(self, name):
"""Create a system user."""
return await self._store.async_create_user(
name=name,
system_generated=True,
is_active=True,
)

async def async_get_or_create_user(self, credentials):
"""Get or create a user."""
return await self._store.async_get_or_create_user(
credentials, self._async_get_auth_provider(credentials))
if not credentials.is_new:
for user in await self._store.async_get_users():
for creds in user.credentials:
if (creds.auth_provider_type ==
credentials.auth_provider_type
and creds.auth_provider_id ==
credentials.auth_provider_id):
return user

raise ValueError('Unable to find the user.')

auth_provider = self._async_get_auth_provider(credentials)
info = await auth_provider.async_user_meta_for_credentials(
credentials)

kwargs = {
'credentials': credentials,
'name': info.get('name')
}

# Make owner and activate user if it's the first user.
if await self._store.async_get_users():
kwargs['is_owner'] = False
kwargs['is_active'] = False
else:
kwargs['is_owner'] = True
kwargs['is_active'] = True

return await self._store.async_create_user(**kwargs)

async def async_link_user(self, user, credentials):
"""Link credentials to an existing user."""
Expand All @@ -313,9 +356,20 @@ async def async_remove_user(self, user):
"""Remove a user."""
await self._store.async_remove_user(user)

async def async_create_refresh_token(self, user, client_id):
async def async_create_refresh_token(self, user, client=None):
"""Create a new refresh token for a user."""
return await self._store.async_create_refresh_token(user, client_id)
if not user.is_active:
raise ValueError('User is not active')

if user.system_generated and client is not None:
raise ValueError(
'System generated users cannot have refresh tokens connected '
'to a client.')

if not user.system_generated and client is None:
raise ValueError('Client is required to generate a refresh token.')

return await self._store.async_create_refresh_token(user, client)

async def async_get_refresh_token(self, token):
"""Get refresh token by token."""
Expand All @@ -324,7 +378,7 @@ async def async_get_refresh_token(self, token):
@callback
def async_create_access_token(self, refresh_token):
"""Create a new access token."""
access_token = AccessToken(refresh_token)
access_token = AccessToken(refresh_token=refresh_token)
self._access_tokens[access_token.token] = access_token
return access_token

Expand Down Expand Up @@ -405,19 +459,6 @@ def __init__(self, hass):
self._clients = None
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)

async def credentials_for_provider(self, provider_type, provider_id):
"""Return credentials for specific auth provider type and id."""
if self._users is None:
await self.async_load()

return [
credentials
for user in self._users.values()
for credentials in user.credentials
if (credentials.auth_provider_type == provider_type and
credentials.auth_provider_id == provider_id)
]

async def async_get_users(self):
"""Retrieve all users."""
if self._users is None:
Expand All @@ -426,50 +467,42 @@ async def async_get_users(self):
return list(self._users.values())

async def async_get_user(self, user_id):
"""Retrieve a user."""
"""Retrieve a user by id."""
if self._users is None:
await self.async_load()

return self._users.get(user_id)

async def async_get_or_create_user(self, credentials, auth_provider):
"""Get or create a new user for given credentials.

If link_user is passed in, the credentials will be linked to the passed
in user if the credentials are new.
"""
async def async_create_user(self, name, is_owner=None, is_active=None,
system_generated=None, credentials=None):
"""Create a new user."""
if self._users is None:
await self.async_load()

# New credentials, store in user
if credentials.is_new:
info = await auth_provider.async_user_meta_for_credentials(
credentials)
# Make owner and activate user if it's the first user.
if self._users:
is_owner = False
is_active = False
else:
is_owner = True
is_active = True

new_user = User(
is_owner=is_owner,
is_active=is_active,
name=info.get('name'),
)
self._users[new_user.id] = new_user
await self.async_link_user(new_user, credentials)
return new_user
kwargs = {
'name': name
}

for user in self._users.values():
for creds in user.credentials:
if (creds.auth_provider_type == credentials.auth_provider_type
and creds.auth_provider_id ==
credentials.auth_provider_id):
return user
if is_owner is not None:
kwargs['is_owner'] = is_owner

if is_active is not None:
kwargs['is_active'] = is_active

if system_generated is not None:
kwargs['system_generated'] = system_generated

new_user = User(**kwargs)

raise ValueError('We got credentials with ID but found no user')
self._users[new_user.id] = new_user

if credentials is None:
await self.async_save()
return new_user

# Saving is done inside the link.
await self.async_link_user(new_user, credentials)
return new_user

async def async_link_user(self, user, credentials):
"""Add credentials to an existing user."""
Expand All @@ -482,17 +515,10 @@ async def async_remove_user(self, user):
self._users.pop(user.id)
await self.async_save()

async def async_create_refresh_token(self, user, client_id):
async def async_create_refresh_token(self, user, client=None):
"""Create a new token for a user."""
local_user = await self.async_get_user(user.id)
if local_user is None:
raise ValueError('Invalid user')

local_client = await self.async_get_client(client_id)
if local_client is None:
raise ValueError('Invalid client_id')

refresh_token = RefreshToken(user, client_id)
client_id = client.id if client is not None else None
refresh_token = RefreshToken(user=user, client_id=client_id)
user.refresh_tokens[refresh_token.token] = refresh_token
await self.async_save()
return refresh_token
Expand Down Expand Up @@ -607,6 +633,7 @@ async def async_save(self):
'is_owner': user.is_owner,
'is_active': user.is_active,
'name': user.name,
'system_generated': user.system_generated,
}
for user in self._users.values()
]
Expand Down
16 changes: 7 additions & 9 deletions homeassistant/components/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,18 +236,16 @@ async def post(self, request, client):
grant_type = data.get('grant_type')

if grant_type == 'authorization_code':
return await self._async_handle_auth_code(
hass, client.id, data)
return await self._async_handle_auth_code(hass, client, data)

elif grant_type == 'refresh_token':
return await self._async_handle_refresh_token(
hass, client.id, data)
return await self._async_handle_refresh_token(hass, client, data)

return self.json({
'error': 'unsupported_grant_type',
}, status_code=400)

async def _async_handle_auth_code(self, hass, client_id, data):
async def _async_handle_auth_code(self, hass, client, data):
"""Handle authorization code request."""
code = data.get('code')

Expand All @@ -256,7 +254,7 @@ async def _async_handle_auth_code(self, hass, client_id, data):
'error': 'invalid_request',
}, status_code=400)

credentials = self._retrieve_credentials(client_id, code)
credentials = self._retrieve_credentials(client.id, code)

if credentials is None:
return self.json({
Expand All @@ -265,7 +263,7 @@ async def _async_handle_auth_code(self, hass, client_id, data):

user = await hass.auth.async_get_or_create_user(credentials)
refresh_token = await hass.auth.async_create_refresh_token(user,
client_id)
client)
access_token = hass.auth.async_create_access_token(refresh_token)

return self.json({
Expand All @@ -276,7 +274,7 @@ async def _async_handle_auth_code(self, hass, client_id, data):
int(refresh_token.access_token_expiration.total_seconds()),
})

async def _async_handle_refresh_token(self, hass, client_id, data):
async def _async_handle_refresh_token(self, hass, client, data):
"""Handle authorization code request."""
token = data.get('refresh_token')

Expand All @@ -287,7 +285,7 @@ async def _async_handle_refresh_token(self, hass, client_id, data):

refresh_token = await hass.auth.async_get_refresh_token(token)

if refresh_token is None or refresh_token.client_id != client_id:
if refresh_token is None or refresh_token.client_id != client.id:
return self.json({
'error': 'invalid_grant',
}, status_code=400)
Expand Down
2 changes: 1 addition & 1 deletion tests/auth_providers/test_insecure_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ async def test_match_existing_credentials(store, provider):
},
is_new=False,
)
store.credentials_for_provider = Mock(return_value=mock_coro([existing]))
provider.async_credentials = Mock(return_value=mock_coro([existing]))
credentials = await provider.async_get_or_create_credentials({
'username': 'user-test',
'password': 'password-test',
Expand Down
16 changes: 12 additions & 4 deletions tests/auth_providers/test_legacy_api_password.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,28 @@ def provider(hass, store):
})


@pytest.fixture
def manager(hass, store, provider):
"""Mock manager."""
return auth.AuthManager(hass, store, {
(provider.type, provider.id): provider
})


async def test_create_new_credential(provider):
"""Test that we create a new credential."""
credentials = await provider.async_get_or_create_credentials({})
assert credentials.data["username"] is legacy_api_password.LEGACY_USER
assert credentials.is_new is True


async def test_only_one_credentials(store, provider):
async def test_only_one_credentials(manager, provider):
"""Call create twice will return same credential."""
credentials = await provider.async_get_or_create_credentials({})
await store.async_get_or_create_user(credentials, provider)
await manager.async_get_or_create_user(credentials)
credentials2 = await provider.async_get_or_create_credentials({})
assert credentials2.data["username"] is legacy_api_password.LEGACY_USER
assert credentials2.id is credentials.id
assert credentials2.data["username"] == legacy_api_password.LEGACY_USER
assert credentials2.id == credentials.id
assert credentials2.is_new is False


Expand Down
3 changes: 2 additions & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,8 @@ class MockUser(auth.User):
def __init__(self, id='mock-id', is_owner=True, is_active=True,
name='Mock User'):
"""Initialize mock user."""
super().__init__(id, is_owner, is_active, name)
super().__init__(
id=id, is_owner=is_owner, is_active=is_active, name=name)

def add_to_hass(self, hass):
"""Test helper to add entry to hass."""
Expand Down
2 changes: 1 addition & 1 deletion tests/components/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ def hass_access_token(hass):
no_secret=True,
))
refresh_token = hass.loop.run_until_complete(
hass.auth.async_create_refresh_token(user, client.id))
hass.auth.async_create_refresh_token(user, client))
yield hass.auth.async_create_access_token(refresh_token)
Loading