Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,9 @@ _build
.vscode
*~

# tox local cache
# tox-specific files
.tox
build

# coverage-specific files
.coverage
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ repos:
types: [python]
files: "^tests/"
args:
- --disable=missing-docstring,consider-using-f-string,duplicate-code
- --disable=missing-docstring,invalid-name,consider-using-f-string,duplicate-code
147 changes: 22 additions & 125 deletions adafruit_minimqtt/adafruit_minimqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,21 @@
* Adafruit CircuitPython firmware for the supported boards:
https://github.com/adafruit/circuitpython/releases

* Adafruit's Connection Manager library:
https://github.com/adafruit/Adafruit_CircuitPython_ConnectionManager

"""
import errno
import struct
import time
from random import randint

from adafruit_connection_manager import (
get_connection_manager,
SocketGetOSError,
SocketConnectMemoryError,
)

try:
from typing import List, Optional, Tuple, Type, Union
except ImportError:
Expand Down Expand Up @@ -78,68 +87,19 @@
_default_sock = None # pylint: disable=invalid-name
_fake_context = None # pylint: disable=invalid-name

TemporaryError = (SocketGetOSError, SocketConnectMemoryError)


class MMQTTException(Exception):
"""MiniMQTT Exception class."""

# pylint: disable=unnecessary-pass
# pass


class TemporaryError(Exception):
"""Temporary error class used for handling reconnects."""


# Legacy ESP32SPI Socket API
def set_socket(sock, iface=None) -> None:
"""Legacy API for setting the socket and network interface.

:param sock: socket object.
:param iface: internet interface object

"""
global _default_sock # pylint: disable=invalid-name, global-statement
global _fake_context # pylint: disable=invalid-name, global-statement
_default_sock = sock
if iface:
_default_sock.set_interface(iface)
_fake_context = _FakeSSLContext(iface)


class _FakeSSLSocket:
def __init__(self, socket, tls_mode) -> None:
self._socket = socket
self._mode = tls_mode
self.settimeout = socket.settimeout
self.send = socket.send
self.recv = socket.recv
self.close = socket.close

def connect(self, address):
"""connect wrapper to add non-standard mode parameter"""
try:
return self._socket.connect(address, self._mode)
except RuntimeError as error:
raise OSError(errno.ENOMEM) from error


class _FakeSSLContext:
def __init__(self, iface) -> None:
self._iface = iface

def wrap_socket(self, socket, server_hostname=None) -> _FakeSSLSocket:
"""Return the same socket"""
# pylint: disable=unused-argument
return _FakeSSLSocket(socket, self._iface.TLS_MODE)


class NullLogger:
"""Fake logger class that does not do anything"""

# pylint: disable=unused-argument
def nothing(self, msg: str, *args) -> None:
"""no action"""
pass

def __init__(self) -> None:
for log_level in ["debug", "info", "warning", "error", "critical"]:
Expand Down Expand Up @@ -194,6 +154,7 @@ def __init__(
user_data=None,
use_imprecise_time: Optional[bool] = None,
) -> None:
self._connection_manager = get_connection_manager(socket_pool)
self._socket_pool = socket_pool
self._ssl_context = ssl_context
self._sock = None
Expand Down Expand Up @@ -300,77 +261,6 @@ def get_monotonic_time(self) -> float:

return time.monotonic()

# pylint: disable=too-many-branches
def _get_connect_socket(self, host: str, port: int, *, timeout: int = 1):
"""Obtains a new socket and connects to a broker.

