diff --git a/crates/core/src/rlp/decode.rs b/crates/core/src/rlp/decode.rs index 1994155a221..3f11d3e384c 100644 --- a/crates/core/src/rlp/decode.rs +++ b/crates/core/src/rlp/decode.rs @@ -317,7 +317,7 @@ impl RLPDecode for (T1, T2, T3) { /// - A boolean indicating if the item is a list or not. /// - The payload of the item, without its prefix. /// - The remaining bytes after the item. -pub(crate) fn decode_rlp_item(data: &[u8]) -> Result<(bool, &[u8], &[u8]), RLPDecodeError> { +pub fn decode_rlp_item(data: &[u8]) -> Result<(bool, &[u8], &[u8]), RLPDecodeError> { if data.is_empty() { return Err(RLPDecodeError::InvalidLength); } diff --git a/crates/net/src/discv4.rs b/crates/net/src/discv4.rs index 220e18ee3a7..807b6b5d379 100644 --- a/crates/net/src/discv4.rs +++ b/crates/net/src/discv4.rs @@ -66,6 +66,10 @@ impl Message { let packet_type = encoded_msg[packet_index]; let msg = &encoded_msg[packet_index + 1..]; match packet_type { + 0x01 => { + let ping = PingMessage::decode(msg)?; + Ok(Message::Ping(ping)) + } 0x02 => { let pong = PongMessage::decode(msg)?; Ok(Message::Pong(pong)) @@ -82,6 +86,16 @@ pub(crate) struct Endpoint { pub tcp_port: u16, } +impl RLPEncode for Endpoint { + fn encode(&self, buf: &mut dyn BufMut) { + structs::Encoder::new(buf) + .encode_field(&self.ip) + .encode_field(&self.udp_port) + .encode_field(&self.tcp_port) + .finish(); + } +} + impl RLPDecode for Endpoint { fn decode_unfinished(rlp: &[u8]) -> Result<(Self, &[u8]), RLPDecodeError> { let decoder = Decoder::new(rlp)?; @@ -98,16 +112,6 @@ impl RLPDecode for Endpoint { } } -impl RLPEncode for Endpoint { - fn encode(&self, buf: &mut dyn BufMut) { - structs::Encoder::new(buf) - .encode_field(&self.ip) - .encode_field(&self.udp_port) - .encode_field(&self.tcp_port) - .finish(); - } -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(crate) struct PingMessage { /// The Ping message version. Should be set to 4, but mustn't be enforced. @@ -156,6 +160,28 @@ impl RLPEncode for PingMessage { } } +impl RLPDecode for PingMessage { + fn decode_unfinished(rlp: &[u8]) -> Result<(Self, &[u8]), RLPDecodeError> { + let decoder = Decoder::new(rlp)?; + let (version, decoder): (u8, Decoder) = decoder.decode_field("version")?; + let (from, decoder) = decoder.decode_field("from")?; + let (to, decoder) = decoder.decode_field("to")?; + let (expiration, decoder) = decoder.decode_field("expiration")?; + let (enr_seq, decoder) = decoder.decode_optional_field(); + + let ping = PingMessage { + version, + from, + to, + expiration, + enr_seq, + }; + + let remaining = decoder.finish()?; + Ok((ping, remaining)) + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(crate) struct PongMessage { /// The endpoint of the receiver. @@ -319,6 +345,65 @@ mod tests { .collect() } + #[test] + fn test_decode_ping_message() { + let expiration: u64 = 17195043770; + + let from = Endpoint { + ip: IpAddr::from_str("1.2.3.4").unwrap(), + udp_port: 1613, + tcp_port: 6363, + }; + let to = Endpoint { + ip: IpAddr::from_str("255.255.2.5").unwrap(), + udp_port: 3063, + tcp_port: 0, + }; + + let msg = Message::Ping(PingMessage::new(from, to, expiration)); + + let key_bytes = + H256::from_str("577d8278cc7748fad214b5378669b420f8221afb45ce930b7f22da49cbc545f3") + .unwrap(); + let signer = SigningKey::from_slice(key_bytes.as_bytes()).unwrap(); + + let mut buf = Vec::new(); + + msg.encode_with_header(&mut buf, signer.clone()); + let result = Message::decode_with_header(&buf).expect("Failed decoding PingMessage"); + assert_eq!(result, msg); + } + + #[test] + fn test_decode_ping_message_with_enr_seq() { + let expiration: u64 = 17195043770; + + let from = Endpoint { + ip: IpAddr::from_str("1.2.3.4").unwrap(), + udp_port: 1613, + tcp_port: 6363, + }; + let to = Endpoint { + ip: IpAddr::from_str("255.255.2.5").unwrap(), + udp_port: 3063, + tcp_port: 0, + }; + + let enr_seq = 1704896740573; + let msg = Message::Ping(PingMessage::new(from, to, expiration).with_enr_seq(enr_seq)); + + let key_bytes = + H256::from_str("577d8278cc7748fad214b5378669b420f8221afb45ce930b7f22da49cbc545f3") + .unwrap(); + let signer = SigningKey::from_slice(key_bytes.as_bytes()).unwrap(); + + let mut buf = Vec::new(); + + msg.encode_with_header(&mut buf, signer.clone()); + let result = Message::decode_with_header(&buf).expect("Failed decoding PingMessage"); + assert_eq!(result, msg); + } + #[test] fn test_decode_endpoint() { let endpoint = Endpoint {