Skip to content

Commit

Permalink
Set up MQTT websocket_api and dump, publish actions from async_setup (
Browse files Browse the repository at this point in the history
#131170)

* Set up MQTT websocket_api and dump, publish actions from `async_setup`

* Follow up comments
  • Loading branch information
jbouwh authored Nov 21, 2024
1 parent 3d499ab commit 3474642
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 73 deletions.
149 changes: 80 additions & 69 deletions homeassistant/components/mqtt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,77 +225,27 @@ async def async_check_config_schema(
) from exc


async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Load a config entry."""
conf: dict[str, Any]
mqtt_data: MqttData
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the actions and websocket API for the MQTT component."""

async def _setup_client(
client_available: asyncio.Future[bool],
) -> tuple[MqttData, dict[str, Any]]:
"""Set up the MQTT client."""
# Fetch configuration
conf = dict(entry.data)
hass_config = await conf_util.async_hass_config_yaml(hass)
mqtt_yaml = CONFIG_SCHEMA(hass_config).get(DOMAIN, [])
await async_create_certificate_temp_files(hass, conf)
client = MQTT(hass, entry, conf)
if DOMAIN in hass.data:
mqtt_data = hass.data[DATA_MQTT]
mqtt_data.config = mqtt_yaml
mqtt_data.client = client
else:
# Initial setup
websocket_api.async_register_command(hass, websocket_subscribe)
websocket_api.async_register_command(hass, websocket_mqtt_info)
hass.data[DATA_MQTT] = mqtt_data = MqttData(config=mqtt_yaml, client=client)
await client.async_start(mqtt_data)

# Restore saved subscriptions
if mqtt_data.subscriptions_to_restore:
mqtt_data.client.async_restore_tracked_subscriptions(
mqtt_data.subscriptions_to_restore
)
mqtt_data.subscriptions_to_restore = set()
mqtt_data.reload_dispatchers.append(
entry.add_update_listener(_async_config_entry_updated)
)

return (mqtt_data, conf)

client_available: asyncio.Future[bool]
if DATA_MQTT_AVAILABLE not in hass.data:
client_available = hass.data[DATA_MQTT_AVAILABLE] = hass.loop.create_future()
else:
client_available = hass.data[DATA_MQTT_AVAILABLE]

mqtt_data, conf = await _setup_client(client_available)
platforms_used = platforms_from_config(mqtt_data.config)
platforms_used.update(
entry.domain
for entry in er.async_entries_for_config_entry(
er.async_get(hass), entry.entry_id
)
)
integration = async_get_loaded_integration(hass, DOMAIN)
# Preload platforms we know we are going to use so
# discovery can setup each platform synchronously
# and avoid creating a flood of tasks at startup
# while waiting for the the imports to complete
if not integration.platforms_are_loaded(platforms_used):
with async_pause_setup(hass, SetupPhases.WAIT_IMPORT_PLATFORMS):
await integration.async_get_platforms(platforms_used)

# Wait to connect until the platforms are loaded so
# we can be sure discovery does not have to wait for
# each platform to load when we get the flood of retained
# messages on connect
await mqtt_data.client.async_connect(client_available)
websocket_api.async_register_command(hass, websocket_subscribe)
websocket_api.async_register_command(hass, websocket_mqtt_info)

async def async_publish_service(call: ServiceCall) -> None:
"""Handle MQTT publish service calls."""
msg_topic: str | None = call.data.get(ATTR_TOPIC)
msg_topic_template: str | None = call.data.get(ATTR_TOPIC_TEMPLATE)

if not mqtt_config_entry_enabled(hass):
raise ServiceValidationError(
translation_key="mqtt_not_setup_cannot_publish",
translation_domain=DOMAIN,
translation_placeholders={
"topic": str(msg_topic or msg_topic_template)
},
)

mqtt_data = hass.data[DATA_MQTT]
payload: PublishPayloadType = call.data.get(ATTR_PAYLOAD)
evaluate_payload: bool = call.data.get(ATTR_EVALUATE_PAYLOAD, False)
payload_template: str | None = call.data.get(ATTR_PAYLOAD_TEMPLATE)
Expand Down Expand Up @@ -402,6 +352,71 @@ async def finish_dump(_: datetime) -> None:
}
),
)
return True


