Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 80 additions & 2 deletions homeassistant/components/device_tracker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
https://home-assistant.io/components/device_tracker/
"""
import asyncio
from collections import namedtuple
from datetime import timedelta
import logging
from typing import Any, List, Sequence, Callable
Expand Down Expand Up @@ -159,11 +160,14 @@ async def async_setup(hass: HomeAssistantType, config: ConfigType):

async def async_setup_platform(p_type, p_config, disc_info=None):
"""Set up a device tracker platform."""
manager = hass.data[DOMAIN]
platform = await async_prepare_setup_platform(
hass, config, DOMAIN, p_type)
if platform is None:
return

manager.add_platform(p_type, platform)

_LOGGER.info("Setting up %s.%s", DOMAIN, p_type)
try:
scanner = None
Expand Down Expand Up @@ -194,12 +198,14 @@ async def async_setup_platform(p_type, p_config, disc_info=None):

if not setup:
_LOGGER.error("Error setting up platform %s", p_type)
manager.remove_platform(p_type)
return

except Exception: # pylint: disable=broad-except
_LOGGER.exception("Error setting up platform %s", p_type)
manager.remove_platform(p_type)

hass.data[DOMAIN] = async_setup_platform
hass.data[DOMAIN] = PlatformManager(async_setup_platform)

setup_tasks = [async_setup_platform(p_type, p_config) for p_type, p_config
in config_per_platform(config, DOMAIN)]
Expand Down Expand Up @@ -236,10 +242,82 @@ async def async_see_service(call):

async def async_setup_entry(hass, entry):
"""Set up an entry."""
await hass.data[DOMAIN](entry.domain, entry)
await hass.data[DOMAIN].async_setup_platform(entry.domain, entry)
return True


async def async_unload_entry(hass, entry):
"""Unload an entry."""
manager = hass.data[DOMAIN]
platform_name = entry.domain
platform = manager.get_platform(platform_name)
if platform and hasattr(platform, 'async_unload_entry'):
await platform.async_unload_entry(hass, entry)
manager.remove_platform(platform_name)
return True
return False


Platform = namedtuple('Platform', 'platform, count')


class PlatformManager:
"""Store data needed to unload entries from loaded platforms."""

def __init__(self, async_setup_platform: callable):
"""Initialize a platform manager."""
self._platforms = {}
self.async_setup_platform = async_setup_platform

def add_platform(self, platform_name: str, platform):
"""
Add a platform that has been loaded to the platform manager.

If an instance of this platform has already been added, increment
the count of it.
"""
_LOGGER.debug(
'Adding instance of %s to PlatformManager',
platform_name
)
if platform_name is None:
return
platform = self._platforms.setdefault(
platform_name,
Platform(platform, 0)
)
self._platforms[platform_name] = platform._replace(
count=platform.count + 1
)

def get_platform(self, platform_name: str):
Comment thread
rohankapoorcom marked this conversation as resolved.
"""Fetch a platform that has been loaded from the platform manager."""
if platform_name in self._platforms:
return self._platforms[platform_name].platform
return None

def remove_platform(self, platform_name: str):
"""
Decrement the count of a platform when an entity is unloaded.

If this is the last instance of the platform, remove it from the
platform manager.
"""
_LOGGER.debug(
'Removing instance of %s from PlatformManager',
platform_name
)
if platform_name not in self._platforms:
return
platform = self._platforms.get(platform_name)
if platform.count > 1:
self._platforms[platform_name] = platform._replace(
count=platform.count - 1
)
return
self._platforms.pop(platform_name)


class DeviceTracker:
"""Representation of a device tracker."""

Expand Down
24 changes: 18 additions & 6 deletions homeassistant/components/device_tracker/gpslogger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,22 @@
"""
import logging

from homeassistant.components.gpslogger import TRACKER_UPDATE
from homeassistant.components.device_tracker import DOMAIN as \
DEVICE_TRACKER_DOMAIN
from homeassistant.components.gpslogger import DOMAIN as GPSLOGGER_DOMAIN, \
TRACKER_UPDATE
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.typing import HomeAssistantType, ConfigType
from homeassistant.helpers.typing import HomeAssistantType

