Skip to content

Commit

Permalink
Merge pull request #101 from michieldwitte/multicast_receive_address
Browse files Browse the repository at this point in the history
multicast receive address
  • Loading branch information
Covertness authored Jun 22, 2024
2 parents c120146 + 759258b commit 231bb6e
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 41 deletions.
76 changes: 42 additions & 34 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use alloc::vec::Vec;
use coap_lite::{
block_handler::{extending_splice, BlockValue},
error::HandlingError,
CoapOption, CoapRequest, CoapResponse, MessageClass, MessageType, ObserveOption, Packet,
CoapOption, CoapRequest, CoapResponse, MessageClass, MessageType, ObserveOption, Packet as Message,
RequestType as Method, ResponseType as Status,
};
use core::mem;
Expand Down Expand Up @@ -37,6 +37,12 @@ use tokio::{
use url::Url;
const DEFAULT_RECEIVE_TIMEOUT_SECONDS: u64 = 2; // 2s

#[derive(Debug, Clone)]
pub struct Packet {
pub address: Option<SocketAddr>,
pub message: Message,
}

#[derive(Debug)]
pub enum ObserveMessage {
Terminate,
Expand All @@ -49,7 +55,7 @@ use async_trait::async_trait;
/// timeouts and retries do not need to be implemented by the transport
/// if confirmable messages are sent
pub trait ClientTransport: Send + Sync {
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<usize>;
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<(usize, Option<SocketAddr>)>;
async fn send(&self, buf: &[u8]) -> std::io::Result<usize>;
}

Expand All @@ -60,9 +66,11 @@ trait TransportExt {
impl<T: ClientTransport> TransportExt for T {
async fn receive_packet(&self) -> IoResult<Option<Packet>> {
let mut buf = [0; 1500];
let nread = self.recv(&mut buf).await?;
let parse_opt = Packet::from_bytes(&buf[..nread]).ok();
return Ok(parse_opt);
let (nread, address) = self.recv(&mut buf).await?;
return match Message::from_bytes(&buf[..nread]).ok() {
Some(message) => Ok(Some(Packet {address, message})),
None => Ok(None),
}
}
}

Expand Down Expand Up @@ -163,16 +171,16 @@ async fn receive_loop<T: ClientTransport + 'static>(
transport_instance.send(&ack).await?;
}

let MessageClass::Response(_) = packet.header.code else {
let MessageClass::Response(_) = packet.message.header.code else {
continue;
};

let token = packet.get_token();
let token = packet.message.get_token();
let Some(sender) = transport_sync.get_sender(token).await else {
info!("received unexpected response for token {:?}", &token);
continue;
};
match packet.header.code {
match packet.message.header.code {
MessageClass::Response(_) => {}
m => {
debug!("unknown message type {}", m);
Expand All @@ -191,16 +199,16 @@ async fn receive_loop<T: ClientTransport + 'static>(
}

pub fn parse_for_ack(packet: &Packet) -> Option<Vec<u8>> {
match (packet.header.get_type(), packet.header.code) {
match (packet.message.header.get_type(), packet.message.header.code) {
(MessageType::Confirmable, MessageClass::Response(_)) => Some(make_ack(packet)),
_ => None,
}
}

pub fn make_ack(packet: &Packet) -> Vec<u8> {
let mut ack = Packet::new();
let mut ack = Message::new();
ack.header.set_type(MessageType::Acknowledgement);
ack.header.message_id = packet.header.message_id;
ack.header.message_id = packet.message.header.message_id;
ack.header.code = MessageClass::Empty;
return ack.to_bytes().unwrap();
}
Expand All @@ -226,9 +234,9 @@ impl<T: ClientTransport> Clone for CoapClientTransport<T> {

impl<T: ClientTransport> CoapClientTransport<T> {
pub const DEFAULT_NUM_RETRIES: usize = 5;
async fn establish_receiver_for(&self, msg: &Packet) -> UnboundedReceiver<IoResult<Packet>> {
async fn establish_receiver_for(&self, packet: &Packet) -> UnboundedReceiver<IoResult<Packet>> {
let (tx, rx) = unbounded_channel();
let token = msg.get_token().to_owned();
let token = packet.message.get_token().to_owned();
self.synchronizer.set_sender(token, tx).await;
return rx;
}
Expand All @@ -249,8 +257,8 @@ impl<T: ClientTransport> CoapClientTransport<T> {
return res;
}

fn encode_packet(packet: &Packet) -> IoResult<Vec<u8>> {
packet
fn encode_message(message: &Message) -> IoResult<Vec<u8>> {
message
.to_bytes()
.map_err(|e| std::io::Error::new(ErrorKind::InvalidData, e.to_string()))
}
Expand All @@ -260,7 +268,7 @@ impl<T: ClientTransport> CoapClientTransport<T> {
msg: &Packet,
receiver: &mut UnboundedReceiver<IoResult<Packet>>,
) -> IoResult<Packet> {
let bytes = Self::encode_packet(msg)?;
let bytes = Self::encode_message(&msg.message)?;
self.transport.send(&bytes).await?;
let try_receive: Result<Option<Result<Packet, Error>>, tokio::time::error::Elapsed> =
timeout(self.timeout, receiver.recv()).await;
Expand All @@ -275,7 +283,7 @@ impl<T: ClientTransport> CoapClientTransport<T> {
packet: &Packet,
receiver: &mut UnboundedReceiver<IoResult<Packet>>,
) -> IoResult<Packet> {
if packet.header.get_type() == MessageType::Confirmable {
if packet.message.header.get_type() == MessageType::Confirmable {
return self.try_send_confirmable_message(&packet, receiver).await;
} else {
return self
Expand All @@ -289,7 +297,7 @@ impl<T: ClientTransport> CoapClientTransport<T> {
let result = self
.do_request_response_for_packet_inner(packet, &mut receiver)
.await;
self.synchronizer.remove_sender(packet.get_token()).await;
self.synchronizer.remove_sender(packet.message.get_token()).await;
result
}

Expand All @@ -309,11 +317,11 @@ pub struct UdpTransport {
}
#[async_trait]
impl ClientTransport for UdpTransport {
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<usize> {
self.socket
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<(usize, Option<SocketAddr>)> {
let (read, addr) = self.socket
.recv_from(buf)
.await
.map(|(recv_size, _addr)| recv_size)
.await?;
return Ok((read, Some(addr)));
}
async fn send(&self, buf: &[u8]) -> std::io::Result<usize> {
self.socket.send_to(buf, self.peer_addr).await
Expand Down Expand Up @@ -481,7 +489,7 @@ impl UdpCoAPClient {
/// client.send_all_coap(&request, segment).await.unwrap();
/// loop {
/// let recv_packet = receiver.receive().await.unwrap();
/// assert_eq!(recv_packet.payload, b"test-echo".to_vec());
/// assert_eq!(recv_packet.message.payload, b"test-echo".to_vec());
/// }
/// }
/// ```
Expand Down Expand Up @@ -591,7 +599,7 @@ impl<T: ClientTransport + 'static> CoAPClient<T> {
}

/// Execute a single request (GET, POST, PUT, DELETE) with a coap url and a specfic timeout
/// using udp
/// using udp
pub async fn request_with_timeout(
url: &str,
method: Method,
Expand Down Expand Up @@ -619,7 +627,7 @@ impl<T: ClientTransport + 'static> CoAPClient<T> {
self.receive(&mut request).await
}

pub async fn observe<H: FnMut(Packet) + Send + 'static>(
pub async fn observe<H: FnMut(Message) + Send + 'static>(
&self,
resource_path: &str,
handler: H,
Expand All @@ -634,7 +642,7 @@ impl<T: ClientTransport + 'static> CoAPClient<T> {
/// Observe a resource with the handler and specified timeout using the given transport.
/// Use the oneshot sender to cancel observation. If this sender is dropped without explicitly
/// cancelling it, the observation will continue forever.
pub async fn observe_with_timeout<H: FnMut(Packet) + Send + 'static>(
pub async fn observe_with_timeout<H: FnMut(Message) + Send + 'static>(
&mut self,
resource_path: &str,
handler: H,
Expand All @@ -651,7 +659,7 @@ impl<T: ClientTransport + 'static> CoAPClient<T> {
/// Use this method if you need to set some specific options in your
/// requests. This method will add observe flags and a message id as a fallback
/// Use this method if you plan on re-using the same client for requests
pub async fn observe_with<H: FnMut(Packet) + Send + 'static>(
pub async fn observe_with<H: FnMut(Message) + Send + 'static>(
&self,
request: CoapRequest<SocketAddr>,
mut handler: H,
Expand Down Expand Up @@ -723,17 +731,17 @@ impl<T: ClientTransport + 'static> CoAPClient<T> {

let _ = self
.transport
.do_request_response_for_packet(&deregister_packet.message)
.do_request_response_for_packet(&Packet {address:None, message: deregister_packet.message})
.await;
}

async fn receive_and_handle_message_observe<H: FnMut(Packet) + Send + 'static>(
async fn receive_and_handle_message_observe<H: FnMut(Message) + Send + 'static>(
socket_result: IoResult<Packet>,
handler: &mut H,
) {
match socket_result {
Ok(packet) => {
handler(packet);
handler(packet.message);
}
Err(e) => match e.kind() {
ErrorKind::WouldBlock => {
Expand All @@ -755,9 +763,9 @@ impl<T: ClientTransport + 'static> CoAPClient<T> {
) -> IoResult<CoapResponse> {
let response = self
.transport
.do_request_response_for_packet(&request.message)
.do_request_response_for_packet(&Packet {address:None, message:request.message.to_owned()})
.await?;
Ok(CoapResponse { message: response })
Ok(CoapResponse { message: response.message })
}

/// low-level method to send a a request supporting block1 option based on
Expand Down Expand Up @@ -1214,7 +1222,7 @@ mod test {

#[async_trait]
impl ClientTransport for FaultyUdp {
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<usize> {
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<(usize, Option<SocketAddr>)> {
self.udp.recv(buf).await
}

Expand Down Expand Up @@ -1381,7 +1389,7 @@ mod test {
}
#[async_trait]
impl ClientTransport for FaultyReceiver {
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<usize> {
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<(usize, Option<SocketAddr>)> {
let mut mutex = self.should_fail.lock().await;
tokio::select! {
e = mutex.deref_mut() => {
Expand Down
10 changes: 5 additions & 5 deletions src/dtls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ pub struct DtlsResponse {

#[async_trait]
impl ClientTransport for DtlsConnection {
async fn recv(&self, buf: &mut [u8]) -> IoResult<usize> {
async fn recv(&self, buf: &mut [u8]) -> IoResult<(usize, Option<SocketAddr>)> {
let read = self
.conn
.read(buf, None)
.await
.map_err(|e| Error::new(ErrorKind::Other, e))?;
return Ok(read);
return Ok((read, self.conn.remote_addr()));
}

async fn send(&self, buf: &[u8]) -> IoResult<usize> {
Expand Down Expand Up @@ -108,7 +108,7 @@ pub struct DtlsConnection {
}

impl DtlsConnection {
/// Creates a new DTLS connection from a given connection. This connection can be
/// Creates a new DTLS connection from a given connection. This connection can be
/// a tokio UDP socket or a user-created struct implementing Conn, Send, and Sync
///
///
Expand Down Expand Up @@ -187,7 +187,7 @@ mod test {
use rcgen::KeyPair;
use std::fs::File;
use std::io::{BufReader, Read};
use std::net::{SocketAddr, ToSocketAddrs};
use std::net::{IpAddr, Ipv4Addr, SocketAddr, ToSocketAddrs};
use std::sync::atomic::AtomicBool;
use tokio::sync::mpsc;
use tokio::time::sleep;
Expand Down Expand Up @@ -635,7 +635,7 @@ mod test {
todo!("not needed");
}
fn remote_addr(&self) -> Option<SocketAddr> {
todo!("not needed")
Some(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0))
}
async fn close(&self) -> WebrtcResult<()> {
Ok(self.0.close().await?)
Expand Down
4 changes: 2 additions & 2 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ pub mod test {
let mut receiver = client.create_receiver_for(&request).await;
client.send_all_coap(&request, segment).await.unwrap();
let recv_packet = receiver.receive().await.unwrap();
assert_eq!(recv_packet.payload, b"test-echo".to_vec());
assert_eq!(recv_packet.message.payload, b"test-echo".to_vec());
}

//This test right now does not work on windows
Expand Down Expand Up @@ -872,7 +872,7 @@ pub mod test {
let mut receiver = client.create_receiver_for(&request).await;
client.send_all_coap(&request, segment).await.unwrap();
let recv_packet = receiver.receive().await.unwrap();
assert_eq!(recv_packet.payload, b"test-echo".to_vec());
assert_eq!(recv_packet.message.payload, b"test-echo".to_vec());
}

#[test]
Expand Down

0 comments on commit 231bb6e

Please sign in to comment.