diff --git a/custom_components/alexa_media/config_flow.py b/custom_components/alexa_media/config_flow.py index d95aa265..9d13275a 100644 --- a/custom_components/alexa_media/config_flow.py +++ b/custom_components/alexa_media/config_flow.py @@ -14,7 +14,8 @@ from functools import reduce import logging import re -from typing import Any, Optional, Text +from typing import Any, Dict, List, Optional, Text +from yarl import URL from aiohttp import ClientConnectionError, ClientSession, web, web_response from aiohttp.web_exceptions import HTTPBadRequest @@ -85,6 +86,8 @@ class AlexaMediaFlowHandler(config_entries.ConfigFlow): VERSION = 1 CONNECTION_CLASS = config_entries.CONN_CLASS_CLOUD_POLL + proxy: AlexaProxy = None + proxy_view: "AlexaMediaAuthorizationProxyView" = None def _update_ord_dict(self, old_dict: OrderedDict, new_dict: dict) -> OrderedDict: result: OrderedDict = OrderedDict() @@ -176,7 +179,6 @@ def __init__(self): self.totp_register = OrderedDict( [(vol.Optional(CONF_TOTP_REGISTER, default=False), bool)] ) - self.proxy = None async def async_step_import(self, import_config): """Import a config entry from configuration.yaml.""" @@ -323,7 +325,12 @@ async def async_step_user(self, user_input=None): errors={"base": "hass_url_invalid"}, description_placeholders={"message": ""}, ) - self.proxy = AlexaProxy(self.login, f"{hass_url}{AUTH_PROXY_PATH}") + if not self.proxy: + self.proxy = AlexaProxy( + self.login, str(URL(hass_url).with_path(AUTH_PROXY_PATH)) + ) + # Swap the login object + self.proxy.change_login(self.login) if ( user_input and user_input.get(CONF_OTPSECRET) @@ -349,18 +356,35 @@ async def async_step_start_proxy(self, user_input=None): _LOGGER.debug( "Starting proxy for %s - %s", hide_email(self.login.email), self.login.url, ) + if not self.proxy_view: + self.proxy_view = AlexaMediaAuthorizationProxyView(self.proxy.all_handler) + else: + _LOGGER.debug("Found existing proxy_view") + self.proxy_view.handler = self.proxy.all_handler self.hass.http.register_view(AlexaMediaAuthorizationCallbackView()) - self.hass.http.register_view( - AlexaMediaAuthorizationProxyView(self.proxy.all_handler) - ) + self.hass.http.register_view(self.proxy_view) callback_url = ( - f"{self.config['hass_url']}{AUTH_CALLBACK_PATH}?flow_id={self.flow_id}" + URL(self.config["hass_url"]) + .with_path(AUTH_CALLBACK_PATH) + .with_query({"flow_id": self.flow_id}) + ) + + proxy_url = self.proxy.access_url().with_query( + {"config_flow_id": self.flow_id, "callback_url": str(callback_url)} ) - proxy_url = f"{self.proxy.access_url()}?config_flow_id={self.flow_id}&callback_url={callback_url}" if self.login.lastreq: self.proxy.last_resp = self.login.lastreq - proxy_url = f"{self.proxy.access_url()}/resume?config_flow_id={self.flow_id}&callback_url={callback_url}" - return self.async_external_step(step_id="check_proxy", url=proxy_url) + self.proxy.session.cookie_jar.update_cookies( + self.login._session.cookie_jar.filter_cookies( + self.proxy._host_url.with_path("/") + ) + ) + proxy_url = ( + self.proxy.access_url().with_path(AUTH_PROXY_PATH) / "resume" + ).with_query( + {"config_flow_id": self.flow_id, "callback_url": str(callback_url)} + ) + return self.async_external_step(step_id="check_proxy", url=str(proxy_url)) async def async_step_check_proxy(self, user_input=None): """Check status of proxy for login.""" @@ -369,6 +393,7 @@ async def async_step_check_proxy(self, user_input=None): hide_email(self.login.email), self.login.url, ) + self.proxy_view.reset() return self.async_external_step_done(next_step_id="finish_proxy") async def async_step_finish_proxy(self, user_input=None): @@ -1021,10 +1046,13 @@ async def get(self, request: web.Request): class AlexaMediaAuthorizationProxyView(HomeAssistantView): """Handle proxy connections.""" - url = AUTH_PROXY_PATH - extra_urls = [f"{AUTH_PROXY_PATH}/{{tail:.*}}"] - name = AUTH_PROXY_NAME - requires_auth = False + url: Text = AUTH_PROXY_PATH + extra_urls: List[Text] = [f"{AUTH_PROXY_PATH}/{{tail:.*}}"] + name: Text = AUTH_PROXY_NAME + requires_auth: bool = False + handler: web.RequestHandler = None + known_ips: Dict[Text, datetime.datetime] = {} + auth_seconds: int = 300 def __init__(self, handler: web.RequestHandler): """Initialize routes for view. @@ -1033,25 +1061,22 @@ def __init__(self, handler: web.RequestHandler): handler (web.RequestHandler): Handler to apply to all method types """ - self.handler = handler + AlexaMediaAuthorizationProxyView.handler = handler for method in ("get", "post", "delete", "put", "patch", "head", "options"): - setattr(self, method, self.handler_wrapper(handler)) - self.known_ips = {} + setattr(self, method, self.check_auth()) - def handler_wrapper(self, handler): + @classmethod + def check_auth(cls): """Wrap authentication into the handler.""" async def wrapped(request, **kwargs): """Notify that the API is running.""" hass = request.app["hass"] - _LOGGER.debug("request %s", request.url.query) success = False if ( - request.remote not in self.known_ips - or ( - datetime.datetime.now() - self.known_ips.get(request.remote) - ).seconds - > 300 + request.remote not in cls.known_ips + or (datetime.datetime.now() - cls.known_ips.get(request.remote)).seconds + > cls.auth_seconds ): try: flow_id = request.url.query["config_flow_id"] @@ -1059,11 +1084,20 @@ async def wrapped(request, **kwargs): raise Unauthorized() from ex for flow in hass.config_entries.flow.async_progress(): if flow["flow_id"] == flow_id: - _LOGGER.debug("Found flow_id") + _LOGGER.debug( + "Found flow_id; adding %s to known_ips for %s seconds", + request.remote, + cls.auth_seconds, + ) success = True if not success: raise Unauthorized() - self.known_ips[request.remote] = datetime.datetime.now() - return await handler(request, **kwargs) + cls.known_ips[request.remote] = datetime.datetime.now() + return await cls.handler(request, **kwargs) return wrapped + + @classmethod + def reset(cls) -> None: + """Reset the view.""" + cls.known_ips = {} diff --git a/custom_components/alexa_media/manifest.json b/custom_components/alexa_media/manifest.json index b2c3c4e6..8e6c689f 100644 --- a/custom_components/alexa_media/manifest.json +++ b/custom_components/alexa_media/manifest.json @@ -7,5 +7,5 @@ "issue_tracker": "https://github.com/custom-components/alexa_media_player/issues", "dependencies": ["persistent_notification", "http"], "codeowners": ["@keatontaylor", "@alandtse"], - "requirements": ["alexapy==1.23.1", "packaging~=20.3", "wrapt~=1.12.1"] + "requirements": ["alexapy==1.24.0", "packaging~=20.3", "wrapt~=1.12.1"] }