diff --git a/homeassistant/components/wiz/__init__.py b/homeassistant/components/wiz/__init__.py index eae5657589a94..52b6996318587 100644 --- a/homeassistant/components/wiz/__init__.py +++ b/homeassistant/components/wiz/__init__.py @@ -4,7 +4,6 @@ from typing import Any from pywizlight import PilotParser, wizlight -from pywizlight.bulb import PIR_SOURCE from homeassistant.const import CONF_HOST, EVENT_HOMEASSISTANT_STOP, Platform from homeassistant.core import Event, HomeAssistant, callback @@ -18,6 +17,7 @@ DISCOVER_SCAN_TIMEOUT, DISCOVERY_INTERVAL, DOMAIN, + OCCUPANCY_SOURCES, SIGNAL_WIZ_PIR, WIZ_CONNECT_EXCEPTIONS, ) @@ -99,7 +99,7 @@ def _async_push_update(state: PilotParser) -> None: """Receive a push update.""" _LOGGER.debug("%s: Got push update: %s", bulb.mac, state.pilotResult) coordinator.async_set_updated_data(coordinator.data) - if state.get_source() == PIR_SOURCE: + if state.get_source() in OCCUPANCY_SOURCES: async_dispatcher_send(hass, SIGNAL_WIZ_PIR.format(bulb.mac)) await bulb.start_push(_async_push_update) diff --git a/homeassistant/components/wiz/binary_sensor.py b/homeassistant/components/wiz/binary_sensor.py index 8f1c5ff53a255..bcbf6419696e5 100644 --- a/homeassistant/components/wiz/binary_sensor.py +++ b/homeassistant/components/wiz/binary_sensor.py @@ -2,8 +2,6 @@ from collections.abc import Callable -from pywizlight.bulb import PIR_SOURCE - from homeassistant.components.binary_sensor import ( BinarySensorDeviceClass, BinarySensorEntity, @@ -14,7 +12,7 @@ from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback -from .const import DOMAIN, SIGNAL_WIZ_PIR +from .const import DOMAIN, OCCUPANCY_SOURCES, SIGNAL_WIZ_PIR from .coordinator import WizConfigEntry, WizData from .entity import WizEntity @@ -73,5 +71,5 @@ def __init__(self, wiz_data: WizData, name: str) -> None: @callback def _async_update_attrs(self) -> None: """Handle updating _attr values.""" - if self._device.state.get_source() == PIR_SOURCE: + if self._device.state.get_source() in OCCUPANCY_SOURCES: self._attr_is_on = self._device.status diff --git a/homeassistant/components/wiz/const.py b/homeassistant/components/wiz/const.py index 78074a3d5fbc5..59cc7788a74fc 100644 --- a/homeassistant/components/wiz/const.py +++ b/homeassistant/components/wiz/const.py @@ -2,6 +2,7 @@ from datetime import timedelta +from pywizlight.bulb import PIR_SOURCE from pywizlight.exceptions import ( WizLightConnectionError, WizLightNotKnownBulb, @@ -24,3 +25,4 @@ WIZ_CONNECT_EXCEPTIONS = (WizLightNotKnownBulb, *WIZ_EXCEPTIONS) SIGNAL_WIZ_PIR = "wiz_pir_{}" +OCCUPANCY_SOURCES = frozenset({PIR_SOURCE, "wfsens"}) diff --git a/tests/components/wiz/test_binary_sensor.py b/tests/components/wiz/test_binary_sensor.py index c7e5541d91ec7..85705dadfd6b4 100644 --- a/tests/components/wiz/test_binary_sensor.py +++ b/tests/components/wiz/test_binary_sensor.py @@ -1,5 +1,7 @@ """Tests for WiZ binary_sensor platform.""" +import pytest + from homeassistant.components import wiz from homeassistant.components.wiz.binary_sensor import OCCUPANCY_UNIQUE_ID from homeassistant.config_entries import ConfigEntryState @@ -21,20 +23,27 @@ from tests.common import MockConfigEntry +@pytest.mark.parametrize("occupancy_source", ["pir", "wfsens"]) async def test_binary_sensor_created_from_push_updates( - hass: HomeAssistant, entity_registry: er.EntityRegistry + hass: HomeAssistant, + entity_registry: er.EntityRegistry, + occupancy_source: str, ) -> None: """Test a binary sensor created from push updates.""" bulb, _ = await async_setup_integration(hass) - await async_push_update(hass, bulb, {"mac": FAKE_MAC, "src": "pir", "state": True}) + await async_push_update( + hass, bulb, {"mac": FAKE_MAC, "src": occupancy_source, "state": True} + ) entity_id = "binary_sensor.mock_title_occupancy" assert entity_registry.async_get(entity_id).unique_id == f"{FAKE_MAC}_occupancy" state = hass.states.get(entity_id) assert state.state == STATE_ON - await async_push_update(hass, bulb, {"mac": FAKE_MAC, "src": "pir", "state": False}) + await async_push_update( + hass, bulb, {"mac": FAKE_MAC, "src": occupancy_source, "state": False} + ) state = hass.states.get(entity_id) assert state.state == STATE_OFF