Skip to content

Commit a994872

Browse files
committed
Handle message redelivery and re-acknowledgment in input worker
1 parent 2c9289c commit a994872

File tree

1 file changed

+63
-41
lines changed

1 file changed

+63
-41
lines changed

lib/protoflow-zeromq/src/input_port.rs

+63-41
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
// This is free and unencumbered software released into the public domain.
22

33
use crate::{
4-
subscribe_topics, unsubscribe_topics, ZmqSubscriptionRequest, ZmqTransport, ZmqTransportEvent,
4+
subscribe_topics, unsubscribe_topics, SequenceID, ZmqSubscriptionRequest, ZmqTransport,
5+
ZmqTransportEvent,
56
};
67
use protoflow_core::{
78
prelude::{fmt, format, vec, Arc, BTreeMap, Bytes, String, ToString, Vec},
@@ -44,7 +45,7 @@ pub enum ZmqInputPortState {
4445
// channel used internally for events from socket
4546
Sender<ZmqTransportEvent>,
4647
// vec of the connected port ids
47-
Vec<OutputPortID>,
48+
BTreeMap<OutputPortID, SequenceID>,
4849
),
4950
Closed,
5051
}
@@ -58,7 +59,7 @@ impl fmt::Display for ZmqInputPortState {
5859
write!(
5960
f,
6061
"Connected({:?})",
61-
ids.iter().map(|id| isize::from(*id)).collect::<Vec<_>>()
62+
ids.keys().map(|id| isize::from(*id)).collect::<Vec<_>>()
6263
)
6364
}
6465
Closed => write!(f, "Closed"),
@@ -156,7 +157,7 @@ pub fn start_input_worker(
156157
match &*input_state {
157158
Open(..) => (),
158159
Connected(.., connected_ids) => {
159-
if connected_ids.iter().any(|&id| id == output_port_id) {
160+
if connected_ids.contains_key(&output_port_id) {
160161
#[cfg(feature = "tracing")]
161162
span.in_scope(|| trace!("output port is already connected"));
162163
return;
@@ -169,16 +170,18 @@ pub fn start_input_worker(
169170
Open(req_send, to_worker_send) => {
170171
let (msgs_send, msgs_recv) = channel(1);
171172
let msgs_recv = Arc::new(Mutex::new(msgs_recv));
173+
let mut connected_ids = BTreeMap::new();
174+
connected_ids.insert(output_port_id, 0);
172175
*input_state = Connected(
173176
req_send.clone(),
174177
msgs_send,
175178
msgs_recv,
176179
to_worker_send.clone(),
177-
vec![output_port_id],
180+
connected_ids,
178181
);
179182
}
180183
Connected(.., ids) => {
181-
ids.push(output_port_id);
184+
ids.insert(output_port_id, 0);
182185
}
183186
Closed => unreachable!(),
184187
};
@@ -204,9 +207,9 @@ pub fn start_input_worker(
204207
#[cfg(feature = "tracing")]
205208
span.in_scope(|| info!("Connected new port: {}", input_state));
206209
}
207-
Message(output_port_id, target_id, seq_id, bytes) => {
210+
Message(output_port_id, target_id, msg_seq_id, bytes) => {
208211
#[cfg(feature = "tracing")]
209-
let span = trace_span!(parent: &span, "Message", ?output_port_id, ?seq_id);
212+
let span = trace_span!(parent: &span, "Message", ?output_port_id, ?msg_seq_id);
210213

211214
debug_assert_eq!(input_port_id, target_id);
212215

@@ -216,43 +219,64 @@ pub fn start_input_worker(
216219
span.in_scope(|| error!("port state not found"));
217220
return;
218221
};
219-
let input_state = input_state.read().await;
222+
let mut input_state = input_state.write().await;
220223

221224
use ZmqInputPortState::*;
222-
match &*input_state {
223-
Connected(_, sender, _, _, connected_ids) => {
224-
if !connected_ids.iter().any(|id| *id == output_port_id) {
225+
match *input_state {
226+
Connected(_, ref sender, _, _, ref mut connected_ids) => {
227+
let Some(&last_seen_seq_id) = connected_ids.get(&output_port_id) else {
225228
#[cfg(feature = "tracing")]
226229
span.in_scope(|| trace!("got message from non-connected output port"));
227230
return;
228-
}
231+
};
229232

230-
if sender
231-
.send(ZmqInputPortEvent::Message(bytes))
232-
.await
233-
.is_err()
234-
{
233+
let send_ack = {
235234
#[cfg(feature = "tracing")]
236-
span.in_scope(|| warn!("receiver for input events has closed"));
237-
return;
235+
let span = span.clone();
236+
237+
|ack_id| async move {
238+
if pub_queue
239+
.send(ZmqTransportEvent::AckMessage(
240+
output_port_id,
241+
input_port_id,
242+
ack_id,
243+
))
244+
.await
245+
.is_err()
246+
{
247+
#[cfg(feature = "tracing")]
248+
span.in_scope(|| warn!("publish channel is closed"));
249+
}
250+
#[cfg(feature = "tracing")]
251+
span.in_scope(|| trace!(?ack_id, "sent msg-ack"));
252+
}
253+
};
254+
255+
use std::cmp::Ordering::*;
256+
match msg_seq_id.cmp(&last_seen_seq_id) {
257+
// seq_id for msg is greater than last seen seq_id by one
258+
Greater if (msg_seq_id - last_seen_seq_id == 1) => {
259+
if sender
260+
.send(ZmqInputPortEvent::Message(bytes))
261+
.await
262+
.is_err()
263+
{
264+
#[cfg(feature = "tracing")]
265+
span.in_scope(|| warn!("receiver for input events has closed"));
266+
return;
267+
}
268+
send_ack(msg_seq_id).await;
269+
let _ = connected_ids.insert(output_port_id, msg_seq_id);
270+
}
271+
Equal => {
272+
send_ack(last_seen_seq_id).await;
273+
}
274+
// either the seq_id is greater than the last seen seq_id by more than
275+
// one, or somehow less than the last seen seq_id:
276+
_ => {
277+
send_ack(last_seen_seq_id).await;
278+
}
238279
}
239-
240-
if pub_queue
241-
.send(ZmqTransportEvent::AckMessage(
242-
output_port_id,
243-
input_port_id,
244-
seq_id,
245-
))
246-
.await
247-
.is_err()
248-
{
249-
#[cfg(feature = "tracing")]
250-
span.in_scope(|| warn!("publish channel is closed"));
251-
return;
252-
}
253-
254-
#[cfg(feature = "tracing")]
255-
span.in_scope(|| trace!("sent msg-ack"));
256280
}
257281

258282
Open(..) | Closed => {
@@ -285,13 +309,11 @@ pub fn start_input_worker(
285309
return;
286310
};
287311

288-
let Some(idx) = connected_ids.iter().position(|&id| id == output_port_id) else {
312+
if connected_ids.remove(&output_port_id).is_none() {
289313
#[cfg(feature = "tracing")]
290314
span.in_scope(|| trace!("output port doesn't match any connected port"));
291315
return;
292-
};
293-
294-
connected_ids.swap_remove(idx);
316+
}
295317

296318
if !connected_ids.is_empty() {
297319
return;

0 commit comments

Comments
 (0)