Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: change mips reconnect logic & add mips test case #641

Merged
merged 3 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 65 additions & 31 deletions custom_components/xiaomi_home/miot/miot_mips.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,9 @@ class _MipsClient(ABC):
_ca_file: Optional[str]
_cert_file: Optional[str]
_key_file: Optional[str]
_tls_done: bool

_mqtt_logger: Optional[logging.Logger]
_mqtt: Client
_mqtt: Optional[Client]
_mqtt_fd: int
_mqtt_timer: Optional[asyncio.TimerHandle]
_mqtt_state: bool
Expand Down Expand Up @@ -272,16 +271,12 @@ def __init__(
self._ca_file = ca_file
self._cert_file = cert_file
self._key_file = key_file
self._tls_done = False

self._mqtt_logger = None
self._mqtt_fd = -1
self._mqtt_timer = None
self._mqtt_state = False
# mqtt init for API_VERSION2,
# callback_api_version=CallbackAPIVersion.VERSION2,
self._mqtt = Client(client_id=self._client_id, protocol=MQTTv5)
self._mqtt.enable_logger(logger=self._mqtt_logger)
self._mqtt = None

# Mips init
self._event_connect = asyncio.Event()
Expand Down Expand Up @@ -316,7 +311,9 @@ def mips_state(self) -> bool:
Returns:
bool: True: connected, False: disconnected
"""
return self._mqtt and self._mqtt.is_connected()
if self._mqtt:
return self._mqtt.is_connected()
return False

def connect(self, thread_name: Optional[str] = None) -> None:
"""mips connect."""
Expand Down Expand Up @@ -359,7 +356,22 @@ def deinit(self) -> None:
self._ca_file = None
self._cert_file = None
self._key_file = None
self._tls_done = False
self._mqtt_logger = None
with self._mips_state_sub_map_lock:
self._mips_state_sub_map.clear()
self._mips_sub_pending_map.clear()
self._mips_sub_pending_timer = None

@final
async def deinit_async(self) -> None:
await self.disconnect_async()

self._logger = None
self._username = None
self._password = None
self._ca_file = None
self._cert_file = None
self._key_file = None
self._mqtt_logger = None
with self._mips_state_sub_map_lock:
self._mips_state_sub_map.clear()
Expand All @@ -368,8 +380,9 @@ def deinit(self) -> None:

def update_mqtt_password(self, password: str) -> None:
self._password = password
self._mqtt.username_pw_set(
username=self._username, password=self._password)
if self._mqtt:
self._mqtt.username_pw_set(
username=self._username, password=self._password)

def log_debug(self, msg, *args, **kwargs) -> None:
if self._logger:
Expand All @@ -389,10 +402,12 @@ def enable_logger(self, logger: Optional[logging.Logger] = None) -> None:
def enable_mqtt_logger(
self, logger: Optional[logging.Logger] = None
) -> None:
if logger:
self._mqtt.enable_logger(logger=logger)
else:
self._mqtt.disable_logger()
self._mqtt_logger = logger
if self._mqtt:
if logger:
self._mqtt.enable_logger(logger=logger)
else:
self._mqtt.disable_logger()

@final
def sub_mips_state(
Expand Down Expand Up @@ -587,25 +602,27 @@ def __mqtt_loop_handler(self) -> None:

def __mips_loop_thread(self) -> None:
self.log_info('mips_loop_thread start')
# mqtt init for API_VERSION2,
# callback_api_version=CallbackAPIVersion.VERSION2,
self._mqtt = Client(client_id=self._client_id, protocol=MQTTv5)
self._mqtt.enable_logger(logger=self._mqtt_logger)
# Set mqtt config
if self._username:
self._mqtt.username_pw_set(
username=self._username, password=self._password)
if not self._tls_done:
if (
self._ca_file
and self._cert_file
and self._key_file
):
self._mqtt.tls_set(
tls_version=ssl.PROTOCOL_TLS_CLIENT,
ca_certs=self._ca_file,
certfile=self._cert_file,
keyfile=self._key_file)
else:
self._mqtt.tls_set(tls_version=ssl.PROTOCOL_TLS_CLIENT)
self._mqtt.tls_insecure_set(True)
self._tls_done = True
if (
self._ca_file
and self._cert_file
and self._key_file
):
self._mqtt.tls_set(
tls_version=ssl.PROTOCOL_TLS_CLIENT,
ca_certs=self._ca_file,
certfile=self._cert_file,
keyfile=self._key_file)
else:
self._mqtt.tls_set(tls_version=ssl.PROTOCOL_TLS_CLIENT)
self._mqtt.tls_insecure_set(True)
self._mqtt.on_connect = self.__on_connect
self._mqtt.on_connect_fail = self.__on_connect_failed
self._mqtt.on_disconnect = self.__on_disconnect
Expand All @@ -617,6 +634,9 @@ def __mips_loop_thread(self) -> None:
self.log_info('mips_loop_thread exit!')

def __on_connect(self, client, user_data, flags, rc, props) -> None:
if not self._mqtt:
_LOGGER.error('__on_connect, but mqtt is None')
return
if not self._mqtt.is_connected():
return
self.log_info(f'mips connect, {flags}, {rc}, {props}')
Expand Down Expand Up @@ -685,6 +705,10 @@ def __on_message(
self._on_mips_message(topic=msg.topic, payload=msg.payload)

def __mips_sub_internal_pending_handler(self, ctx: Any) -> None:
if not self._mqtt or not self._mqtt.is_connected():
_LOGGER.error(
'mips sub internal pending, but mqtt is None or disconnected')
return
subbed_count = 1
for topic in list(self._mips_sub_pending_map.keys()):
if subbed_count > self.MIPS_SUB_PATCH:
Expand Down Expand Up @@ -712,6 +736,9 @@ def __mips_sub_internal_pending_handler(self, ctx: Any) -> None:
self._mips_sub_pending_timer = None

def __mips_connect(self) -> None:
if not self._mqtt:
_LOGGER.error('__mips_connect, but mqtt is None')
return
result = MQTT_ERR_UNKNOWN
if self._mips_reconnect_timer:
self._mips_reconnect_timer.cancel()
Expand Down Expand Up @@ -782,7 +809,14 @@ def __mips_disconnect(self) -> None:
self._internal_loop.remove_reader(self._mqtt_fd)
self._internal_loop.remove_writer(self._mqtt_fd)
self._mqtt_fd = -1
self._mqtt.disconnect()
# Clear retry sub
if self._mips_sub_pending_timer:
self._mips_sub_pending_timer.cancel()
self._mips_sub_pending_timer = None
self._mips_sub_pending_map = {}
if self._mqtt:
self._mqtt.disconnect()
self._mqtt = None
self._internal_loop.stop()

def __get_next_reconnect_time(self) -> float:
Expand Down
9 changes: 6 additions & 3 deletions test/test_mdns.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
"""Unit test for miot_mdns.py."""
import asyncio
import logging
import pytest
from zeroconf import IPVersion
Expand All @@ -12,19 +13,21 @@

@pytest.mark.asyncio
async def test_service_loop_async():
from miot.miot_mdns import MipsService, MipsServiceData, MipsServiceState
from miot.miot_mdns import MipsService, MipsServiceState

async def on_service_state_change(
group_id: str, state: MipsServiceState, data: MipsServiceData):
group_id: str, state: MipsServiceState, data: dict):
_LOGGER.info(
'on_service_state_change, %s, %s, %s', group_id, state, data)

async with AsyncZeroconf(ip_version=IPVersion.V4Only) as aiozc:
mips_service = MipsService(aiozc)
mips_service.sub_service_change('test', '*', on_service_state_change)
await mips_service.init_async()
# Wait for service to discover
await asyncio.sleep(3)
services_detail = mips_service.get_services()
_LOGGER.info('get all service, %s', services_detail.keys())
_LOGGER.info('get all service, %s', list(services_detail.keys()))
for name, data in services_detail.items():
_LOGGER.info(
'\tinfo, %s, %s, %s, %s',
Expand Down
Loading
Loading