Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion homeassistant/components/cloud/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ async def async_setup(hass, config):
)

cloud = hass.data[DOMAIN] = Cloud(hass, **kwargs)
await auth_api.async_setup(hass, cloud)
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, cloud.async_start)
await http_api.async_setup(hass)
return True
Expand Down Expand Up @@ -263,7 +264,7 @@ def load_config():
self.access_token = info['access_token']
self.refresh_token = info['refresh_token']

self.hass.add_job(self.iot.connect())
self.hass.async_create_task(self.iot.connect())

def _decode_claims(self, token): # pylint: disable=no-self-use
"""Decode the claims in a token."""
Expand Down
73 changes: 66 additions & 7 deletions homeassistant/components/cloud/auth_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
"""Package to communicate with the authentication API."""
import asyncio
import logging
import random


_LOGGER = logging.getLogger(__name__)


class CloudError(Exception):
Expand Down Expand Up @@ -39,6 +45,40 @@ class UnknownError(CloudError):
}


async def async_setup(hass, cloud):
"""Configure the auth api."""
refresh_task = None

async def handle_token_refresh():
"""Handle Cloud access token refresh."""
sleep_time = 5
sleep_time = random.randint(2400, 3600)
while True:
try:
await asyncio.sleep(sleep_time)
await hass.async_add_executor_job(renew_access_token, cloud)
except CloudError as err:
_LOGGER.error("Can't refresh cloud token: %s", err)
except asyncio.CancelledError:
# Task is canceled, stop it.
break

sleep_time = random.randint(3100, 3600)

async def on_connect():
"""When the instance is connected."""
nonlocal refresh_task
refresh_task = hass.async_create_task(handle_token_refresh())

async def on_disconnect():
"""When the instance is disconnected."""
nonlocal refresh_task
refresh_task.cancel()

cloud.iot.register_on_connect(on_connect)
cloud.iot.register_on_disconnect(on_disconnect)


def _map_aws_exception(err):
"""Map AWS exception to our exceptions."""
ex = AWS_EXCEPTIONS.get(err.response['Error']['Code'], UnknownError)
Expand All @@ -47,21 +87,24 @@ def _map_aws_exception(err):

def register(cloud, email, password):
"""Register a new account."""
from botocore.exceptions import ClientError
from botocore.exceptions import ClientError, EndpointConnectionError

cognito = _cognito(cloud)
# Workaround for bug in Warrant. PR with fix:
# https://github.com/capless/warrant/pull/82
cognito.add_base_attributes()
try:
cognito.register(email, password)

except ClientError as err:
raise _map_aws_exception(err)
except EndpointConnectionError:
raise UnknownError()


def resend_email_confirm(cloud, email):
"""Resend email confirmation."""
from botocore.exceptions import ClientError
from botocore.exceptions import ClientError, EndpointConnectionError

cognito = _cognito(cloud, username=email)

Expand All @@ -72,18 +115,23 @@ def resend_email_confirm(cloud, email):
)
except ClientError as err:
raise _map_aws_exception(err)
except EndpointConnectionError:
raise UnknownError()


def forgot_password(cloud, email):
"""Initialize forgotten password flow."""
from botocore.exceptions import ClientError
from botocore.exceptions import ClientError, EndpointConnectionError

cognito = _cognito(cloud, username=email)

try:
cognito.initiate_forgot_password()

except ClientError as err:
raise _map_aws_exception(err)
except EndpointConnectionError:
raise UnknownError()


def login(cloud, email, password):
Expand All @@ -97,7 +145,7 @@ def login(cloud, email, password):

def check_token(cloud):
"""Check that the token is valid and verify if needed."""
from botocore.exceptions import ClientError
from botocore.exceptions import ClientError, EndpointConnectionError

