Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix socket receive issues related to message buffer size #55

Merged
merged 1 commit into from
Jan 22, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 72 additions & 28 deletions adafruit_minimqtt/adafruit_minimqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,35 @@ def __enter__(self):
def __exit__(self, exception_type, exception_value, traceback):
self.deinit()

def _sock_exact_recv(self, bufsize):
brentru marked this conversation as resolved.
Show resolved Hide resolved
"""Reads _exact_ number of bytes from the connected socket. Will only return
string with the exact number of bytes requested.

The semantics of native socket receive is that it returns no more than the
specified number of bytes (i.e. max size). However, it makes no guarantees in
terms of the minimum size of the buffer, which could be 1 byte. This is a
wrapper for socket recv() to ensure that no less than the expected number of
bytes is returned or trigger a timeout exception.

:param int bufsize: number of bytes to receive
"""
stamp = time.monotonic()
rc = self._sock.recv(bufsize)
to_read = bufsize - len(rc)
assert to_read >= 0
read_timeout = self.keep_alive
while to_read > 0:
recv = self._sock.recv(to_read)
to_read -= len(recv)
rc += recv
if time.monotonic() - stamp > read_timeout:
raise MMQTTException(
"Unable to receive {} bytes within {} seconds.".format(
to_read, read_timeout
)
)
return rc

def deinit(self):
"""De-initializes the MQTT client and disconnects from the mqtt broker."""
self.disconnect()
Expand Down Expand Up @@ -351,7 +380,7 @@ def connect(self, clean_session=True):
while True:
op = self._wait_for_msg()
if op == 32:
rc = self._sock.recv(3)
rc = self._sock_exact_recv(3)
assert rc[0] == 0x02
if rc[2] != 0x00:
raise MMQTTException(CONNACK_ERRORS[rc[2]])
Expand All @@ -366,32 +395,38 @@ def disconnect(self):
self.is_connected()
if self.logger is not None:
self.logger.debug("Sending DISCONNECT packet to broker")
self._sock.send(MQTT_DISCONNECT)
try:
brentru marked this conversation as resolved.
Show resolved Hide resolved
self._sock.send(MQTT_DISCONNECT)
except RuntimeError as e:
if self.logger:
self.logger.warning("Unable to send DISCONNECT packet: {}".format(e))
if self.logger is not None:
self.logger.debug("Closing socket")
self._sock.close()
self._is_connected = False
self._subscribed_topics = None
self._subscribed_topics = []
brentru marked this conversation as resolved.
Show resolved Hide resolved
if self.on_disconnect is not None:
self.on_disconnect(self, self.user_data, 0)

def ping(self):
"""Pings the MQTT Broker to confirm if the broker is alive or if
there is an active network connection.
Returns response codes of any messages received while waiting for PINGRESP.
"""
self.is_connected()
if self.logger is not None:
if self.logger:
self.logger.debug("Sending PINGREQ")
self._sock.send(MQTT_PINGREQ)
if self.logger is not None:
self.logger.debug("Checking PINGRESP")
while True:
op = self._wait_for_msg(0.5)
if op == 208:
ping_resp = self._sock.recv(2)
if ping_resp[0] != 0x00:
raise MMQTTException("PINGRESP not returned from broker.")
return
ping_timeout = self.keep_alive
stamp = time.monotonic()
rc, rcs = None, []
while rc != MQTT_PINGRESP:
rc = self._wait_for_msg()
if rc:
rcs.append(rc)
if time.monotonic() - stamp > ping_timeout:
raise MMQTTException("PINGRESP not returned from broker.")
return rcs

# pylint: disable=too-many-branches, too-many-statements
def publish(self, topic, msg, retain=False, qos=0):
Expand Down Expand Up @@ -486,9 +521,9 @@ def publish(self, topic, msg, retain=False, qos=0):
while True:
op = self._wait_for_msg()
if op == 0x40:
sz = self._sock.recv(1)
sz = self._sock_exact_recv(1)
assert sz == b"\x02"
rcv_pid = self._sock.recv(2)
rcv_pid = self._sock_exact_recv(2)
rcv_pid = rcv_pid[0] << 0x08 | rcv_pid[1]
if pid == rcv_pid:
if self.on_publish is not None:
Expand Down Expand Up @@ -571,7 +606,7 @@ def subscribe(self, topic, qos=0):
while True:
op = self._wait_for_msg()
if op == 0x90:
rc = self._sock.recv(4)
rc = self._sock_exact_recv(4)
assert rc[1] == packet[2] and rc[2] == packet[3]
if rc[3] == 0x80:
raise MMQTTException("SUBACK Failure!")
Expand Down Expand Up @@ -634,7 +669,7 @@ def unsubscribe(self, topic):
while True:
op = self._wait_for_msg()
if op == 176:
return_code = self._sock.recv(3)
return_code = self._sock_exact_recv(3)
assert return_code[0] == 0x02
# [MQTT-3.32]
assert (
Expand Down Expand Up @@ -671,6 +706,7 @@ def reconnect(self, resub_topics=True):
def loop(self):
"""Non-blocking message loop. Use this method to
check incoming subscription messages.
Returns response codes of any messages received.
"""
if self._timestamp == 0:
self._timestamp = time.monotonic()
Expand All @@ -682,10 +718,12 @@ def loop(self):
"KeepAlive period elapsed - \
requesting a PINGRESP from the server..."
)
self.ping()
rcs = self.ping()
self._timestamp = 0
return rcs
self._sock.settimeout(0.1)
return self._wait_for_msg()
rc = self._wait_for_msg()
return [rc] if rc else None

def _wait_for_msg(self, timeout=30):
"""Reads and processes network events.
Expand All @@ -694,24 +732,30 @@ def _wait_for_msg(self, timeout=30):
res = self._sock.recv(1)
self._sock.settimeout(timeout)
if res in [None, b""]:
# If we get here, it means that there is nothing to be received
return None
if res == MQTT_PINGRESP:
sz = self._sock.recv(1)[0]
assert sz == 0
return None
if res[0] == MQTT_PINGRESP:
if self.logger:
self.logger.debug("Checking PINGRESP")
sz = self._sock_exact_recv(1)[0]
if sz != 0x00:
raise MMQTTException(
"Unexpected PINGRESP returned from broker: {}.".format(sz)
)
return MQTT_PINGRESP
if res[0] & 0xF0 != 0x30:
return res[0]
sz = self._recv_len()
topic_len = self._sock.recv(2)
topic_len = self._sock_exact_recv(2)
topic_len = (topic_len[0] << 8) | topic_len[1]
topic = self._sock.recv(topic_len)
topic = self._sock_exact_recv(topic_len)
topic = str(topic, "utf-8")
sz -= topic_len + 2
if res[0] & 0x06:
pid = self._sock.recv(2)
pid = self._sock_exact_recv(2)
pid = pid[0] << 0x08 | pid[1]
sz -= 0x02
msg = self._sock.recv(sz)
msg = self._sock_exact_recv(sz)
self._handle_on_message(self, topic, str(msg, "utf-8"))
if res[0] & 0x06 == 0x02:
pkt = bytearray(b"\x40\x02\0\0")
Expand All @@ -725,7 +769,7 @@ def _recv_len(self):
n = 0
sh = 0
while True:
b = self._sock.recv(1)[0]
b = self._sock_exact_recv(1)[0]
n |= (b & 0x7F) << sh
if not b & 0x80:
return n
Expand Down