diff --git a/homeassistant/components/generic/config_flow.py b/homeassistant/components/generic/config_flow.py index f9ef75a3ea3e9b..436b4e2b95e3d4 100644 --- a/homeassistant/components/generic/config_flow.py +++ b/homeassistant/components/generic/config_flow.py @@ -5,7 +5,6 @@ import asyncio from collections.abc import Mapping import contextlib -from copy import deepcopy from datetime import datetime, timedelta from errno import EHOSTUNREACH, EIO import io @@ -21,11 +20,9 @@ from homeassistant.components import websocket_api from homeassistant.components.camera import ( CAMERA_IMAGE_TIMEOUT, - DATA_CAMERA_PREFS, DOMAIN as CAMERA_DOMAIN, DynamicStreamSettings, _async_get_image, - init_camera_prefs, ) from homeassistant.components.http.view import HomeAssistantView from homeassistant.components.stream import ( @@ -34,6 +31,7 @@ HLS_PROVIDER, RTSP_TRANSPORTS, SOURCE_TIMEOUT, + Stream, create_stream, ) from homeassistant.config_entries import ( @@ -239,10 +237,14 @@ def slug( return None -async def async_test_stream( +async def async_test_and_preview_stream( hass: HomeAssistant, info: Mapping[str, Any] -) -> dict[str, str]: - """Verify that the stream is valid before we create an entity.""" +) -> dict[str, str] | PreviewStream: + """Verify that the stream is valid before we create an entity. + + Returns a dict with errors if any, or the stream object if valid. + The stream object is used to preview the video in the UI. + """ if not (stream_source := info.get(CONF_STREAM_SOURCE)): return {} # Import from stream.worker as stream cannot reexport from worker @@ -276,19 +278,20 @@ async def async_test_stream( url = url.with_user(username).with_password(password) stream_source = str(url) try: - stream = create_stream( - hass, - stream_source, - stream_options, - DynamicStreamSettings(), - "test_stream", + stream = PreviewStream( + create_stream( + hass, + stream_source, + stream_options, + DynamicStreamSettings(), + "test_stream", + ) ) hls_provider = stream.add_provider(HLS_PROVIDER) await stream.start() if not await hls_provider.part_recv(timeout=SOURCE_TIMEOUT): hass.async_create_task(stream.stop()) return {CONF_STREAM_SOURCE: "timeout"} - await stream.stop() except StreamWorkerError as err: return {CONF_STREAM_SOURCE: str(err)} except PermissionError: @@ -303,7 +306,7 @@ async def async_test_stream( if "Stream integration is not set up" in str(err): return {CONF_STREAM_SOURCE: "stream_not_set_up"} raise - return {} + return stream def register_preview(hass: HomeAssistant) -> None: @@ -316,38 +319,6 @@ def register_preview(hass: HomeAssistant) -> None: hass.data[DOMAIN][IMAGE_PREVIEWS_ACTIVE] = True -async def register_stream_preview(hass: HomeAssistant, config) -> str: - """Set up preview for camera stream during config flow.""" - hass.data.setdefault("camera", {}) - - # Need to load the camera prefs early to avoid errors generating the stream - # if the user does not already have the stream component loaded. - if hass.data.get(DATA_CAMERA_PREFS) is None: - await init_camera_prefs(hass) - - # Create a camera but don't add it to the hass object. - cam = GenericCamera(hass, config, "stream_preview", "Camera Preview Stream") - cam.entity_id = DOMAIN + ".stream_preview" - cam.platform = EntityPlatform( - hass=hass, - logger=_LOGGER, - domain=DOMAIN, - platform_name="camera", - platform=None, - scan_interval=timedelta(seconds=1), - entity_namespace=None, - ) - - stream = await cam.async_create_stream() - if not stream: - raise HomeAssistantError("Failed to create preview stream") - stream.add_provider(HLS_PROVIDER) - url = stream.endpoint_url(HLS_PROVIDER) - _LOGGER.debug("Registered preview stream URL: %s", url) - - return url - - class GenericIPCamConfigFlow(ConfigFlow, domain=DOMAIN): """Config flow for generic IP camera.""" @@ -387,7 +358,12 @@ async def async_step_user( errors["base"] = "no_still_image_or_stream_url" else: errors, still_format = await async_test_still(hass, user_input) - errors = errors | await async_test_stream(hass, user_input) + result = await async_test_and_preview_stream(hass, user_input) + if isinstance(result, dict): + errors = errors | result + self.context.pop("preview_stream", None) + else: + self.context["preview_stream"] = result if not errors: user_input[CONF_CONTENT_TYPE] = still_format still_url = user_input.get(CONF_STILL_IMAGE_URL) @@ -404,7 +380,7 @@ async def async_step_user( self.title = name # temporary preview for user to check the image self.context["preview_cam"] = user_input - return await self.async_step_user_confirm_still() + return await self.async_step_user_confirm() elif self.user_input: user_input = self.user_input else: @@ -415,11 +391,14 @@ async def async_step_user( errors=errors, ) - async def async_step_user_confirm_still( + async def async_step_user_confirm( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: """Handle user clicking confirm after still preview.""" if user_input: + if ha_stream := self.context.get("preview_stream"): + # Kill off the temp stream we created. + await ha_stream.stop() if not user_input.get(CONF_CONFIRMED_OK): return await self.async_step_user() return self.async_create_entry( @@ -428,7 +407,7 @@ async def async_step_user_confirm_still( register_preview(self.hass) preview_url = f"/api/generic/preview_flow_image/{self.flow_id}?t={datetime.now().isoformat()}" return self.async_show_form( - step_id="user_confirm_still", + step_id="user_confirm", data_schema=vol.Schema( { vol.Required(CONF_CONFIRMED_OK, default=False): bool, @@ -464,7 +443,10 @@ async def async_step_init( errors, still_format = await async_test_still( hass, self.config_entry.options | user_input ) - errors = errors | await async_test_stream(hass, user_input) + + result = await async_test_and_preview_stream(hass, user_input) + if isinstance(result, dict): + errors = errors | result still_url = user_input.get(CONF_STILL_IMAGE_URL) if not errors: if still_url is None: @@ -553,6 +535,41 @@ async def get(self, request: web.Request, flow_id: str) -> web.Response: return web.Response(body=image.content, content_type=image.content_type) +class PreviewStream: + """A wrapper around the stream object to automatically close unused streams.""" + + def __init__(self, stream: Stream) -> None: + """Initialize the object.""" + self.stream = stream + self._deferred_stop = None + + async def start(self, timeout=600): + """Start the stream with a timeout.""" + + async def _timeout() -> None: + _LOGGER.debug("Starting preview stream with timeout %ss", timeout) + await asyncio.sleep(timeout) + _LOGGER.info("Preview stream stopping due to timeout") + await self.stream.stop() + + await self.stream.start() + self._deferred_stop = self.stream.hass.async_create_task(_timeout()) + + async def stop(self): + """Stop the stream.""" + if not self._deferred_stop.done(): + self._deferred_stop.cancel() + await self.stream.stop() + + def add_provider(self, provider): + """Add a provider to the stream.""" + return self.stream.add_provider(provider) + + def endpoint_url(self, fmt: str) -> str: + """Return the endpoint URL.""" + return self.stream.endpoint_url(fmt) + + @websocket_api.websocket_command( { vol.Required("type"): "generic_camera/start_preview", @@ -568,12 +585,12 @@ async def ws_start_preview( msg: dict[str, Any], ) -> None: """Generate websocket handler for the camera still/stream preview.""" - errors: dict[str, str] = {} _LOGGER.debug("Generating websocket handler for generic camera preview") + ha_still_url = None + ha_stream_url = None flow = hass.config_entries.flow.async_get(msg["flow_id"]) - user_input = deepcopy(flow["context"]["preview_cam"]) - del user_input[CONF_CONTENT_TYPE] # The schema doesn't like this generated field. + user_input = flow["context"]["preview_cam"] # Create an EntityPlatform, needed for name translations platform = await async_prepare_setup_platform(hass, {}, CAMERA_DOMAIN, DOMAIN) @@ -588,34 +605,13 @@ async def ws_start_preview( ) await entity_platform.async_load_translations() - user_input[CONF_LIMIT_REFETCH_TO_URL_CHANGE] = False - - ext_still_url = user_input.get(CONF_STILL_IMAGE_URL) - ext_stream_url = user_input.get(CONF_STREAM_SOURCE) - - if ext_still_url: - errors, still_format = await async_test_still(hass, user_input) - user_input[CONF_CONTENT_TYPE] = still_format - register_preview(hass) - + if user_input.get(CONF_STILL_IMAGE_URL): ha_still_url = f"/api/generic/preview_flow_image/{msg['flow_id']}?t={datetime.now().isoformat()}" - _LOGGER.debug("Preview still URL: %s", ha_still_url) - else: - # If user didn't specify a still image URL, - # The automatically generated still image that stream generates - # is always jpeg - user_input[CONF_CONTENT_TYPE] = "image/jpeg" - ha_still_url = None + _LOGGER.debug("Got preview still URL: %s", ha_still_url) - ha_stream_url = None - if ext_stream_url: - errors = errors | await async_test_stream(hass, user_input) - if not errors: - preview_entity = GenericCamera( - hass, user_input, msg["flow_id"] + "stream_preview", "PreviewStream" - ) - preview_entity.platform = entity_platform - ha_stream_url = await register_stream_preview(hass, user_input) + if ha_stream := flow["context"].get("preview_stream"): + ha_stream_url = ha_stream.endpoint_url(HLS_PROVIDER) + _LOGGER.debug("Got preview stream URL: %s", ha_stream_url) connection.send_result(msg["id"]) connection.send_message( diff --git a/homeassistant/components/generic/strings.json b/homeassistant/components/generic/strings.json index cf2c5cc545c839..16480bc256d754 100644 --- a/homeassistant/components/generic/strings.json +++ b/homeassistant/components/generic/strings.json @@ -38,7 +38,7 @@ "verify_ssl": "[%key:common::config_flow::data::verify_ssl%]" } }, - "user_confirm_still": { + "user_confirm": { "title": "Confirmation", "description": "Please wait for previews to load...", "data": { @@ -67,10 +67,10 @@ } }, "confirm_still": { - "title": "[%key:component::generic::config::step::user_confirm_still::title%]", - "description": "[%key:component::generic::config::step::user_confirm_still::description%]", + "title": "Preview", + "description": "![Camera Still Image Preview]({preview_url})", "data": { - "confirmed_ok": "[%key:component::generic::config::step::user_confirm_still::data::confirmed_ok%]" + "confirmed_ok": "This image looks good." } } }, diff --git a/tests/components/generic/test_config_flow.py b/tests/components/generic/test_config_flow.py index 64e8871c1c8d87..e4c782b9223b56 100644 --- a/tests/components/generic/test_config_flow.py +++ b/tests/components/generic/test_config_flow.py @@ -90,7 +90,7 @@ async def test_form( TESTDATA, ) assert result1["type"] is FlowResultType.FORM - assert result1["step_id"] == "user_confirm_still" + assert result1["step_id"] == "user_confirm" client = await hass_client() preview_id = result1["flow_id"] # Check the preview image works. @@ -144,7 +144,7 @@ async def test_form_only_stillimage( ) await hass.async_block_till_done() assert result1["type"] is FlowResultType.FORM - assert result1["step_id"] == "user_confirm_still" + assert result1["step_id"] == "user_confirm" result2 = await hass.config_entries.flow.async_configure( result1["flow_id"], user_input={CONF_CONFIRMED_OK: True}, @@ -165,13 +165,13 @@ async def test_form_only_stillimage( @respx.mock -async def test_form_reject_still_preview( +async def test_form_reject_preview( hass: HomeAssistant, fakeimgbytes_png: bytes, mock_create_stream: _patch[MagicMock], user_flow: ConfigFlowResult, ) -> None: - """Test we go back to the config screen if the user rejects the still preview.""" + """Test we go back to the config screen if the user rejects the preview.""" respx.get("http://127.0.0.1/testurl/1").respond(stream=fakeimgbytes_png) with mock_create_stream: result1 = await hass.config_entries.flow.async_configure( @@ -179,7 +179,7 @@ async def test_form_reject_still_preview( TESTDATA, ) assert result1["type"] is FlowResultType.FORM - assert result1["step_id"] == "user_confirm_still" + assert result1["step_id"] == "user_confirm" result2 = await hass.config_entries.flow.async_configure( result1["flow_id"], user_input={CONF_CONFIRMED_OK: False}, @@ -209,7 +209,7 @@ async def test_form_still_preview_cam_off( TESTDATA, ) assert result1["type"] is FlowResultType.FORM - assert result1["step_id"] == "user_confirm_still" + assert result1["step_id"] == "user_confirm" preview_id = result1["flow_id"] # Try to view the image, should be unavailable. client = await hass_client() @@ -231,7 +231,7 @@ async def test_form_only_stillimage_gif( data, ) assert result1["type"] is FlowResultType.FORM - assert result1["step_id"] == "user_confirm_still" + assert result1["step_id"] == "user_confirm" result2 = await hass.config_entries.flow.async_configure( result1["flow_id"], user_input={CONF_CONFIRMED_OK: True}, @@ -256,7 +256,7 @@ async def test_form_only_svg_whitespace( data, ) assert result1["type"] is FlowResultType.FORM - assert result1["step_id"] == "user_confirm_still" + assert result1["step_id"] == "user_confirm" result2 = await hass.config_entries.flow.async_configure( result1["flow_id"], user_input={CONF_CONFIRMED_OK: True}, @@ -291,7 +291,7 @@ async def test_form_only_still_sample( data, ) assert result1["type"] is FlowResultType.FORM - assert result1["step_id"] == "user_confirm_still" + assert result1["step_id"] == "user_confirm" result2 = await hass.config_entries.flow.async_configure( result1["flow_id"], user_input={CONF_CONFIRMED_OK: True}, @@ -308,13 +308,13 @@ async def test_form_only_still_sample( ( "http://localhost:812{{3}}/static/icons/favicon-apple-180x180.png", "http://localhost:8123/static/icons/favicon-apple-180x180.png", - "user_confirm_still", + "user_confirm", None, ), ( "{% if 1 %}https://bla{% else %}https://yo{% endif %}", "https://bla/", - "user_confirm_still", + "user_confirm", None, ), ( @@ -383,7 +383,7 @@ async def test_form_rtsp_mode( user_flow["flow_id"], data ) assert result1["type"] is FlowResultType.FORM - assert result1["step_id"] == "user_confirm_still" + assert result1["step_id"] == "user_confirm" result2 = await hass.config_entries.flow.async_configure( result1["flow_id"], user_input={CONF_CONFIRMED_OK: True}, @@ -588,6 +588,8 @@ async def test_form_stream_timeout( "homeassistant.components.generic.config_flow.create_stream" ) as create_stream: create_stream.return_value.start = AsyncMock() + create_stream.return_value.stop = AsyncMock() + create_stream.return_value.hass = hass create_stream.return_value.add_provider.return_value.part_recv = AsyncMock() create_stream.return_value.add_provider.return_value.part_recv.return_value = ( False