Skip to content

Commit

Permalink
Fix socket receive issues related to message buffer size
Browse files Browse the repository at this point in the history
This change addresses a few issues in the handling of the MQTT
messages that caused the library to become unstable:

- Add wapper for socket.recv() so that an exact number of bytes
  are read into the buffer before attempting to parse the MQTT
  message;

- Fix handling of ping response packets as part of _wait_for_msg(),
  together with all other MQTT messages;

- Fix disconnect so it can gracefully handle cases when socket writes
  are not possible. Also re-init _subscribed_topics as an empty list
  instead of None.

Related-to adafruit/Adafruit_CircuitPython_ESP32SPI#102
Fixes adafruit/Adafruit_CircuitPython_PyPortal#98
Fixes #54
Signed-off-by: Flavio Fernandes <[email protected]>
  • Loading branch information
flavio-fernandes committed Jan 13, 2021
1 parent e02d658 commit e353adb
Showing 1 changed file with 63 additions and 24 deletions.
87 changes: 63 additions & 24 deletions adafruit_minimqtt/adafruit_minimqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def __init__(
self.logger.setLevel(logging.INFO)
self._sock = None
self._is_connected = False
self._pending_ping_response = False
self._msg_size_lim = MQTT_MSG_SZ_LIM
self._pid = 0
self._timestamp = 0
Expand All @@ -175,6 +176,35 @@ def __enter__(self):
def __exit__(self, exception_type, exception_value, traceback):
self.deinit()

def _sock_exact_recv(self, bufsize):
"""Reads _exact_ number of bytes from the connected socket. Will only return
string with the exact number of bytes requested.
The semantics of native socket receive is that it returns no more than the
specified number of bytes (i.e. max size). However, it makes no guarantees in
terms of the minimum size of the buffer, which could be 1 byte. This is a
wrapper for socket recv() to ensure that no less than the expected number of
bytes is returned or trigger a timeout exception.
:param int bufsize: number of bytes to receive
"""
stamp = time.monotonic()
rc = self._sock.recv(bufsize)
to_read = bufsize - len(rc)
assert to_read >= 0
read_timeout = min(self.keep_alive, self._sock._timeout)
while to_read > 0:
recv = self._sock.recv(to_read)
to_read -= len(recv)
rc += recv
if time.monotonic() - stamp > read_timeout:
raise MMQTTException(
"Unable to receive {} bytes within {} seconds.".format(
to_read, read_timeout
)
)
return rc

def deinit(self):
"""De-initializes the MQTT client and disconnects from the mqtt broker."""
self.disconnect()
Expand Down Expand Up @@ -351,7 +381,7 @@ def connect(self, clean_session=True):
while True:
op = self._wait_for_msg()
if op == 32:
rc = self._sock.recv(3)
rc = self._sock_exact_recv(3)
assert rc[0] == 0x02
if rc[2] != 0x00:
raise MMQTTException(CONNACK_ERRORS[rc[2]])
Expand All @@ -366,12 +396,16 @@ def disconnect(self):
self.is_connected()
if self.logger is not None:
self.logger.debug("Sending DISCONNECT packet to broker")
self._sock.send(MQTT_DISCONNECT)
try:
self._sock.send(MQTT_DISCONNECT)
except RuntimeError as e:
if self.logger:
self.logger.warning("Unable to send DISCONNECT packet: {}".format(e))
if self.logger is not None:
self.logger.debug("Closing socket")
self._sock.close()
self._is_connected = False
self._subscribed_topics = None
self._subscribed_topics = []
if self.on_disconnect is not None:
self.on_disconnect(self, self.user_data, 0)

Expand All @@ -380,18 +414,15 @@ def ping(self):
there is an active network connection.
"""
self.is_connected()
if self._pending_ping_response:
self._pending_ping_response = False
raise MMQTTException("Ping response was pending from previous MQTT_PINGREQ")
if self.logger is not None:
self.logger.debug("Sending PINGREQ")
self._sock.send(MQTT_PINGREQ)
if self.logger is not None:
self.logger.debug("Checking PINGRESP")
while True:
op = self._wait_for_msg(0.5)
if op == 208:
ping_resp = self._sock.recv(2)
if ping_resp[0] != 0x00:
raise MMQTTException("PINGRESP not returned from broker.")
return
# Set pending ping response. It will be checked upon next ping and
# assumed to be cleared via _wait_for_msg()
self._pending_ping_response = True

# pylint: disable=too-many-branches, too-many-statements
def publish(self, topic, msg, retain=False, qos=0):
Expand Down Expand Up @@ -486,9 +517,9 @@ def publish(self, topic, msg, retain=False, qos=0):
while True:
op = self._wait_for_msg()
if op == 0x40:
sz = self._sock.recv(1)
sz = self._sock_exact_recv(1)
assert sz == b"\x02"
rcv_pid = self._sock.recv(2)
rcv_pid = self._sock_exact_recv(2)
rcv_pid = rcv_pid[0] << 0x08 | rcv_pid[1]
if pid == rcv_pid:
if self.on_publish is not None:
Expand Down Expand Up @@ -571,7 +602,7 @@ def subscribe(self, topic, qos=0):
while True:
op = self._wait_for_msg()
if op == 0x90:
rc = self._sock.recv(4)
rc = self._sock_exact_recv(4)
assert rc[1] == packet[2] and rc[2] == packet[3]
if rc[3] == 0x80:
raise MMQTTException("SUBACK Failure!")
Expand Down Expand Up @@ -634,7 +665,7 @@ def unsubscribe(self, topic):
while True:
op = self._wait_for_msg()
if op == 176:
return_code = self._sock.recv(3)
return_code = self._sock_exact_recv(3)
assert return_code[0] == 0x02
# [MQTT-3.32]
assert (
Expand Down Expand Up @@ -694,24 +725,32 @@ def _wait_for_msg(self, timeout=30):
res = self._sock.recv(1)
self._sock.settimeout(timeout)
if res in [None, b""]:
# If we get here, it means that there is nothing to be received
return None
if res == MQTT_PINGRESP:
sz = self._sock.recv(1)[0]
assert sz == 0
if res[0] == MQTT_PINGRESP:
if self.logger:
self.logger.debug("Checking PINGRESP")
sz = self._sock_exact_recv(1)[0]
if sz != 0x00:
raise MMQTTException(
"Unexpected PINGRESP returned from broker: {}.".format(sz)
)
# Ping response is no longer pending
self._pending_ping_response = False
return None
if res[0] & 0xF0 != 0x30:
return res[0]
sz = self._recv_len()
topic_len = self._sock.recv(2)
topic_len = self._sock_exact_recv(2)
topic_len = (topic_len[0] << 8) | topic_len[1]
topic = self._sock.recv(topic_len)
topic = self._sock_exact_recv(topic_len)
topic = str(topic, "utf-8")
sz -= topic_len + 2
if res[0] & 0x06:
pid = self._sock.recv(2)
pid = self._sock_exact_recv(2)
pid = pid[0] << 0x08 | pid[1]
sz -= 0x02
msg = self._sock.recv(sz)
msg = self._sock_exact_recv(sz)
self._handle_on_message(self, topic, str(msg, "utf-8"))
if res[0] & 0x06 == 0x02:
pkt = bytearray(b"\x40\x02\0\0")
Expand All @@ -725,7 +764,7 @@ def _recv_len(self):
n = 0
sh = 0
while True:
b = self._sock.recv(1)[0]
b = self._sock_exact_recv(1)[0]
n |= (b & 0x7F) << sh
if not b & 0x80:
return n
Expand Down

0 comments on commit e353adb

Please sign in to comment.