_LOGGER = logging.getLogger(__name__)

DEPENDENCIES = ['gpslogger']

DATA_KEY = '{}.{}'.format(GPSLOGGER_DOMAIN, DEVICE_TRACKER_DOMAIN)

async def async_setup_scanner(hass: HomeAssistantType, config: ConfigType,
async_see, discovery_info=None):
"""Set up an endpoint for the GPSLogger device tracker."""

async def async_setup_entry(hass: HomeAssistantType, entry, async_see):
"""Configure a dispatcher connection based on a config entry."""
async def _set_location(device, gps_location, battery, accuracy, attrs):
"""Fire HA event to set location."""
await async_see(
Expand All @@ -28,5 +32,13 @@ async def _set_location(device, gps_location, battery, accuracy, attrs):
attributes=attrs
)

async_dispatcher_connect(hass, TRACKER_UPDATE, _set_location)
hass.data[DATA_KEY] = async_dispatcher_connect(
hass, TRACKER_UPDATE, _set_location
)
return True


async def async_unload_entry(hass: HomeAssistantType, entry):
"""Unload the config entry and remove the dispatcher connection."""
hass.data[DATA_KEY]()
return True
13 changes: 9 additions & 4 deletions homeassistant/components/gpslogger/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from homeassistant.const import HTTP_UNPROCESSABLE_ENTITY, \
HTTP_OK, ATTR_LATITUDE, ATTR_LONGITUDE, CONF_WEBHOOK_ID
from homeassistant.helpers import config_entry_flow
from homeassistant.helpers.discovery import async_load_platform
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.components.device_tracker import DOMAIN as DEVICE_TRACKER

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -57,9 +57,6 @@ def _id(value: str) -> str:

async def async_setup(hass, hass_config):
"""Set up the GPSLogger component."""
hass.async_create_task(
async_load_platform(hass, 'device_tracker', DOMAIN, {}, hass_config)
)
return True


Expand Down Expand Up @@ -103,12 +100,20 @@ async def async_setup_entry(hass, entry):
"""Configure based on config entry."""
hass.components.webhook.async_register(
DOMAIN, 'GPSLogger', entry.data[CONF_WEBHOOK_ID], handle_webhook)

hass.async_create_task(
hass.config_entries.async_forward_entry_setup(entry, DEVICE_TRACKER)
)
return True


async def async_unload_entry(hass, entry):
"""Unload a config entry."""
hass.components.webhook.async_unregister(entry.data[CONF_WEBHOOK_ID])

hass.async_create_task(
hass.config_entries.async_forward_entry_unload(entry, DEVICE_TRACKER)
)
return True

config_entry_flow.register_webhook_flow(
Expand Down
113 changes: 94 additions & 19 deletions tests/components/device_tracker/test_init.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,30 @@
"""The tests for the device tracker component."""
# pylint: disable=protected-access
import asyncio
import json
import logging
from unittest.mock import call
from datetime import datetime, timedelta
import os
from asynctest import patch
from datetime import datetime, timedelta
from unittest.mock import call

import pytest
from asynctest import patch

from homeassistant.components import zone
from homeassistant.core import callback, State
from homeassistant.setup import async_setup_component
from homeassistant.helpers import discovery
from homeassistant.loader import get_component
import homeassistant.components.device_tracker as device_tracker
import homeassistant.util.dt as dt_util
from homeassistant.components import zone
from homeassistant.const import (
ATTR_ENTITY_ID, ATTR_ENTITY_PICTURE, ATTR_FRIENDLY_NAME, ATTR_HIDDEN,
STATE_HOME, STATE_NOT_HOME, CONF_PLATFORM, ATTR_ICON)
import homeassistant.components.device_tracker as device_tracker
from tests.components.device_tracker import common
from homeassistant.core import callback, State
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import discovery
from homeassistant.helpers.json import JSONEncoder

from homeassistant.loader import get_component
from homeassistant.setup import async_setup_component
from tests.common import (
async_fire_time_changed, patch_yaml_files, assert_setup_component,
mock_restore_cache)
mock_restore_cache, MockPlatform)
from tests.components.device_tracker import common

