Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions homeassistant/config_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ async def async_step_discovery(info):
import os
import uuid

from . import data_entry_flow
from .core import callback
from .exceptions import HomeAssistantError
from .data_entry_flow import FlowManager
from .setup import async_setup_component, async_process_deps_reqs
from .util.json import load_json, save_json
from .util.decorator import Registry
Expand Down Expand Up @@ -255,8 +255,8 @@ class ConfigEntries:
def __init__(self, hass, hass_config):
"""Initialize the entry manager."""
self.hass = hass
self.flow = FlowManager(hass, HANDLERS, self._async_missing_handler,
self._async_save_entry)
self.flow = data_entry_flow.FlowManager(
hass, self._async_create_flow, self._async_save_entry)
self._hass_config = hass_config
self._entries = None
self._sched_save = None
Expand Down Expand Up @@ -345,7 +345,7 @@ async def _async_save_entry(self, result):
"""Add an entry."""
entry = ConfigEntry(
version=result['version'],
domain=result['domain'],
domain=result['handler'],
title=result['title'],
data=result['data'],
source=result['source'],
Expand All @@ -362,17 +362,22 @@ async def _async_save_entry(self, result):
await async_setup_component(
self.hass, entry.domain, self._hass_config)

async def _async_missing_handler(self, domain):
"""Called when a flow handler is not loaded."""
# This will load the component and thus register the handler
component = getattr(self.hass.components, domain)
async def _async_create_flow(self, handler):
"""Create a flow for specified handler.

if domain not in HANDLERS:
return
Handler key is the domain of the component that we want to setup.
"""
component = getattr(self.hass.components, handler)
handler = HANDLERS.get(handler)

if handler is None:
raise data_entry_flow.UnknownHandler

# Make sure requirements and dependencies of component are resolved
await async_process_deps_reqs(
self.hass, self._hass_config, domain, component)
self.hass, self._hass_config, handler, component)

return handler()

@callback
def _async_schedule_save(self):
Expand Down
35 changes: 12 additions & 23 deletions homeassistant/data_entry_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,41 +34,30 @@ class UnknownStep(FlowError):
class FlowManager:
"""Manage all the flows that are in progress."""

def __init__(self, hass, handlers, async_missing_handler,
async_save_entry):
def __init__(self, hass, async_create_flow, async_save_entry):
"""Initialize the flow manager."""
self.hass = hass
self._handlers = handlers
self._progress = {}
self._async_missing_handler = async_missing_handler
self._async_create_flow = async_create_flow
self._async_save_entry = async_save_entry

@callback
def async_progress(self):
"""Return the flows in progress."""
return [{
'flow_id': flow.flow_id,
'domain': flow.domain,
'handler': flow.handler,
'source': flow.source,
} for flow in self._progress.values()]

async def async_init(self, domain, *, source=SOURCE_USER, data=None):
async def async_init(self, handler, *, source=SOURCE_USER, data=None):
"""Start a configuration flow."""
handler = self._handlers.get(domain)

if handler is None:
await self._async_missing_handler(domain)
handler = self._handlers.get(domain)

if handler is None:
raise UnknownHandler

flow_id = uuid.uuid4().hex
flow = self._progress[flow_id] = handler()
flow = await self._async_create_flow(handler)
flow.hass = self.hass
flow.domain = domain
flow.flow_id = flow_id
flow.handler = handler
flow.flow_id = uuid.uuid4().hex
flow.source = source
self._progress[flow.flow_id] = flow

if source == SOURCE_USER:
step = 'init'
Expand Down Expand Up @@ -137,7 +126,7 @@ class FlowHandler:
# Set by flow manager
flow_id = None
hass = None
domain = None
handler = None
source = SOURCE_USER
cur_step = None

