From af6d207c7ea0899f8b0a90e2bbea781ed80de8bd Mon Sep 17 00:00:00 2001 From: Samuel Sarle Date: Wed, 4 Dec 2024 22:06:41 +0200 Subject: [PATCH] Shortcut sending data to network if input port is reachable locally --- lib/protoflow-zeromq/src/input_port.rs | 20 ++++ lib/protoflow-zeromq/src/lib.rs | 2 +- lib/protoflow-zeromq/src/output_port.rs | 24 ++++- lib/protoflow-zeromq/src/socket.rs | 124 ++++++++---------------- 4 files changed, 86 insertions(+), 84 deletions(-) diff --git a/lib/protoflow-zeromq/src/input_port.rs b/lib/protoflow-zeromq/src/input_port.rs index e84f01bb..18544c1e 100644 --- a/lib/protoflow-zeromq/src/input_port.rs +++ b/lib/protoflow-zeromq/src/input_port.rs @@ -76,6 +76,14 @@ impl ZmqInputPortState { Closed => PortState::Closed, } } + + pub async fn event_sender(&self) -> Option> { + use ZmqInputPortState::*; + match self { + Open(_, sender) | Connected(.., sender, _) => Some(sender.clone()), + Closed => None, + } + } } fn input_topics(id: InputPortID) -> Vec { @@ -86,6 +94,18 @@ fn input_topics(id: InputPortID) -> Vec { ] } +pub async fn input_port_event_sender( + inputs: &RwLock>>, + id: InputPortID, +) -> Option> { + if let Some(input_state) = inputs.read().await.get(&id) { + let input_state = input_state.read().await; + input_state.event_sender().await + } else { + None + } +} + pub fn start_input_worker( transport: &ZmqTransport, input_port_id: InputPortID, diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index 34f7cac7..105a5fa7 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -100,7 +100,7 @@ impl ZmqTransport { inputs, }; - start_pub_socket_worker(psock, pub_queue_recv); + start_pub_socket_worker(&transport, psock, pub_queue_recv); start_sub_socket_worker(&transport, ssock, sub_queue_recv); transport diff --git a/lib/protoflow-zeromq/src/output_port.rs b/lib/protoflow-zeromq/src/output_port.rs index 2de704f6..843f6ad0 100644 --- a/lib/protoflow-zeromq/src/output_port.rs +++ b/lib/protoflow-zeromq/src/output_port.rs @@ -2,7 +2,7 @@ use crate::{subscribe_topics, unsubscribe_topics, ZmqTransport, ZmqTransportEvent}; use protoflow_core::{ - prelude::{fmt, format, vec, Bytes, String, ToString, Vec}, + prelude::{fmt, format, vec, BTreeMap, Bytes, String, ToString, Vec}, InputPortID, OutputPortID, PortError, PortState, }; use tokio::sync::{ @@ -19,7 +19,7 @@ pub enum ZmqOutputPortRequest { Send(Bytes), } -const DEFAULT_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(500); +const DEFAULT_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(200); const DEFAULT_MAX_RETRIES: u64 = 10; #[derive(Clone, Debug)] @@ -65,6 +65,14 @@ impl ZmqOutputPortState { Closed => PortState::Closed, } } + + pub async fn event_sender(&self) -> Option> { + use ZmqOutputPortState::*; + match self { + Open(.., sender) | Connected(.., sender, _) => Some(sender.clone()), + Closed => None, + } + } } fn output_topics(source: OutputPortID, target: InputPortID) -> Vec { @@ -75,6 +83,18 @@ fn output_topics(source: OutputPortID, target: InputPortID) -> Vec { ] } +pub async fn output_port_event_sender( + outputs: &RwLock>>, + id: OutputPortID, +) -> Option> { + if let Some(output_state) = outputs.read().await.get(&id) { + let output_state = output_state.read().await; + output_state.event_sender().await + } else { + None + } +} + pub fn start_output_worker( transport: &ZmqTransport, output_port_id: OutputPortID, diff --git a/lib/protoflow-zeromq/src/socket.rs b/lib/protoflow-zeromq/src/socket.rs index 6691e652..cef3d2f1 100644 --- a/lib/protoflow-zeromq/src/socket.rs +++ b/lib/protoflow-zeromq/src/socket.rs @@ -1,6 +1,9 @@ // This is free and unencumbered software released into the public domain. -use crate::{ZmqInputPortState, ZmqOutputPortState, ZmqTransport, ZmqTransportEvent}; +use crate::{ + input_port_event_sender, output_port_event_sender, ZmqInputPortState, ZmqOutputPortState, + ZmqTransport, ZmqTransportEvent, +}; use protoflow_core::{ prelude::{BTreeMap, String, Vec}, InputPortID, OutputPortID, PortError, @@ -18,11 +21,17 @@ pub enum ZmqSubscriptionRequest { } #[cfg(feature = "tracing")] -use tracing::{error, trace, trace_span}; +use tracing::{debug, error, trace, trace_span, warn}; -pub fn start_pub_socket_worker(psock: zeromq::PubSocket, pub_queue: Receiver) { +pub fn start_pub_socket_worker( + transport: &ZmqTransport, + psock: zeromq::PubSocket, + pub_queue: Receiver, +) { #[cfg(feature = "tracing")] let span = trace_span!("ZmqTransport::pub_socket"); + let outputs = transport.outputs.clone(); + let inputs = transport.inputs.clone(); let mut psock = psock; let mut pub_queue = pub_queue; tokio::task::spawn(async move { @@ -30,6 +39,27 @@ pub fn start_pub_socket_worker(psock: zeromq::PubSocket, pub_queue: Receiver { + input_port_event_sender(&inputs, id).await + } + AckConnection(id, _) | AckMessage(id, ..) => { + output_port_event_sender(&outputs, id).await + } + CloseInput(..) => None, + }; + + if let Some(sender) = shortcut_sender { + #[cfg(feature = "tracing")] + span.in_scope(|| debug!("attempting to shortcut send directly to target port")); + if sender.send(event.clone()).await.is_ok() { + continue; + } + #[cfg(feature = "tracing")] + span.in_scope(|| warn!("failed to send message with shortcut, sending to socket")); + } + if let Err(err) = psock.send(event.into()).await { #[cfg(feature = "tracing")] span.in_scope(|| error!(?err, "failed to send message")); @@ -126,89 +156,21 @@ async fn handle_zmq_msg( use ZmqTransportEvent::*; match event { // input ports - Connect(_, input_port_id) => { - let sender = { - let inputs = inputs.read().await; - let Some(input) = inputs.get(&input_port_id) else { - return Err(PortError::Invalid(input_port_id.into())); - }; - let input = input.read().await; - - use ZmqInputPortState::*; - match &*input { - Closed => return Err(PortError::Invalid(input_port_id.into())), - Open(.., sender) | Connected(.., sender, _) => sender.clone(), - } - }; - - sender.send(event).await.map_err(|_| PortError::Closed) - } - Message(_, input_port_id, _, _) => { - let sender = { - let inputs = inputs.read().await; - let Some(input) = inputs.get(&input_port_id) else { - return Err(PortError::Invalid(input_port_id.into())); - }; - - let input = input.read().await; - let ZmqInputPortState::Connected(_, _, _, sender, _) = &*input else { - return Err(PortError::Invalid(input_port_id.into())); - }; - - sender.clone() - }; - - sender.send(event).await.map_err(|_| PortError::Closed) - } - CloseOutput(_, input_port_id) => { - let sender = { - let inputs = inputs.read().await; - let Some(input) = inputs.get(&input_port_id) else { - return Err(PortError::Invalid(input_port_id.into())); - }; - let input = input.read().await; - - use ZmqInputPortState::*; - match &*input { - Closed => return Err(PortError::Invalid(input_port_id.into())), - Open(.., sender) | Connected(.., sender, _) => sender.clone(), - } - }; + Connect(_, input_port_id) + | Message(_, input_port_id, _, _) + | CloseOutput(_, input_port_id) => { + let sender = input_port_event_sender(inputs, input_port_id) + .await + .ok_or_else(|| PortError::Invalid(input_port_id.into()))?; sender.send(event).await.map_err(|_| PortError::Closed) } // output ports - AckConnection(output_port_id, _) => { - let sender = { - let outputs = outputs.read().await; - let Some(output) = outputs.get(&output_port_id) else { - return Err(PortError::Invalid(output_port_id.into())); - }; - let output = output.read().await; - - let ZmqOutputPortState::Open(.., sender) = &*output else { - return Err(PortError::Invalid(output_port_id.into())); - }; - - sender.clone() - }; - - sender.send(event).await.map_err(|_| PortError::Closed) - } - AckMessage(output_port_id, _, _) => { - let sender = { - let outputs = outputs.read().await; - let Some(output) = outputs.get(&output_port_id) else { - return Err(PortError::Invalid(output_port_id.into())); - }; - let output = output.read().await; - let ZmqOutputPortState::Connected(_, sender, _) = &*output else { - return Err(PortError::Invalid(output_port_id.into())); - }; - - sender.clone() - }; + AckConnection(output_port_id, _) | AckMessage(output_port_id, _, _) => { + let sender = output_port_event_sender(outputs, output_port_id) + .await + .ok_or_else(|| PortError::Invalid(output_port_id.into()))?; sender.send(event).await.map_err(|_| PortError::Closed) }