Skip to content

Commit

Permalink
Use a smart pipe that keeps connection status for passing data to the…
Browse files Browse the repository at this point in the history
… network writer
  • Loading branch information
Mossop committed Dec 20, 2024
1 parent 9a0abf1 commit 8c52d0c
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 124 deletions.
21 changes: 21 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ once_cell = { version = "1.20.2", default-features = false, features = [
"critical-section",
] }
hex = { version = "0.4.3", default-features = false }
pin-project = "1.1.7"
51 changes: 16 additions & 35 deletions src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,36 +14,16 @@ use embassy_sync::{
use embassy_time::Timer;
use embedded_io_async::Write;
use mqttrs::{
decode_slice,
Connect,
ConnectReturnCode,
LastWill,
Packet,
Pid,
Protocol,
Publish,
QoS,
QosPid,
decode_slice, Connect, ConnectReturnCode, LastWill, Packet, Pid, Protocol, Publish, QoS, QosPid,
};

use crate::{
device_id,
fmt::Debug2Format,
queue::LossyQueue,
ControlMessage,
Error,
MqttMessage,
Payload,
Publishable,
Topic,
TopicString,
CONFIRMATION_TIMEOUT,
DATA_CHANNEL,
DEFAULT_BACKOFF,
device_id, fmt::Debug2Format, pipe::ConnectedPipe, ControlMessage, Error, MqttMessage, Payload,
Publishable, Topic, TopicString, CONFIRMATION_TIMEOUT, DATA_CHANNEL, DEFAULT_BACKOFF,
RESET_BACKOFF,
};

static SEND_QUEUE: LossyQueue<CriticalSectionRawMutex, Payload, 10> = LossyQueue::new();
static SEND_QUEUE: ConnectedPipe<CriticalSectionRawMutex, Payload, 10> = ConnectedPipe::new();

pub(crate) static CONTROL_CHANNEL: PubSubChannel<CriticalSectionRawMutex, ControlMessage, 2, 5, 0> =
PubSubChannel::new();
Expand Down Expand Up @@ -88,14 +68,13 @@ mod atomic16 {
}
}

pub(crate) fn send_packet(packet: Packet<'_>) -> Result<(), Error> {
pub(crate) async fn send_packet(packet: Packet<'_>) -> Result<(), Error> {
let mut buffer = Payload::new();
trace!("Encoding packet");

match buffer.encode_packet(&packet) {
Ok(()) => {
trace!("Sending packet");
SEND_QUEUE.push(buffer);
trace!("Pushing new packet for broker");
SEND_QUEUE.push(buffer).await;
Ok(())
}
Err(_) => {
Expand Down Expand Up @@ -162,7 +141,7 @@ pub(crate) async fn publish(
payload,
});

send_packet(packet)?;
send_packet(packet).await?;

if let Some(expected_pid) = pid {
wait_for_publish(subscriber, expected_pid).await
Expand Down Expand Up @@ -313,10 +292,10 @@ where
match publish.qospid {
mqttrs::QosPid::AtMostOnce => {}
mqttrs::QosPid::AtLeastOnce(pid) => {
send_packet(Packet::Puback(pid))?;
send_packet(Packet::Puback(pid)).await?;
}
mqttrs::QosPid::ExactlyOnce(pid) => {
send_packet(Packet::Pubrec(pid))?;
send_packet(Packet::Pubrec(pid)).await?;
}
}
}
Expand All @@ -325,9 +304,9 @@ where
}
Packet::Pubrec(pid) => {
controller.publish_immediate(ControlMessage::Published(pid));
send_packet(Packet::Pubrel(pid))?;
send_packet(Packet::Pubrel(pid)).await?;
}
Packet::Pubrel(pid) => send_packet(Packet::Pubrel(pid))?,
Packet::Pubrel(pid) => send_packet(Packet::Pubrel(pid)).await?,
Packet::Pubcomp(_) => {}

Packet::Suback(suback) => {
Expand Down Expand Up @@ -408,9 +387,11 @@ where
return;
}

let reader = SEND_QUEUE.reader();

loop {
trace!("Writer waiting for data");
let buffer = SEND_QUEUE.pop().await;
let buffer = reader.receive().await;

trace!("Writer sending data");
if let Err(e) = writer.write_all(&buffer).await {
Expand Down Expand Up @@ -474,7 +455,7 @@ where
loop {
Timer::after_secs(45).await;

let _ = send_packet(Packet::Pingreq);
let _ = send_packet(Packet::Pingreq).await;
}
};

Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ mod buffer;
#[cfg(feature = "homeassistant")]
pub mod homeassistant;
mod io;
mod pipe;
mod publish;
mod queue;
mod topic;

// This really needs to match that used by mqttrs.
Expand Down
163 changes: 163 additions & 0 deletions src/pipe.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
use core::{
cell::RefCell,
future::Future,
pin::Pin,
task::{Context, Poll, Waker},
};

use embassy_sync::blocking_mutex::{raw::RawMutex, Mutex};
use pin_project::pin_project;

struct PipeData<T, const N: usize> {
connect_count: usize,
receiver_waker: Option<Waker>,
sender_waker: Option<Waker>,
pending: Option<T>,
}

fn swap_wakers(waker: &mut Option<Waker>, new_waker: &Waker) {
if let Some(old_waker) = waker.take() {
if old_waker.will_wake(new_waker) {
*waker = Some(old_waker)
} else {
if !new_waker.will_wake(&old_waker) {
old_waker.wake();
}

*waker = Some(new_waker.clone());
}
} else {
*waker = Some(new_waker.clone())
}
}

pub(crate) struct ReceiveFuture<'a, M: RawMutex, T, const N: usize> {
pipe: &'a ConnectedPipe<M, T, N>,
}

impl<M: RawMutex, T, const N: usize> Future for ReceiveFuture<'_, M, T, N> {
type Output = T;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.pipe.inner.lock(|cell| {
let mut inner = cell.borrow_mut();

if let Some(waker) = inner.sender_waker.take() {
waker.wake();
}

if let Some(item) = inner.pending.take() {
if let Some(old_waker) = inner.receiver_waker.take() {
old_waker.wake();
}

Poll::Ready(item)
} else {
swap_wakers(&mut inner.receiver_waker, cx.waker());
Poll::Pending
}
})
}
}

