From 8c52d0c71110d524072be4fbb490b7af08399977 Mon Sep 17 00:00:00 2001 From: Dave Townsend Date: Fri, 20 Dec 2024 12:54:49 +0000 Subject: [PATCH] Use a smart pipe that keeps connection status for passing data to the network writer --- Cargo.lock | 21 +++++++ Cargo.toml | 1 + src/io.rs | 51 +++++----------- src/lib.rs | 2 +- src/pipe.rs | 163 +++++++++++++++++++++++++++++++++++++++++++++++++++ src/queue.rs | 80 ------------------------- src/topic.rs | 12 ++-- 7 files changed, 206 insertions(+), 124 deletions(-) create mode 100644 src/pipe.rs delete mode 100644 src/queue.rs diff --git a/Cargo.lock b/Cargo.lock index 9047f6b..4a2c73b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -344,6 +344,7 @@ dependencies = [ "log", "mqttrs", "once_cell", + "pin-project", "serde", "serde-json-core", ] @@ -382,6 +383,26 @@ dependencies = [ "portable-atomic", ] +[[package]] +name = "pin-project" +version = "1.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be57f64e946e500c8ee36ef6331845d40a93055567ec57e8fae13efd33759b95" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c0f5fad0874fc7abcd4d750e76917eaebbecaa2c20bde22e1dbeeba8beb758c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "pin-project-lite" version = "0.2.15" diff --git a/Cargo.toml b/Cargo.toml index b2bf9c0..7c5b35d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,3 +38,4 @@ once_cell = { version = "1.20.2", default-features = false, features = [ "critical-section", ] } hex = { version = "0.4.3", default-features = false } +pin-project = "1.1.7" diff --git a/src/io.rs b/src/io.rs index fcb3791..fb2b9f6 100644 --- a/src/io.rs +++ b/src/io.rs @@ -14,36 +14,16 @@ use embassy_sync::{ use embassy_time::Timer; use embedded_io_async::Write; use mqttrs::{ - decode_slice, - Connect, - ConnectReturnCode, - LastWill, - Packet, - Pid, - Protocol, - Publish, - QoS, - QosPid, + decode_slice, Connect, ConnectReturnCode, LastWill, Packet, Pid, Protocol, Publish, QoS, QosPid, }; use crate::{ - device_id, - fmt::Debug2Format, - queue::LossyQueue, - ControlMessage, - Error, - MqttMessage, - Payload, - Publishable, - Topic, - TopicString, - CONFIRMATION_TIMEOUT, - DATA_CHANNEL, - DEFAULT_BACKOFF, + device_id, fmt::Debug2Format, pipe::ConnectedPipe, ControlMessage, Error, MqttMessage, Payload, + Publishable, Topic, TopicString, CONFIRMATION_TIMEOUT, DATA_CHANNEL, DEFAULT_BACKOFF, RESET_BACKOFF, }; -static SEND_QUEUE: LossyQueue = LossyQueue::new(); +static SEND_QUEUE: ConnectedPipe = ConnectedPipe::new(); pub(crate) static CONTROL_CHANNEL: PubSubChannel = PubSubChannel::new(); @@ -88,14 +68,13 @@ mod atomic16 { } } -pub(crate) fn send_packet(packet: Packet<'_>) -> Result<(), Error> { +pub(crate) async fn send_packet(packet: Packet<'_>) -> Result<(), Error> { let mut buffer = Payload::new(); - trace!("Encoding packet"); match buffer.encode_packet(&packet) { Ok(()) => { - trace!("Sending packet"); - SEND_QUEUE.push(buffer); + trace!("Pushing new packet for broker"); + SEND_QUEUE.push(buffer).await; Ok(()) } Err(_) => { @@ -162,7 +141,7 @@ pub(crate) async fn publish( payload, }); - send_packet(packet)?; + send_packet(packet).await?; if let Some(expected_pid) = pid { wait_for_publish(subscriber, expected_pid).await @@ -313,10 +292,10 @@ where match publish.qospid { mqttrs::QosPid::AtMostOnce => {} mqttrs::QosPid::AtLeastOnce(pid) => { - send_packet(Packet::Puback(pid))?; + send_packet(Packet::Puback(pid)).await?; } mqttrs::QosPid::ExactlyOnce(pid) => { - send_packet(Packet::Pubrec(pid))?; + send_packet(Packet::Pubrec(pid)).await?; } } } @@ -325,9 +304,9 @@ where } Packet::Pubrec(pid) => { controller.publish_immediate(ControlMessage::Published(pid)); - send_packet(Packet::Pubrel(pid))?; + send_packet(Packet::Pubrel(pid)).await?; } - Packet::Pubrel(pid) => send_packet(Packet::Pubrel(pid))?, + Packet::Pubrel(pid) => send_packet(Packet::Pubrel(pid)).await?, Packet::Pubcomp(_) => {} Packet::Suback(suback) => { @@ -408,9 +387,11 @@ where return; } + let reader = SEND_QUEUE.reader(); + loop { trace!("Writer waiting for data"); - let buffer = SEND_QUEUE.pop().await; + let buffer = reader.receive().await; trace!("Writer sending data"); if let Err(e) = writer.write_all(&buffer).await { @@ -474,7 +455,7 @@ where loop { Timer::after_secs(45).await; - let _ = send_packet(Packet::Pingreq); + let _ = send_packet(Packet::Pingreq).await; } }; diff --git a/src/lib.rs b/src/lib.rs index caf5bcd..706a20a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,8 +24,8 @@ mod buffer; #[cfg(feature = "homeassistant")] pub mod homeassistant; mod io; +mod pipe; mod publish; -mod queue; mod topic; // This really needs to match that used by mqttrs. diff --git a/src/pipe.rs b/src/pipe.rs new file mode 100644 index 0000000..a8529bf --- /dev/null +++ b/src/pipe.rs @@ -0,0 +1,163 @@ +use core::{ + cell::RefCell, + future::Future, + pin::Pin, + task::{Context, Poll, Waker}, +}; + +use embassy_sync::blocking_mutex::{raw::RawMutex, Mutex}; +use pin_project::pin_project; + +struct PipeData { + connect_count: usize, + receiver_waker: Option, + sender_waker: Option, + pending: Option, +} + +fn swap_wakers(waker: &mut Option, new_waker: &Waker) { + if let Some(old_waker) = waker.take() { + if old_waker.will_wake(new_waker) { + *waker = Some(old_waker) + } else { + if !new_waker.will_wake(&old_waker) { + old_waker.wake(); + } + + *waker = Some(new_waker.clone()); + } + } else { + *waker = Some(new_waker.clone()) + } +} + +pub(crate) struct ReceiveFuture<'a, M: RawMutex, T, const N: usize> { + pipe: &'a ConnectedPipe, +} + +impl Future for ReceiveFuture<'_, M, T, N> { + type Output = T; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.pipe.inner.lock(|cell| { + let mut inner = cell.borrow_mut(); + + if let Some(waker) = inner.sender_waker.take() { + waker.wake(); + } + + if let Some(item) = inner.pending.take() { + if let Some(old_waker) = inner.receiver_waker.take() { + old_waker.wake(); + } + + Poll::Ready(item) + } else { + swap_wakers(&mut inner.receiver_waker, cx.waker()); + Poll::Pending + } + }) + } +} + +pub(crate) struct PipeReader<'a, M: RawMutex, T, const N: usize> { + pipe: &'a ConnectedPipe, +} + +impl PipeReader<'_, M, T, N> { + #[must_use] + pub(crate) fn receive(&self) -> ReceiveFuture<'_, M, T, N> { + ReceiveFuture { pipe: self.pipe } + } +} + +impl Drop for PipeReader<'_, M, T, N> { + fn drop(&mut self) { + self.pipe.inner.lock(|cell| { + let mut inner = cell.borrow_mut(); + inner.connect_count -= 1; + + if inner.connect_count == 0 { + inner.pending = None; + } + + if let Some(waker) = inner.sender_waker.take() { + waker.wake(); + } + }) + } +} + +#[pin_project] +pub(crate) struct PushFuture<'a, M: RawMutex, T, const N: usize> { + data: Option, + pipe: &'a ConnectedPipe, +} + +impl Future for PushFuture<'_, M, T, N> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.pipe.inner.lock(|cell| { + let project = self.project(); + let mut inner = cell.borrow_mut(); + + if let Some(receiver) = inner.receiver_waker.take() { + receiver.wake(); + } + + if project.data.is_none() || inner.connect_count == 0 { + trace!("Dropping packet"); + Poll::Ready(()) + } else if inner.pending.is_some() { + swap_wakers(&mut inner.sender_waker, cx.waker()); + Poll::Pending + } else { + trace!("Pushed packet to receiver"); + inner.pending = project.data.take(); + + Poll::Ready(()) + } + }) + } +} + +/// A pipe that knows whether a receiver is connected. If so pushing to the +/// queue waits until there is space in the queue, otherwise data is simply +/// dropped. +pub(crate) struct ConnectedPipe { + inner: Mutex>>, +} + +impl ConnectedPipe { + pub(crate) const fn new() -> Self { + Self { + inner: Mutex::new(RefCell::new(PipeData { + connect_count: 0, + receiver_waker: None, + sender_waker: None, + pending: None, + })), + } + } + + /// A future that waits for a new item to be available. + pub(crate) fn reader(&self) -> PipeReader<'_, M, T, N> { + self.inner.lock(|cell| { + let mut inner = cell.borrow_mut(); + inner.connect_count += 1; + + PipeReader { pipe: self } + }) + } + + /// Pushes an item to the reader, waiting for a slot to become available if + /// connected. + #[must_use] + pub(crate) fn push(&self, data: T) -> PushFuture<'_, M, T, N> { + PushFuture { + data: Some(data), + pipe: self, + } + } +} diff --git a/src/queue.rs b/src/queue.rs deleted file mode 100644 index 0c9fb96..0000000 --- a/src/queue.rs +++ /dev/null @@ -1,80 +0,0 @@ -use core::{ - cell::RefCell, - future::Future, - pin::Pin, - task::{Context, Poll, Waker}, -}; - -use embassy_sync::blocking_mutex::{raw::RawMutex, Mutex}; -use heapless::Deque; - -struct LossyQueueData { - receiver_waker: Option, - queue: Deque, -} - -pub(crate) struct ReceiveFuture<'a, M: RawMutex, T, const N: usize> { - pipe: &'a LossyQueue, -} - -impl Future for ReceiveFuture<'_, M, T, N> { - type Output = T; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.pipe.inner.lock(|cell| { - let mut inner = cell.borrow_mut(); - - if let Some(waker) = inner.receiver_waker.take() { - waker.wake(); - } - - if let Some(item) = inner.queue.pop_front() { - Poll::Ready(item) - } else { - inner.receiver_waker = Some(cx.waker().clone()); - Poll::Pending - } - }) - } -} - -/// A FIFO queue holding a fixed number of items. Older items are dropped if the -/// queue is full when a new item is pushed. -pub(crate) struct LossyQueue { - inner: Mutex>>, -} - -impl LossyQueue { - pub(crate) const fn new() -> Self { - Self { - inner: Mutex::new(RefCell::new(LossyQueueData { - receiver_waker: None, - queue: Deque::new(), - })), - } - } - - /// A future that waits for a new item to be available. - pub(crate) fn pop(&self) -> ReceiveFuture<'_, M, T, N> { - ReceiveFuture { pipe: self } - } - - /// Pushes an item into the queue. If the queue is already full the oldest - /// item is dropped to make space. - pub(crate) fn push(&self, data: T) { - self.inner.lock(|cell| { - let mut inner = cell.borrow_mut(); - - if inner.queue.is_full() { - inner.queue.pop_front(); - } - - // As we pop above the queue cannot be full now. - let _ = inner.queue.push_back(data); - - if let Some(waker) = inner.receiver_waker.take() { - waker.wake(); - } - }) - } -} diff --git a/src/topic.rs b/src/topic.rs index 9d95b9c..935f336 100644 --- a/src/topic.rs +++ b/src/topic.rs @@ -9,14 +9,10 @@ use mqttrs::{Packet, QoS, Subscribe, SubscribeReturnCodes, SubscribeTopic, Unsub #[cfg(feature = "serde")] use crate::publish::PublishJson; use crate::{ - device_id, - device_type, + device_id, device_type, io::{assign_pid, send_packet, subscribe}, publish::{PublishBytes, PublishDisplay}, - ControlMessage, - Error, - TopicString, - CONFIRMATION_TIMEOUT, + ControlMessage, Error, TopicString, CONFIRMATION_TIMEOUT, }; /// An MQTT topic that is optionally prefixed with the device type and unique ID. @@ -189,7 +185,7 @@ impl> Topic { let packet = Packet::Subscribe(Subscribe { pid, topics }); - send_packet(packet)?; + send_packet(packet).await?; if wait_for_ack { match select( @@ -247,7 +243,7 @@ impl> Topic { let packet = Packet::Unsubscribe(Unsubscribe { pid, topics }); - send_packet(packet)?; + send_packet(packet).await?; if wait_for_ack { match select(