Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ tower = "0.5"
tower-http = "0.6"

# p2p
discv5 = "0.10"
discv5 = { git = "https://github.com/sigp/discv5", rev = "7663c00" }
if-addrs = "0.14"

# rpc
Expand Down
210 changes: 137 additions & 73 deletions crates/net/discv4/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ use secp256k1::SecretKey;
use std::{
cell::RefCell,
collections::{btree_map, hash_map::Entry, BTreeMap, HashMap, VecDeque},
fmt,
future::poll_fn,
io,
fmt, io,
net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4},
pin::Pin,
rc::Rc,
Expand Down Expand Up @@ -243,17 +241,56 @@ impl Discv4 {
/// ```
pub async fn bind(
local_address: SocketAddr,
local_node_record: NodeRecord,
secret_key: SecretKey,
config: Discv4Config,
) -> io::Result<(Self, Discv4Service)> {
let socket = Arc::new(UdpSocket::bind(local_address).await?);
trace!(target: "discv4", local_addr=?socket.local_addr(), "opened UDP socket");
let (tx, rx) = mpsc::channel(config.udp_ingress_message_buffer);

Self::bind_with_socket(socket, Some(tx), rx, local_node_record, secret_key, config)
}

/// Creates a new `Discv4` instance using a pre-bound shared socket. No receive loop is
/// spawned; instead returns an [`IngressHandler`] that should be used to forward raw packets
/// received by the socket owner (e.g. discv5 unrecognized frames).
pub fn bind_shared(
socket: Arc<UdpSocket>,
local_node_record: NodeRecord,
secret_key: SecretKey,
config: Discv4Config,
) -> io::Result<(Self, Discv4Service, IngressHandler)> {
let (tx, rx) = mpsc::channel(config.udp_ingress_message_buffer);
let local_id = local_node_record.id;
let (discv4, service) =
Self::bind_with_socket(socket, None, rx, local_node_record, secret_key, config)?;

let handler = IngressHandler::new(tx, local_id);

Ok((discv4, service, handler))
}

fn bind_with_socket(
socket: Arc<UdpSocket>,
ingress_tx: Option<IngressSender>,
ingress_rx: IngressReceiver,
mut local_node_record: NodeRecord,
secret_key: SecretKey,
config: Discv4Config,
) -> io::Result<(Self, Discv4Service)> {
let socket = UdpSocket::bind(local_address).await?;
let local_addr = socket.local_addr()?;
local_node_record.udp_port = local_addr.port();
trace!(target: "discv4", ?local_addr,"opened UDP socket");

let mut service =
Discv4Service::new(socket, local_addr, local_node_record, secret_key, config);
let mut service = Discv4Service::new(
socket,
ingress_tx,
ingress_rx,
local_addr,
local_node_record,
secret_key,
config,
);

// resolve the external address immediately
service.resolve_external_ip();
Expand Down Expand Up @@ -520,20 +557,25 @@ pub struct Discv4Service {

impl Discv4Service {
/// Create a new instance for a bound [`UdpSocket`].
///
/// If `ingress_tx` is `Some`, the receive loop is spawned to read from the socket. If `None`,
/// the caller feeds packets into `ingress_rx` externally (shared socket mode).
pub(crate) fn new(
socket: UdpSocket,
socket: Arc<UdpSocket>,
ingress_tx: Option<IngressSender>,
ingress_rx: IngressReceiver,
local_address: SocketAddr,
local_node_record: NodeRecord,
secret_key: SecretKey,
config: Discv4Config,
) -> Self {
let socket = Arc::new(socket);
let (ingress_tx, ingress_rx) = mpsc::channel(config.udp_ingress_message_buffer);
let (egress_tx, egress_rx) = mpsc::channel(config.udp_egress_message_buffer);
let mut tasks = JoinSet::<()>::new();

let udp = Arc::clone(&socket);
tasks.spawn(receive_loop(udp, ingress_tx, local_node_record.id));
if let Some(ingress_tx) = ingress_tx {
let udp = Arc::clone(&socket);
tasks.spawn(receive_loop(udp, ingress_tx, local_node_record.id));
}

let udp = Arc::clone(&socket);
tasks.spawn(send_loop(udp, egress_rx));
Expand Down Expand Up @@ -947,7 +989,7 @@ impl Discv4Service {
let key = kad_key(peer_id);
match self.kbuckets.entry(&key) {
BucketEntry::Present(entry, _) => Some(f(entry.value())),
BucketEntry::Pending(mut entry, _) => Some(f(entry.value())),
BucketEntry::Pending(entry, _) => Some(f(entry.value())),
_ => None,
}
}
Expand All @@ -973,7 +1015,9 @@ impl Discv4Service {
kbucket::Entry::Present(mut entry, _) => {
entry.value_mut().update_with_enr(last_enr_seq)
}
kbucket::Entry::Pending(mut entry, _) => entry.value().update_with_enr(last_enr_seq),
kbucket::Entry::Pending(mut entry, _) => {
entry.value_mut().update_with_enr(last_enr_seq)
}
_ => return,
};

Expand Down Expand Up @@ -1025,8 +1069,8 @@ impl Discv4Service {
}
kbucket::Entry::Pending(mut entry, mut status) => {
// endpoint is now proven
entry.value().establish_proof();
entry.value().update_with_enr(last_enr_seq);
entry.value_mut().establish_proof();
entry.value_mut().update_with_enr(last_enr_seq);

if !status.is_connected() {
status.state = ConnectionState::Connected;
Expand Down Expand Up @@ -1158,7 +1202,7 @@ impl Discv4Service {
} else {
is_proven = entry.value().has_endpoint_proof;
}
entry.value().update_with_enr(ping.enr_sq)
entry.value_mut().update_with_enr(ping.enr_sq)
}
kbucket::Entry::Absent(entry) => {
let mut node = NodeEntry::new(record);
Expand Down Expand Up @@ -1388,7 +1432,7 @@ impl Discv4Service {
(entry.value().record, id)
}
kbucket::Entry::Pending(mut entry, _) => {
let id = entry.value().update_with_fork_id(fork_id);
let id = entry.value_mut().update_with_fork_id(fork_id);
(entry.value().record, id)
}
_ => return,
Expand Down Expand Up @@ -1538,7 +1582,7 @@ impl Discv4Service {
}
}
}
BucketEntry::Pending(mut entry, _) => {
BucketEntry::Pending(entry, _) => {
if entry.value().has_endpoint_proof {
if entry
.value()
Expand Down Expand Up @@ -1642,7 +1686,7 @@ impl Discv4Service {
entry.value().find_node_failures
}
kbucket::Entry::Pending(mut entry, _) => {
entry.value().inc_failed_request();
entry.value_mut().inc_failed_request();
entry.value().find_node_failures
}
_ => continue,
Expand Down Expand Up @@ -1962,80 +2006,100 @@ const MAX_INCOMING_PACKETS_PER_MINUTE_BY_IP: usize = 60usize;

/// Continuously awaits new incoming messages and sends them back through the channel.
///
/// The receive loop enforce primitive rate limiting for ips to prevent message spams from
/// individual IPs
/// The receive loop enforces primitive rate limiting for IPs to prevent message spams from
/// individual IPs.
pub(crate) async fn receive_loop(udp: Arc<UdpSocket>, tx: IngressSender, local_id: PeerId) {
let send = |event: IngressEvent| async {
let _ = tx.send(event).await.map_err(|err| {
debug!(
target: "discv4",
%err,
"failed send incoming packet",
)
});
};

let mut cache = ReceiveCache::default();

// tick at half the rate of the limit
let tick = MAX_INCOMING_PACKETS_PER_MINUTE_BY_IP / 2;
let mut interval = tokio::time::interval(Duration::from_secs(tick as u64));

let mut handler = IngressHandler::new(tx, local_id);
let mut buf = [0; MAX_PACKET_SIZE];
loop {
let res = udp.recv_from(&mut buf).await;
match res {
Err(err) => {
debug!(target: "discv4", %err, "Failed to read datagram.");
send(IngressEvent::RecvError(err)).await;
handler.send(IngressEvent::RecvError(err)).await;
}
Ok((read, remote_addr)) => {
// rate limit incoming packets by IP
if cache.inc_ip(remote_addr.ip()) > MAX_INCOMING_PACKETS_PER_MINUTE_BY_IP {
trace!(target: "discv4", ?remote_addr, "Too many incoming packets from IP.");
continue
}
handler.handle_packet(&buf[..read], remote_addr).await;
}
}
}
}

let packet = &buf[..read];
match Message::decode(packet) {
Ok(packet) => {
if packet.node_id == local_id {
// received our own message
debug!(target: "discv4", ?remote_addr, "Received own packet.");
continue
}
/// Handles decoding, rate-limiting, and deduplication of incoming discv4 packets.
///
/// Used by both the standalone receive loop and the shared-port mode via
/// [`Discv4::bind_shared`].
#[derive(Debug)]
pub struct IngressHandler {
tx: IngressSender,
local_id: PeerId,
tick: usize,
tick_interval: Duration,
cache: ReceiveCache,
last_tick: Instant,
}

// skip if we've already received the same packet
if cache.contains_packet(packet.hash) {
debug!(target: "discv4", ?remote_addr, "Received duplicate packet.");
continue
}
impl IngressHandler {
fn new(tx: IngressSender, local_id: PeerId) -> Self {
let tick = MAX_INCOMING_PACKETS_PER_MINUTE_BY_IP / 2;
Self {
tx,
local_id,
tick,
tick_interval: Duration::from_secs(tick as u64),
cache: ReceiveCache::default(),
last_tick: Instant::now(),
}
}

send(IngressEvent::Packet(remote_addr, packet)).await;
}
Err(err) => {
trace!(target: "discv4", %err,"Failed to decode packet");
send(IngressEvent::BadPacket(remote_addr, err, packet.to_vec())).await
}
}
}
async fn send(&self, event: IngressEvent) {
let _ = self.tx.send(event).await.map_err(|err| {
debug!(target: "discv4", %err, "failed send incoming packet");
});
}

/// Handles an incoming raw packet: decodes, rate-limits, deduplicates, and forwards to the
/// discv4 service. Used in shared-port mode to process unrecognized frames from discv5.
pub async fn handle_packet(&mut self, data: &[u8], src: SocketAddr) {
if self.last_tick.elapsed() >= self.tick_interval {
self.cache.tick_ips(self.tick);
self.last_tick = Instant::now();
}

// reset the tracked ips if the interval has passed
if poll_fn(|cx| match interval.poll_tick(cx) {
Poll::Ready(_) => Poll::Ready(true),
Poll::Pending => Poll::Ready(false),
})
.await
{
cache.tick_ips(tick);
// rate limit incoming packets by IP
if self.cache.inc_ip(src.ip()) > MAX_INCOMING_PACKETS_PER_MINUTE_BY_IP {
trace!(target: "discv4", ?src, "Too many incoming packets from IP.");
return
}

let event = match Message::decode(data) {
Ok(packet) => {
if packet.node_id == self.local_id {
debug!(target: "discv4", ?src, "Received own packet.");
return
}

if self.cache.contains_packet(packet.hash) {
debug!(target: "discv4", ?src, "Received duplicate packet.");
return
}

IngressEvent::Packet(src, packet)
}
Err(err) => {
trace!(target: "discv4", %err, "Failed to decode packet");
IngressEvent::BadPacket(src, err, data.to_vec())
}
};

self.send(event).await;
}
}

/// A cache for received packets and their source address.
///
/// This is used to discard duplicated packets and rate limit messages from the same source.
#[derive(Debug)]
struct ReceiveCache {
/// keeps track of how many messages we've received from a given IP address since the last
/// tick.
Expand Down
Loading
Loading