cognito = _cognito(
cloud,
Expand All @@ -109,13 +157,17 @@ def check_token(cloud):
cloud.id_token = cognito.id_token
cloud.access_token = cognito.access_token
cloud.write_user_info()

except ClientError as err:
raise _map_aws_exception(err)

except EndpointConnectionError:
raise UnknownError()


def renew_access_token(cloud):
"""Renew access token."""
from botocore.exceptions import ClientError
from botocore.exceptions import ClientError, EndpointConnectionError

cognito = _cognito(
cloud,
Expand All @@ -127,13 +179,17 @@ def renew_access_token(cloud):
cloud.id_token = cognito.id_token
cloud.access_token = cognito.access_token
cloud.write_user_info()

except ClientError as err:
raise _map_aws_exception(err)

except EndpointConnectionError:
raise UnknownError()


def _authenticate(cloud, email, password):
"""Log in and return an authenticated Cognito instance."""
from botocore.exceptions import ClientError
from botocore.exceptions import ClientError, EndpointConnectionError
from warrant.exceptions import ForceChangePasswordException

assert not cloud.is_logged_in, 'Cannot login if already logged in.'
Expand All @@ -145,11 +201,14 @@ def _authenticate(cloud, email, password):
return cognito

except ForceChangePasswordException:
raise PasswordChangeRequired
raise PasswordChangeRequired()

except ClientError as err:
raise _map_aws_exception(err)

except EndpointConnectionError:
raise UnknownError()


def _cognito(cloud, **kwargs):
"""Get the client credentials."""
Expand Down
30 changes: 24 additions & 6 deletions homeassistant/components/cloud/iot.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,18 @@ def __init__(self, cloud):
# Local code waiting for a response
self._response_handler = {}
self._on_connect = []
self._on_disconnect = []

@callback
def register_on_connect(self, on_connect_cb):
"""Register an async on_connect callback."""
self._on_connect.append(on_connect_cb)

@callback
def register_on_disconnect(self, on_disconnect_cb):
"""Register an async on_disconnect callback."""
self._on_disconnect.append(on_disconnect_cb)

@property
def connected(self):
"""Return if we're currently connected."""
Expand Down Expand Up @@ -102,6 +108,17 @@ def _handle_hass_stop(event):
# Still adding it here to make sure we can always reconnect
_LOGGER.exception("Unexpected error")

if self.state == STATE_CONNECTED and self._on_disconnect:
try:
yield from asyncio.wait([
cb() for cb in self._on_disconnect
])
except Exception: # pylint: disable=broad-except
# Safety net. This should never hit.
# Still adding it here to make sure we don't break the flow
_LOGGER.exception(
"Unexpected error in on_disconnect callbacks")

if self.close_requested:
break

Expand Down Expand Up @@ -192,7 +209,13 @@ def _handle_connection(self):
self.state = STATE_CONNECTED

if self._on_connect:
yield from asyncio.wait([cb() for cb in self._on_connect])
try:
yield from asyncio.wait([cb() for cb in self._on_connect])
except Exception: # pylint: disable=broad-except
# Safety net. This should never hit.
# Still adding it here to make sure we don't break the flow
_LOGGER.exception(
"Unexpected error in on_connect callbacks")

while not client.closed:
msg = yield from client.receive()
Expand Down Expand Up @@ -326,11 +349,6 @@ async def async_handle_cloud(hass, cloud, payload):
await cloud.logout()
_LOGGER.error("You have been logged out from Home Assistant cloud: %s",
payload['reason'])
elif action == 'refresh_auth':
# Refresh the auth token between now and payload['seconds']
hass.helpers.event.async_call_later(
random.randint(0, payload['seconds']),
lambda now: auth_api.check_token(cloud))
else:
_LOGGER.warning("Received unknown cloud action: %s", action)

Expand Down
29 changes: 29 additions & 0 deletions tests/components/cloud/test_auth_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests for the tools to communicate with the cloud."""
import asyncio
from unittest.mock import MagicMock, patch

from botocore.exceptions import ClientError
Expand Down Expand Up @@ -165,3 +166,31 @@ def test_check_token_raises(mock_cognito):
assert cloud.id_token != mock_cognito.id_token
assert cloud.access_token != mock_cognito.access_token
assert len(cloud.write_user_info.mock_calls) == 0


async def test_async_setup(hass):
"""Test async setup."""
cloud = MagicMock()
await auth_api.async_setup(hass, cloud)
assert len(cloud.iot.mock_calls) == 2
on_connect = cloud.iot.mock_calls[0][1][0]
on_disconnect = cloud.iot.mock_calls[1][1][0]

with patch('random.randint', return_value=0), patch(
'homeassistant.components.cloud.auth_api.renew_access_token'
) as mock_renew:
await on_connect()
# Let handle token sleep once
await asyncio.sleep(0)
# Let handle token refresh token
await asyncio.sleep(0)

assert len(mock_renew.mock_calls) == 1
assert mock_renew.mock_calls[0][1][0] is cloud

await on_disconnect()

# Make sure task is no longer being called
await asyncio.sleep(0)
await asyncio.sleep(0)
assert len(mock_renew.mock_calls) == 1
23 changes: 1 addition & 22 deletions tests/components/cloud/test_iot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
Cloud, iot, auth_api, MODE_DEV)
from homeassistant.components.cloud.const import (
PREF_ENABLE_ALEXA, PREF_ENABLE_GOOGLE)
from homeassistant.util import dt as dt_util
from tests.components.alexa import test_smart_home as test_alexa
from tests.common import mock_coro, async_fire_time_changed
from tests.common import mock_coro

from . import mock_cloud_prefs

Expand Down Expand Up @@ -158,26 +157,6 @@ async def test_handling_core_messages_logout(hass, mock_cloud):
assert len(mock_cloud.logout.mock_calls) == 1


async def test_handling_core_messages_refresh_auth(hass, mock_cloud):
"""Test handling core messages."""
mock_cloud.hass = hass
with patch('random.randint', return_value=0) as mock_rand, patch(
'homeassistant.components.cloud.auth_api.check_token'
) as mock_check:
await iot.async_handle_cloud(hass, mock_cloud, {
'action': 'refresh_auth',
'seconds': 230,
})
async_fire_time_changed(hass, dt_util.utcnow())
await hass.async_block_till_done()

assert len(mock_rand.mock_calls) == 1
assert mock_rand.mock_calls[0][1] == (0, 230)

assert len(mock_check.mock_calls) == 1
assert mock_check.mock_calls[0][1][0] is mock_cloud


@asyncio.coroutine
def test_cloud_getting_disconnected_by_server(mock_client, caplog, mock_cloud):
"""Test server disconnecting instance."""
Expand Down