Expand All @@ -150,7 +139,7 @@ def async_show_form(self, *, step_id, data_schema=None, errors=None):
return {
'type': RESULT_TYPE_FORM,
'flow_id': self.flow_id,
'domain': self.domain,
'handler': self.handler,
'step_id': step_id,
'data_schema': data_schema,
'errors': errors,
Expand All @@ -163,7 +152,7 @@ def async_create_entry(self, *, title, data):
'version': self.VERSION,
'type': RESULT_TYPE_CREATE_ENTRY,
'flow_id': self.flow_id,
'domain': self.domain,
'handler': self.handler,
'title': title,
'data': data,
'source': self.source,
Expand All @@ -175,6 +164,6 @@ def async_abort(self, *, reason):
return {
'type': RESULT_TYPE_ABORT,
'flow_id': self.flow_id,
'domain': self.domain,
'handler': self.handler,
'reason': reason
}
12 changes: 6 additions & 6 deletions tests/components/config/test_config_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def async_step_init(self, user_input=None):

assert data == {
'type': 'form',
'domain': 'test',
'handler': 'test',
'step_id': 'init',
'data_schema': [
{
Expand Down Expand Up @@ -156,7 +156,7 @@ def async_step_init(self, user_input=None):
data = yield from resp.json()
data.pop('flow_id')
assert data == {
'domain': 'test',
'handler': 'test',
'reason': 'bla',
'type': 'abort'
}
Expand Down Expand Up @@ -186,7 +186,7 @@ def async_step_init(self, user_input=None):
data = yield from resp.json()
data.pop('flow_id')
assert data == {
'domain': 'test',
'handler': 'test',
'title': 'Test Entry',
'type': 'create_entry',
'source': 'user',
Expand Down Expand Up @@ -226,7 +226,7 @@ def async_step_account(self, user_input=None):
flow_id = data.pop('flow_id')
assert data == {
'type': 'form',
'domain': 'test',
'handler': 'test',
'step_id': 'account',
'data_schema': [
{
Expand All @@ -245,7 +245,7 @@ def async_step_account(self, user_input=None):
data = yield from resp.json()
data.pop('flow_id')
assert data == {
'domain': 'test',
'handler': 'test',
'type': 'create_entry',
'title': 'user-title',
'version': 1,
Expand Down Expand Up @@ -279,7 +279,7 @@ def async_step_account(self, user_input=None):
assert data == [
{
'flow_id': form['flow_id'],
'domain': 'test',
'handler': 'test',
'source': 'hassio'
}
]
Expand Down
2 changes: 1 addition & 1 deletion tests/components/test_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def discover(netdisco):

with patch.dict(discovery.CONFIG_ENTRY_HANDLERS, {
'mock-service': 'mock-component'}), patch(
'homeassistant.config_entries.FlowManager.async_init') as m_init:
'homeassistant.data_entry_flow.FlowManager.async_init') as m_init:
await mock_discovery(hass, discover)

assert len(m_init.mock_calls) == 1
Expand Down
18 changes: 12 additions & 6 deletions tests/test_data_entry_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,26 @@
from homeassistant import data_entry_flow
from homeassistant.util.decorator import Registry

from tests.common import mock_coro


@pytest.fixture
def manager():
"""Return a flow manager."""
handlers = Registry()
entries = []

async def async_create_flow(handler_name):
handler = handlers.get(handler_name)

if handler is None:
raise data_entry_flow.UnknownHandler

return handler()

async def async_add_entry(result):
entries.append(result)

manager = data_entry_flow.FlowManager(
None, handlers, mock_coro, async_add_entry)
None, async_create_flow, async_add_entry)
manager.mock_created_entries = entries
manager.mock_reg_handler = handlers.register
return manager
Expand Down Expand Up @@ -84,7 +90,7 @@ async def async_step_second(self, user_input=None):
assert len(manager.async_progress()) == 0
assert len(manager.mock_created_entries) == 1
result = manager.mock_created_entries[0]
assert result['domain'] == 'test'
assert result['handler'] == 'test'
assert result['data'] == ['INIT-DATA', 'SECOND-DATA']


Expand Down Expand Up @@ -153,7 +159,7 @@ async def async_step_init(self, user_input=None):

entry = manager.mock_created_entries[0]
assert entry['version'] == 5
assert entry['domain'] == 'test'
assert entry['handler'] == 'test'
assert entry['title'] == 'Test Title'
assert entry['data'] == 'Test Data'
assert entry['source'] == data_entry_flow.SOURCE_USER
Expand All @@ -180,7 +186,7 @@ async def async_step_discovery(self, info):

entry = manager.mock_created_entries[0]
assert entry['version'] == 5
assert entry['domain'] == 'test'
assert entry['handler'] == 'test'
assert entry['title'] == 'hello'
assert entry['data'] == data
assert entry['source'] == data_entry_flow.SOURCE_DISCOVERY