async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Load a config entry."""
conf: dict[str, Any]
mqtt_data: MqttData

async def _setup_client() -> tuple[MqttData, dict[str, Any]]:
"""Set up the MQTT client."""
# Fetch configuration
conf = dict(entry.data)
hass_config = await conf_util.async_hass_config_yaml(hass)
mqtt_yaml = CONFIG_SCHEMA(hass_config).get(DOMAIN, [])
await async_create_certificate_temp_files(hass, conf)
client = MQTT(hass, entry, conf)
if DOMAIN in hass.data:
mqtt_data = hass.data[DATA_MQTT]
mqtt_data.config = mqtt_yaml
mqtt_data.client = client
else:
# Initial setup
hass.data[DATA_MQTT] = mqtt_data = MqttData(config=mqtt_yaml, client=client)
await client.async_start(mqtt_data)

# Restore saved subscriptions
if mqtt_data.subscriptions_to_restore:
mqtt_data.client.async_restore_tracked_subscriptions(
mqtt_data.subscriptions_to_restore
)
mqtt_data.subscriptions_to_restore = set()
mqtt_data.reload_dispatchers.append(
entry.add_update_listener(_async_config_entry_updated)
)

return (mqtt_data, conf)

client_available: asyncio.Future[bool]
if DATA_MQTT_AVAILABLE not in hass.data:
client_available = hass.data[DATA_MQTT_AVAILABLE] = hass.loop.create_future()
else:
client_available = hass.data[DATA_MQTT_AVAILABLE]

mqtt_data, conf = await _setup_client()
platforms_used = platforms_from_config(mqtt_data.config)
platforms_used.update(
entry.domain
for entry in er.async_entries_for_config_entry(
er.async_get(hass), entry.entry_id
)
)
integration = async_get_loaded_integration(hass, DOMAIN)
# Preload platforms we know we are going to use so
# discovery can setup each platform synchronously
# and avoid creating a flood of tasks at startup
# while waiting for the the imports to complete
if not integration.platforms_are_loaded(platforms_used):
with async_pause_setup(hass, SetupPhases.WAIT_IMPORT_PLATFORMS):
await integration.async_get_platforms(platforms_used)

# Wait to connect until the platforms are loaded so
# we can be sure discovery does not have to wait for
# each platform to load when we get the flood of retained
# messages on connect
await mqtt_data.client.async_connect(client_available)

# setup platforms and discovery
async def _reload_config(call: ServiceCall) -> None:
Expand Down Expand Up @@ -557,10 +572,6 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
mqtt_data = hass.data[DATA_MQTT]
mqtt_client = mqtt_data.client

# Unload publish and dump services.
hass.services.async_remove(DOMAIN, SERVICE_PUBLISH)
hass.services.async_remove(DOMAIN, SERVICE_DUMP)

# Stop the discovery
await discovery.async_stop(hass)
# Unload the platforms
Expand Down
37 changes: 33 additions & 4 deletions tests/components/mqtt/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,26 @@ async def test_service_call_without_topic_does_not_publish(
assert not mqtt_mock.async_publish.called


async def test_service_call_mqtt_entry_does_not_publish(
hass: HomeAssistant, mqtt_client_mock: MqttMockPahoClient
) -> None:
"""Test the service call if topic is missing."""
assert await async_setup_component(hass, mqtt.DOMAIN, {})
with pytest.raises(
ServiceValidationError,
match='Cannot publish to topic "test_topic", make sure MQTT is set up correctly',
):
await hass.services.async_call(
mqtt.DOMAIN,
mqtt.SERVICE_PUBLISH,
{
mqtt.ATTR_TOPIC: "test_topic",
mqtt.ATTR_PAYLOAD: "payload",
},
blocking=True,
)


# The use of a topic_template in an mqtt publish action call
# has been deprecated with HA Core 2024.8.0 and will be removed with HA Core 2025.2.0
async def test_mqtt_publish_action_call_with_topic_and_topic_template_does_not_publish(
Expand Down Expand Up @@ -1822,11 +1842,17 @@ async def async_mqtt_connected_async(status: bool) -> None:

async def test_unload_config_entry(
hass: HomeAssistant,
setup_with_birth_msg_client_mock: MqttMockPahoClient,
mqtt_client_mock: MqttMockPahoClient,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test unloading the MQTT entry."""
mqtt_client_mock = setup_with_birth_msg_client_mock
entry = MockConfigEntry(
domain=mqtt.DOMAIN,
data={mqtt.CONF_BROKER: "test-broker"},
)
entry.add_to_hass(hass)

assert await async_setup_component(hass, mqtt.DOMAIN, {})
assert hass.services.has_service(mqtt.DOMAIN, "dump")
assert hass.services.has_service(mqtt.DOMAIN, "publish")

Expand All @@ -1843,15 +1869,18 @@ async def test_unload_config_entry(
mqtt_client_mock.publish.assert_any_call("just_in_time", "published", 0, False)
assert new_mqtt_config_entry.state is ConfigEntryState.NOT_LOADED
await hass.async_block_till_done(wait_background_tasks=True)
assert not hass.services.has_service(mqtt.DOMAIN, "dump")
assert not hass.services.has_service(mqtt.DOMAIN, "publish")
assert hass.services.has_service(mqtt.DOMAIN, "dump")
assert hass.services.has_service(mqtt.DOMAIN, "publish")
assert "No ACK from MQTT server" not in caplog.text


async def test_publish_or_subscribe_without_valid_config_entry(
hass: HomeAssistant, record_calls: MessageCallbackType
) -> None:
"""Test internal publish function with bad use cases."""
assert await async_setup_component(hass, mqtt.DOMAIN, {})
assert hass.services.has_service(mqtt.DOMAIN, "dump")
assert hass.services.has_service(mqtt.DOMAIN, "publish")
with pytest.raises(HomeAssistantError):
await mqtt.async_publish(
hass, "some-topic", "test-payload", qos=0, retain=False, encoding=None
Expand Down

0 comments on commit 3474642

Please sign in to comment.