pub(crate) struct PipeReader<'a, M: RawMutex, T, const N: usize> {
pipe: &'a ConnectedPipe<M, T, N>,
}

impl<M: RawMutex, T, const N: usize> PipeReader<'_, M, T, N> {
#[must_use]
pub(crate) fn receive(&self) -> ReceiveFuture<'_, M, T, N> {
ReceiveFuture { pipe: self.pipe }
}
}

impl<M: RawMutex, T, const N: usize> Drop for PipeReader<'_, M, T, N> {
fn drop(&mut self) {
self.pipe.inner.lock(|cell| {
let mut inner = cell.borrow_mut();
inner.connect_count -= 1;

if inner.connect_count == 0 {
inner.pending = None;
}

if let Some(waker) = inner.sender_waker.take() {
waker.wake();
}
})
}
}

#[pin_project]
pub(crate) struct PushFuture<'a, M: RawMutex, T, const N: usize> {
data: Option<T>,
pipe: &'a ConnectedPipe<M, T, N>,
}

impl<M: RawMutex, T, const N: usize> Future for PushFuture<'_, M, T, N> {
type Output = ();

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.pipe.inner.lock(|cell| {
let project = self.project();
let mut inner = cell.borrow_mut();

if let Some(receiver) = inner.receiver_waker.take() {
receiver.wake();
}

if project.data.is_none() || inner.connect_count == 0 {
trace!("Dropping packet");
Poll::Ready(())
} else if inner.pending.is_some() {
swap_wakers(&mut inner.sender_waker, cx.waker());
Poll::Pending
} else {
trace!("Pushed packet to receiver");
inner.pending = project.data.take();

Poll::Ready(())
}
})
}
}

/// A pipe that knows whether a receiver is connected. If so pushing to the
/// queue waits until there is space in the queue, otherwise data is simply
/// dropped.
pub(crate) struct ConnectedPipe<M: RawMutex, T, const N: usize> {
inner: Mutex<M, RefCell<PipeData<T, N>>>,
}

impl<M: RawMutex, T, const N: usize> ConnectedPipe<M, T, N> {
pub(crate) const fn new() -> Self {
Self {
inner: Mutex::new(RefCell::new(PipeData {
connect_count: 0,
receiver_waker: None,
sender_waker: None,
pending: None,
})),
}
}

/// A future that waits for a new item to be available.
pub(crate) fn reader(&self) -> PipeReader<'_, M, T, N> {
self.inner.lock(|cell| {
let mut inner = cell.borrow_mut();
inner.connect_count += 1;

PipeReader { pipe: self }
})
}

/// Pushes an item to the reader, waiting for a slot to become available if
/// connected.
#[must_use]
pub(crate) fn push(&self, data: T) -> PushFuture<'_, M, T, N> {
PushFuture {
data: Some(data),
pipe: self,
}
}
}
Loading

0 comments on commit 8c52d0c

Please sign in to comment.