From d9b00ab5f352f0dd95b35b3f91f785bb9952af38 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Thu, 16 Jun 2022 18:44:06 +0200 Subject: [PATCH] feat(kad): add limit option for getting providers --- protocols/kad/src/behaviour.rs | 96 +++++++++++++++++++++-------- protocols/kad/src/behaviour/test.rs | 88 +++++++++++++++++++++++++- protocols/kad/src/lib.rs | 2 +- 3 files changed, 158 insertions(+), 28 deletions(-) diff --git a/protocols/kad/src/behaviour.rs b/protocols/kad/src/behaviour.rs index d5b322e096a7..d63a67dd89cd 100644 --- a/protocols/kad/src/behaviour.rs +++ b/protocols/kad/src/behaviour.rs @@ -920,17 +920,23 @@ where /// /// The result of this operation is delivered in a /// reported via [`KademliaEvent::OutboundQueryCompleted{QueryResult::GetProviders}`]. - pub fn get_providers(&mut self, key: record::Key) -> QueryId { + pub fn get_providers(&mut self, key: record::Key, limit: ProviderLimit) -> QueryId { let providers = self .store .providers(&key) .into_iter() .filter(|p| !p.is_expired(Instant::now())) - .map(|p| p.provider) - .collect(); + .map(|p| p.provider); + + let providers = match limit { + ProviderLimit::None => providers.collect(), + ProviderLimit::N(limit) => providers.take(limit.into()).collect(), + }; + let info = QueryInfo::GetProviders { key: key.clone(), providers, + limit, }; let target = kbucket::Key::new(key); let peers = self.kbuckets.closest_keys(&target); @@ -1259,17 +1265,19 @@ where })), }), - QueryInfo::GetProviders { key, providers } => { - Some(KademliaEvent::OutboundQueryCompleted { - id: query_id, - stats: result.stats, - result: QueryResult::GetProviders(Ok(GetProvidersOk { - key, - providers, - closest_peers: result.peers.collect(), - })), - }) - } + QueryInfo::GetProviders { + key, + providers, + limit: _, + } => Some(KademliaEvent::OutboundQueryCompleted { + id: query_id, + stats: result.stats, + result: QueryResult::GetProviders(Ok(GetProvidersOk { + key, + providers, + closest_peers: result.peers.collect(), + })), + }), QueryInfo::AddProvider { context, @@ -1554,17 +1562,19 @@ where })), }), - QueryInfo::GetProviders { key, providers } => { - Some(KademliaEvent::OutboundQueryCompleted { - id: query_id, - stats: result.stats, - result: QueryResult::GetProviders(Err(GetProvidersError::Timeout { - key, - providers, - closest_peers: result.peers.collect(), - })), - }) - } + QueryInfo::GetProviders { + key, + providers, + limit: _, + } => Some(KademliaEvent::OutboundQueryCompleted { + id: query_id, + stats: result.stats, + result: QueryResult::GetProviders(Err(GetProvidersError::Timeout { + key, + providers, + closest_peers: result.peers.collect(), + })), + }), } } @@ -2332,6 +2342,31 @@ where { query.on_success(&peer_id, vec![]) } + + if let QueryInfo::GetProviders { + key: _, + providers, + limit, + } = &query.inner.info + { + match limit { + ProviderLimit::None => { + // No limit, so wait for enough peers to respond. + } + ProviderLimit::N(n) => { + // Check if we have enough providers. + if usize::from(*n) <= providers.len() { + debug!( + "found enough providers {}/{}, finishing", + providers.len(), + n + ); + query.finish(); + } + } + } + } + if self.connected_peers.contains(&peer_id) { self.queued_events .push_back(NetworkBehaviourAction::NotifyHandler { @@ -2364,6 +2399,15 @@ where } } +/// Specifies the number of provider records fetched. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum ProviderLimit { + /// No limit on the number of records. + None, + /// Finishes the query as soon as this many records have been found. + N(NonZeroUsize), +} + /// A quorum w.r.t. the configured replication factor specifies the minimum /// number of distinct nodes that must be successfully contacted in order /// for a query to succeed. @@ -2863,6 +2907,8 @@ pub enum QueryInfo { key: record::Key, /// The found providers. providers: HashSet, + /// The limit of how many providers to find, + limit: ProviderLimit, }, /// A (repeated) query initiated by [`Kademlia::start_providing`]. diff --git a/protocols/kad/src/behaviour/test.rs b/protocols/kad/src/behaviour/test.rs index 1f67be5a19d3..6da316bf47bc 100644 --- a/protocols/kad/src/behaviour/test.rs +++ b/protocols/kad/src/behaviour/test.rs @@ -1333,7 +1333,7 @@ fn network_behaviour_inject_address_change() { } #[test] -fn get_providers() { +fn get_providers_single() { fn prop(key: record::Key) { let (_, mut single_swarm) = build_node(); single_swarm @@ -1352,7 +1352,9 @@ fn get_providers() { } }); - let query_id = single_swarm.behaviour_mut().get_providers(key.clone()); + let query_id = single_swarm + .behaviour_mut() + .get_providers(key.clone(), ProviderLimit::None); block_on(async { match single_swarm.next().await.unwrap() { @@ -1379,3 +1381,85 @@ fn get_providers() { } QuickCheck::new().tests(10).quickcheck(prop as fn(_)) } + +fn get_providers_limit() { + fn prop(key: record::Key) { + let mut swarms = build_nodes(3); + + // Let first peer know of second peer and second peer know of third peer. + for i in 0..2 { + let (peer_id, address) = ( + Swarm::local_peer_id(&swarms[i + 1].1).clone(), + swarms[i + 1].0.clone(), + ); + swarms[i].1.behaviour_mut().add_address(&peer_id, address); + } + + // Drop the swarm addresses. + let mut swarms = swarms + .into_iter() + .map(|(_addr, swarm)| swarm) + .collect::>(); + + // Provide the content on peer 2 and 3. + for i in 1..3 { + swarms[i] + .behaviour_mut() + .start_providing(key.clone()) + .expect("could not provide"); + } + + // Query with expecting a single provider. + let query_id = swarms[0] + .behaviour_mut() + .get_providers(key.clone(), ProviderLimit::N(N.try_into().unwrap())); + + block_on(poll_fn(move |ctx| { + for (i, swarm) in swarms.iter_mut().enumerate() { + loop { + match swarm.poll_next_unpin(ctx) { + Poll::Ready(Some(SwarmEvent::Behaviour( + KademliaEvent::OutboundQueryCompleted { + id, + result: + QueryResult::GetProviders(Ok(GetProvidersOk { + key: found_key, + providers, + .. + })), + .. + }, + ))) if i == 0 && id == query_id => { + // There are a total of 2 providers. + assert_eq!(providers.len(), std::cmp::min(N, 2)); + assert_eq!(key, found_key); + // Providers should be either 2 or 3 + assert_ne!(swarm.local_peer_id(), providers.iter().next().unwrap()); + return Poll::Ready(()); + } + Poll::Ready(..) => {} + Poll::Pending => break, + } + } + } + Poll::Pending + })); + } + + QuickCheck::new().tests(10).quickcheck(prop:: as fn(_)) +} + +#[test] +fn get_providers_limit_n_1() { + get_providers_limit::<1>(); +} + +#[test] +fn get_providers_limit_n_2() { + get_providers_limit::<1>(); +} + +#[test] +fn get_providers_limit_n_5() { + get_providers_limit::<5>(); +} diff --git a/protocols/kad/src/lib.rs b/protocols/kad/src/lib.rs index e46e2b16b54a..0e0374b1d387 100644 --- a/protocols/kad/src/lib.rs +++ b/protocols/kad/src/lib.rs @@ -67,7 +67,7 @@ pub use behaviour::{ }; pub use behaviour::{ Kademlia, KademliaBucketInserts, KademliaCaching, KademliaConfig, KademliaEvent, - KademliaStoreInserts, Quorum, + KademliaStoreInserts, ProviderLimit, Quorum, }; pub use protocol::KadConnectionType; pub use query::QueryId;