Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 43 additions & 21 deletions homeassistant/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from collections import OrderedDict
from typing import List, Awaitable

import jwt

from homeassistant import data_entry_flow
from homeassistant.core import callback, HomeAssistant
from homeassistant.util import dt as dt_util

This comment was marked as resolved.


from . import models
from . import auth_store
from .providers import auth_provider_from_config

Expand Down Expand Up @@ -54,7 +56,6 @@ def __init__(self, hass, store, providers):
self.login_flow = data_entry_flow.FlowManager(
hass, self._async_create_login_flow,
self._async_finish_login_flow)
self._access_tokens = OrderedDict()

@property
def active(self):
Expand Down Expand Up @@ -181,35 +182,56 @@ async def async_create_refresh_token(self, user, client_id=None):

return await self._store.async_create_refresh_token(user, client_id)

async def async_get_refresh_token(self, token):
async def async_get_refresh_token(self, token_id):
"""Get refresh token by id."""
return await self._store.async_get_refresh_token(token_id)

async def async_get_refresh_token_by_token(self, token):
"""Get refresh token by token."""
return await self._store.async_get_refresh_token(token)
return await self._store.async_get_refresh_token_by_token(token)

@callback
def async_create_access_token(self, refresh_token):
"""Create a new access token."""
access_token = models.AccessToken(refresh_token=refresh_token)
self._access_tokens[access_token.token] = access_token
return access_token

@callback
def async_get_access_token(self, token):
"""Get an access token."""
tkn = self._access_tokens.get(token)
# pylint: disable=no-self-use
return jwt.encode({
'iss': refresh_token.id,
'iat': dt_util.utcnow(),
'exp': dt_util.utcnow() + refresh_token.access_token_expiration,
}, refresh_token.jwt_key, algorithm='HS256').decode()

async def async_validate_access_token(self, token):
"""Return if an access token is valid."""
try:
unverif_claims = jwt.decode(token, verify=False)
except jwt.InvalidTokenError:
return None

if tkn is None:
_LOGGER.debug('Attempt to get non-existing access token')
refresh_token = await self.async_get_refresh_token(
unverif_claims.get('iss'))

if refresh_token is None:
jwt_key = ''
issuer = ''
else:
jwt_key = refresh_token.jwt_key
issuer = refresh_token.id

try:
jwt.decode(
token,
jwt_key,
leeway=10,
issuer=issuer,
algorithms=['HS256']
)
except jwt.InvalidTokenError:
return None

if tkn.expired or not tkn.refresh_token.user.is_active:
if tkn.expired:
_LOGGER.debug('Attempt to get expired access token')
else:
_LOGGER.debug('Attempt to get access token for inactive user')
self._access_tokens.pop(token)
if not refresh_token.user.is_active:
return None

return tkn
return refresh_token