:param str host: Desired broker hostname
:param int port: Desired broker port
:param int timeout: Desired socket timeout, in seconds
"""
# For reconnections - check if we're using a socket already and close it
if self._sock:
self._sock.close()
self._sock = None

# Legacy API - use the interface's socket instead of a passed socket pool
if self._socket_pool is None:
self._socket_pool = _default_sock

# Legacy API - fake the ssl context
if self._ssl_context is None:
self._ssl_context = _fake_context

if not isinstance(port, int):
raise RuntimeError("Port must be an integer")

if self._is_ssl and not self._ssl_context:
raise RuntimeError(
"ssl_context must be set before using adafruit_mqtt for secure MQTT."
)

if self._is_ssl:
self.logger.info(f"Establishing a SECURE SSL connection to {host}:{port}")
else:
self.logger.info(f"Establishing an INSECURE connection to {host}:{port}")

addr_info = self._socket_pool.getaddrinfo(
host, port, 0, self._socket_pool.SOCK_STREAM
)[0]

try:
sock = self._socket_pool.socket(addr_info[0], addr_info[1])
except OSError as exc:
# Do not consider this for back-off.
self.logger.warning(
f"Failed to create socket for host {addr_info[0]} and port {addr_info[1]}"
)
raise TemporaryError from exc

connect_host = addr_info[-1][0]
if self._is_ssl:
sock = self._ssl_context.wrap_socket(sock, server_hostname=host)
connect_host = host
sock.settimeout(timeout)

last_exception = None
try:
sock.connect((connect_host, port))
except MemoryError as exc:
sock.close()
self.logger.warning(f"Failed to allocate memory for connect: {exc}")
# Do not consider this for back-off.
raise TemporaryError from exc
except OSError as exc:
sock.close()
last_exception = exc

if last_exception:
raise last_exception

self._backwards_compatible_sock = not hasattr(sock, "recv_into")
return sock

def __enter__(self):
return self

Expand Down Expand Up @@ -593,8 +483,15 @@ def _connect(
time.sleep(self._reconnect_timeout)

# Get a new socket
self._sock = self._get_connect_socket(
self.broker, self.port, timeout=self._socket_timeout
self._sock = self._connection_manager.get_socket(
self.broker,
self.port,
"mqtt:",
timeout=self._socket_timeout,
is_ssl=self._is_ssl,
ssl_context=self._ssl_context,
max_retries=1, # setting to 1 since we want to handle backoff internally
exception_passthrough=True,
)

# Fixed Header
Expand Down Expand Up @@ -689,7 +586,7 @@ def disconnect(self) -> None:
except RuntimeError as e:
self.logger.warning(f"Unable to send DISCONNECT packet: {e}")
self.logger.debug("Closing socket")
self._sock.close()
self._connection_manager.free_socket(self._sock)
self._is_connected = False
self._subscribed_topics = []
if self.on_disconnect is not None:
Expand Down
17 changes: 17 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# SPDX-FileCopyrightText: 2023 Justin Myers for Adafruit Industries
#
# SPDX-License-Identifier: Unlicense

""" PyTest Setup """

import pytest
import adafruit_connection_manager


@pytest.fixture(autouse=True)
def reset_connection_manager(monkeypatch):
"""Reset the ConnectionManager, since it's a singlton and will hold data"""
monkeypatch.setattr(
"adafruit_minimqtt.adafruit_minimqtt.get_connection_manager",
adafruit_connection_manager.ConnectionManager,
)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
# SPDX-License-Identifier: Unlicense

Adafruit-Blinka
Adafruit-Circuitpython-ConnectionManager@git+https://github.com/justmobilize/Adafruit_CircuitPython_ConnectionManager@connection-manager
37 changes: 33 additions & 4 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,38 @@
# SPDX-License-Identifier: MIT

[tox]
envlist = py39
envlist = py311

[testenv]
changedir = {toxinidir}/tests
deps = pytest==6.2.5
commands = pytest -v
description = run tests
deps =
pytest==7.4.3
pytest-subtests==0.11.0
commands = pytest

[testenv:coverage]
description = run coverage
deps =
pytest==7.4.3
pytest-cov==4.1.0
pytest-subtests==0.11.0
package = editable
commands =
coverage run --source=. --omit=tests/* --branch {posargs} -m pytest
coverage report
coverage html

[testenv:lint]
description = run linters
deps =
pre-commit==3.6.0
skip_install = true
commands = pre-commit run {posargs}

[testenv:docs]
description = build docs
deps =
-r requirements.txt
-r docs/requirements.txt
skip_install = true
commands = sphinx-build -E -W -b html docs/. _build/html