TEST_PLATFORM = {device_tracker.DOMAIN: {CONF_PLATFORM: 'test'}}

Expand Down Expand Up @@ -513,8 +512,7 @@ async def test_see_failures(mock_warning, hass, yaml_devices):
assert len(config) == 4


@asyncio.coroutine
def test_async_added_to_hass(hass):
async def test_async_added_to_hass(hass):
"""Test restoring state."""
attr = {
device_tracker.ATTR_LONGITUDE: 18,
Expand All @@ -532,7 +530,7 @@ def test_async_added_to_hass(hass):
path: 'jk:\n name: JK Phone\n track: True',
}
with patch_yaml_files(files):
yield from device_tracker.async_setup(hass, {})
await device_tracker.async_setup(hass, {})

state = hass.states.get('device_tracker.jk')
assert state
Expand All @@ -543,16 +541,15 @@ def test_async_added_to_hass(hass):
assert atr == val, "{}={} expected: {}".format(key, atr, val)


@asyncio.coroutine
def test_bad_platform(hass):
async def test_bad_platform(hass):
"""Test bad platform."""
config = {
'device_tracker': [{
'platform': 'bad_platform'
}]
}
with assert_setup_component(0, device_tracker.DOMAIN):
assert (yield from device_tracker.async_setup(hass, config))
assert await device_tracker.async_setup(hass, config)


async def test_adding_unknown_device_to_config(mock_device_tracker_conf, hass):
Expand Down Expand Up @@ -631,3 +628,81 @@ def test_see_schema_allowing_ios_calls():
"gps_accuracy": 300,
"hostname": 'beer',
})


@pytest.fixture
def platform_manager():
"""Create a new empty Platform Manager."""
return device_tracker.PlatformManager(None)


def test_platform_manager_add_platform_multiple(platform_manager):
"""Test that a platform can be added to the Platform Manager repeatedly."""
platform = MockPlatform()

platform_tuple = platform_manager._platforms.get('mock_platform')
assert None is platform_tuple

for i in range(1, 100):
platform_manager.add_platform('mock_platform', platform)
platform_tuple = platform_manager._platforms['mock_platform']
assert platform == platform_tuple.platform
assert i == platform_tuple.count


def test_platform_manager_get_nonexistant(platform_manager):
"""Test that a platform cannot be retrieved if never added."""
platform = platform_manager.get_platform('mock_platform')
assert None is platform


def test_platform_manager_get_after_add(platform_manager):
"""Test that a platform is retrievable after adding."""
platform = MockPlatform()
platform_manager.add_platform('mock_platform', platform)

platform = platform_manager.get_platform('mock_platform')
assert platform == platform


def test_platform_manager_remove_empty(platform_manager):
"""Test that removing before adding does nothing."""
platform_tuple = platform_manager._platforms.get('mock_platform')
assert None is platform_tuple

platform_manager.remove_platform('mock_platform')

platform_tuple = platform_manager._platforms.get('mock_platform')
assert None is platform_tuple


def test_platform_manager_remove_after_add_single(platform_manager):
"""Test that removing after adding leaves nothing behind."""
platform = MockPlatform()
platform_manager.add_platform('mock_platform', platform)

platform_manager.remove_platform('mock_platform')

platform_tuple = platform_manager._platforms.get('mock_platform')
assert None is platform_tuple


def test_platform_manager_remove_after_add_multiple(platform_manager):
"""Test that removing each instance decrements the count until the end."""
platform = MockPlatform()
starting_count = 100

platform_manager._platforms['mock_platform'] = device_tracker.Platform(
platform,
starting_count
)

for i in range(1, starting_count):
platform_manager.remove_platform('mock_platform')
platform_tuple = platform_manager._platforms['mock_platform']
assert platform == platform_tuple.platform
assert starting_count - i == platform_tuple.count

platform_manager.remove_platform('mock_platform')
platform_tuple = platform_manager._platforms.get('mock_platform')
assert None is platform_tuple
Loading