async def _async_create_login_flow(self, handler, *, context, data):
"""Create a login flow."""
Expand Down
56 changes: 26 additions & 30 deletions homeassistant/auth/auth_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Storage for auth models."""
from collections import OrderedDict
from datetime import timedelta
import hmac

from homeassistant.util import dt as dt_util

Expand Down Expand Up @@ -110,22 +111,36 @@ async def async_remove_credentials(self, credentials):
async def async_create_refresh_token(self, user, client_id=None):
"""Create a new token for a user."""
refresh_token = models.RefreshToken(user=user, client_id=client_id)
user.refresh_tokens[refresh_token.token] = refresh_token
user.refresh_tokens[refresh_token.id] = refresh_token
await self.async_save()
return refresh_token

async def async_get_refresh_token(self, token):
"""Get refresh token by token."""
async def async_get_refresh_token(self, token_id):
"""Get refresh token by id."""
if self._users is None:
await self.async_load()

for user in self._users.values():
refresh_token = user.refresh_tokens.get(token)
refresh_token = user.refresh_tokens.get(token_id)
if refresh_token is not None:
return refresh_token

return None

async def async_get_refresh_token_by_token(self, token):
"""Get refresh token by token."""
if self._users is None:
await self.async_load()

found = None

for user in self._users.values():
for refresh_token in user.refresh_tokens.values():
if hmac.compare_digest(refresh_token.token, token):
found = refresh_token

return found

async def async_load(self):
"""Load the users."""
data = await self._store.async_load()
Expand Down Expand Up @@ -153,9 +168,11 @@ async def async_load(self):
data=cred_dict['data'],
))

refresh_tokens = OrderedDict()

for rt_dict in data['refresh_tokens']:
# Filter out the old keys that don't have jwt_key (pre-0.76)
if 'jwt_key' not in rt_dict:
continue

token = models.RefreshToken(
id=rt_dict['id'],
user=users[rt_dict['user_id']],
Expand All @@ -164,18 +181,9 @@ async def async_load(self):
access_token_expiration=timedelta(
seconds=rt_dict['access_token_expiration']),
token=rt_dict['token'],
jwt_key=rt_dict['jwt_key']
)
refresh_tokens[token.id] = token
users[rt_dict['user_id']].refresh_tokens[token.token] = token

for ac_dict in data['access_tokens']:
refresh_token = refresh_tokens[ac_dict['refresh_token_id']]
token = models.AccessToken(
refresh_token=refresh_token,
created_at=dt_util.parse_datetime(ac_dict['created_at']),
token=ac_dict['token'],
)
refresh_token.access_tokens.append(token)
users[rt_dict['user_id']].refresh_tokens[token.id] = token

self._users = users

Expand Down Expand Up @@ -213,27 +221,15 @@ async def async_save(self):
'access_token_expiration':
refresh_token.access_token_expiration.total_seconds(),
'token': refresh_token.token,
'jwt_key': refresh_token.jwt_key,
}
for user in self._users.values()
for refresh_token in user.refresh_tokens.values()
]

access_tokens = [
{
'id': user.id,
'refresh_token_id': refresh_token.id,
'created_at': access_token.created_at.isoformat(),
'token': access_token.token,
}
for user in self._users.values()
for refresh_token in user.refresh_tokens.values()
for access_token in refresh_token.access_tokens
]

data = {
'users': users,
'credentials': credentials,
'access_tokens': access_tokens,
'refresh_tokens': refresh_tokens,
}

Expand Down
22 changes: 2 additions & 20 deletions homeassistant/auth/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,26 +39,8 @@ class RefreshToken:
default=ACCESS_TOKEN_EXPIRATION)
token = attr.ib(type=str,
default=attr.Factory(lambda: generate_secret(64)))
access_tokens = attr.ib(type=list, default=attr.Factory(list), cmp=False)


@attr.s(slots=True)
class AccessToken:
"""Access token to access the API.

These will only ever be stored in memory and not be persisted.
"""

refresh_token = attr.ib(type=RefreshToken)
created_at = attr.ib(type=datetime, default=attr.Factory(dt_util.utcnow))
token = attr.ib(type=str,
default=attr.Factory(generate_secret))

@property
def expired(self):
"""Return if this token has expired."""
expires = self.created_at + self.refresh_token.access_token_expiration
return dt_util.utcnow() > expires
jwt_key = attr.ib(type=str,
default=attr.Factory(lambda: generate_secret(64)))


@attr.s(slots=True)
Expand Down
6 changes: 3 additions & 3 deletions homeassistant/components/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ async def _async_handle_auth_code(self, hass, data):
access_token = hass.auth.async_create_access_token(refresh_token)

return self.json({
'access_token': access_token.token,
'access_token': access_token,
'token_type': 'Bearer',
'refresh_token': refresh_token.token,
'expires_in':
Expand All @@ -178,7 +178,7 @@ async def _async_handle_refresh_token(self, hass, data):
'error': 'invalid_request',
}, status_code=400)

refresh_token = await hass.auth.async_get_refresh_token(token)
refresh_token = await hass.auth.async_get_refresh_token_by_token(token)

if refresh_token is None:
return self.json({
Expand All @@ -193,7 +193,7 @@ async def _async_handle_refresh_token(self, hass, data):
access_token = hass.auth.async_create_access_token(refresh_token)

return self.json({
'access_token': access_token.token,
'access_token': access_token,
'token_type': 'Bearer',
'expires_in':
int(refresh_token.access_token_expiration.total_seconds()),
Expand Down
6 changes: 3 additions & 3 deletions homeassistant/components/http/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ async def async_validate_auth_header(request, api_password=None):

if auth_type == 'Bearer':
hass = request.app['hass']
access_token = hass.auth.async_get_access_token(auth_val)
if access_token is None:
refresh_token = await hass.auth.async_validate_access_token(auth_val)
if refresh_token is None:
return False

request['hass_user'] = access_token.refresh_token.user
request['hass_user'] = refresh_token.user
return True

if auth_type == 'Basic' and api_password is not None:
Expand Down
9 changes: 5 additions & 4 deletions homeassistant/components/websocket_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,11 +355,12 @@ def handle_hass_stop(event):

if self.hass.auth.active and 'access_token' in msg:
self.debug("Received access_token")
token = self.hass.auth.async_get_access_token(
msg['access_token'])
authenticated = token is not None
refresh_token = \
await self.hass.auth.async_validate_access_token(
msg['access_token'])
authenticated = refresh_token is not None
if authenticated:
request['hass_user'] = token.refresh_token.user
request['hass_user'] = refresh_token.user

elif ((not self.hass.auth.active or
self.hass.auth.support_legacy) and
Expand Down
1 change: 1 addition & 0 deletions homeassistant/package_constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ async_timeout==3.0.0
attrs==18.1.0
certifi>=2018.04.16
jinja2>=2.10
PyJWT==1.6.4
pip>=8.0.3
pytz>=2018.04
pyyaml>=3.13,<4
Expand Down
1 change: 1 addition & 0 deletions requirements_all.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ async_timeout==3.0.0
attrs==18.1.0
certifi>=2018.04.16
jinja2>=2.10
PyJWT==1.6.4
pip>=8.0.3
pytz>=2018.04
pyyaml>=3.13,<4
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
'attrs==18.1.0',
'certifi>=2018.04.16',
'jinja2>=2.10',
'PyJWT==1.6.4',
'pip>=8.0.3',
'pytz>=2018.04',
'pyyaml>=3.13,<4',
Expand Down
45 changes: 12 additions & 33 deletions tests/auth/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,7 @@ async def test_saving_loading(hass, hass_storage):
})
user = await manager.async_get_or_create_user(step['result'])
await manager.async_activate_user(user)
refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID)

manager.async_create_access_token(refresh_token)
await manager.async_create_refresh_token(user, CLIENT_ID)

await flush_store(manager._store._store)

Expand All @@ -211,30 +209,6 @@ async def test_saving_loading(hass, hass_storage):
assert users[0] == user


def test_access_token_expired():
"""Test that the expired property on access tokens work."""
refresh_token = auth_models.RefreshToken(
user=None,
client_id='bla'
)

access_token = auth_models.AccessToken(
refresh_token=refresh_token
)

assert access_token.expired is False

with patch('homeassistant.util.dt.utcnow',
return_value=dt_util.utcnow() +
auth_const.ACCESS_TOKEN_EXPIRATION):
assert access_token.expired is True

almost_exp = \
dt_util.utcnow() + auth_const.ACCESS_TOKEN_EXPIRATION - timedelta(1)
with patch('homeassistant.util.dt.utcnow', return_value=almost_exp):
assert access_token.expired is False


async def test_cannot_retrieve_expired_access_token(hass):
"""Test that we cannot retrieve expired access tokens."""
manager = await auth.auth_manager_from_config(hass, [])
Expand All @@ -244,15 +218,20 @@ async def test_cannot_retrieve_expired_access_token(hass):
assert refresh_token.client_id == CLIENT_ID

access_token = manager.async_create_access_token(refresh_token)
assert manager.async_get_access_token(access_token.token) is access_token
assert (
await manager.async_validate_access_token(access_token)
is refresh_token
)

with patch('homeassistant.util.dt.utcnow',
return_value=dt_util.utcnow() +
auth_const.ACCESS_TOKEN_EXPIRATION):
assert manager.async_get_access_token(access_token.token) is None
return_value=dt_util.utcnow() -
auth_const.ACCESS_TOKEN_EXPIRATION - timedelta(seconds=11)):
access_token = manager.async_create_access_token(refresh_token)

# Even with unpatched time, it should have been removed from manager
assert manager.async_get_access_token(access_token.token) is None
assert (
await manager.async_validate_access_token(access_token)
is None
)


async def test_generating_system_user(hass):
Expand Down
Loading