Skip to content

Commit db7ce77

Browse files
committed
refactor: switch back to RwLock
1 parent 945ddf6 commit db7ce77

File tree

6 files changed

+117
-108
lines changed

6 files changed

+117
-108
lines changed

Cargo.lock

Lines changed: 0 additions & 15 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

packages/rs-dapi-client/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ lru = { version = "0.12.3" }
3737
serde = { version = "1.0.197", optional = true, features = ["derive"] }
3838
serde_json = { version = "1.0.120", optional = true }
3939
chrono = { version = "0.4.38", features = ["serde"] }
40-
dashmap = "6.1.0"
4140

4241
[dev-dependencies]
4342
tokio = { version = "1.40", features = ["macros"] }

packages/rs-dapi-client/src/address_list.rs

Lines changed: 94 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,21 @@
33
use chrono::Utc;
44
use dapi_grpc::tonic::codegen::http;
55
use dapi_grpc::tonic::transport::Uri;
6-
use dashmap::setref::multiple::RefMulti;
7-
use dashmap::DashSet;
86
use rand::{rngs::SmallRng, seq::IteratorRandom, SeedableRng};
7+
use std::collections::hash_map::Entry;
8+
use std::collections::HashMap;
99
use std::hash::{Hash, Hasher};
10+
use std::mem;
1011
use std::str::FromStr;
11-
use std::sync::Arc;
12+
use std::sync::{Arc, RwLock};
1213
use std::time::Duration;
1314

1415
const DEFAULT_BASE_BAN_PERIOD: Duration = Duration::from_secs(60);
1516

1617
/// DAPI address.
1718
#[derive(Debug, Clone, Eq)]
1819
#[cfg_attr(feature = "mocks", derive(serde::Serialize, serde::Deserialize))]
19-
pub struct Address {
20-
ban_count: usize,
21-
banned_until: Option<chrono::DateTime<Utc>>,
22-
#[cfg_attr(feature = "mocks", serde(with = "http_serde::uri"))]
23-
uri: Uri,
24-
}
20+
pub struct Address(#[cfg_attr(feature = "mocks", serde(with = "http_serde::uri"))] Uri);
2521

2622
impl FromStr for Address {
2723
type Err = AddressListError;
@@ -35,35 +31,46 @@ impl FromStr for Address {
3531

3632
impl PartialEq<Self> for Address {
3733
fn eq(&self, other: &Self) -> bool {
38-
self.uri == other.uri
34+
self.0 == other.0
3935
}
4036
}
4137

4238
impl PartialEq<Uri> for Address {
4339
fn eq(&self, other: &Uri) -> bool {
44-
self.uri == *other
40+
self.0 == *other
4541
}
4642
}
4743

4844
impl Hash for Address {
4945
fn hash<H: Hasher>(&self, state: &mut H) {
50-
self.uri.hash(state);
46+
self.0.hash(state);
5147
}
5248
}
5349

5450
impl From<Uri> for Address {
5551
fn from(uri: Uri) -> Self {
56-
Address {
57-
ban_count: 0,
58-
banned_until: None,
59-
uri,
60-
}
52+
Address(uri)
6153
}
6254
}
6355

6456
impl Address {
57+
/// Get [Uri] of a node.
58+
pub fn uri(&self) -> &Uri {
59+
&self.0
60+
}
61+
}
62+
63+
/// Address status
64+
/// Contains information about the number of bans and the time until the next ban is lifted.
65+
#[derive(Debug, Default, Clone)]
66+
pub struct AddressStatus {
67+
ban_count: usize,
68+
banned_until: Option<chrono::DateTime<Utc>>,
69+
}
70+
71+
impl AddressStatus {
6572
/// Ban the [Address] so it won't be available through [AddressList::get_live_address] for some time.
66-
fn ban(&mut self, base_ban_period: &Duration) {
73+
pub fn ban(&mut self, base_ban_period: &Duration) {
6774
let coefficient = (self.ban_count as f64).exp();
6875
let ban_period = Duration::from_secs_f64(base_ban_period.as_secs_f64() * coefficient);
6976

@@ -77,24 +84,16 @@ impl Address {
7784
}
7885

7986
/// Clears ban record.
80-
fn unban(&mut self) {
87+
pub fn unban(&mut self) {
8188
self.ban_count = 0;
8289
self.banned_until = None;
8390
}
84-
85-
/// Get [Uri] of a node.
86-
pub fn uri(&self) -> &Uri {
87-
&self.uri
88-
}
8991
}
9092

9193
/// [AddressList] errors
9294
#[derive(Debug, thiserror::Error)]
9395
#[cfg_attr(feature = "mocks", derive(serde::Serialize, serde::Deserialize))]
9496
pub enum AddressListError {
95-
/// Specified address is not present in the list
96-
#[error("address {0} not found in the list")]
97-
AddressNotFound(#[cfg_attr(feature = "mocks", serde(with = "http_serde::uri"))] Uri),
9897
/// A valid uri is required to create an Address
9998
#[error("unable parse address: {0}")]
10099
#[cfg_attr(feature = "mocks", serde(skip))]
@@ -105,7 +104,7 @@ pub enum AddressListError {
105104
/// for [DapiRequest](crate::DapiRequest) execution.
106105
#[derive(Debug, Clone)]
107106
pub struct AddressList {
108-
addresses: Arc<DashSet<Address>>,
107+
addresses: Arc<RwLock<HashMap<Address, AddressStatus>>>,
109108
base_ban_period: Duration,
110109
}
111110

@@ -117,7 +116,7 @@ impl Default for AddressList {
117116

118117
impl std::fmt::Display for Address {
119118
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120-
self.uri.fmt(f)
119+
self.0.fmt(f)
121120
}
122121
}
123122

@@ -130,43 +129,70 @@ impl AddressList {
130129
/// Creates an empty [AddressList] with adjustable base ban time.
131130
pub fn with_settings(base_ban_period: Duration) -> Self {
132131
AddressList {
133-
addresses: Arc::new(DashSet::new()),
132+
addresses: Arc::new(RwLock::new(HashMap::new())),
134133
base_ban_period,
135134
}
136135
}
137136

138137
/// Bans address
139-
pub fn ban_address(&self, address: &Address) -> Result<(), AddressListError> {
140-
if self.addresses.remove(address).is_none() {
141-
return Err(AddressListError::AddressNotFound(address.uri.clone()));
142-
};
138+
/// Returns false if the address is not in the list.
139+
pub fn ban(&self, address: &Address) -> bool {
140+
let mut guard = self.addresses.write().unwrap();
143141

144-
let mut banned_address = address.clone();
145-
banned_address.ban(&self.base_ban_period);
142+
let Some(mut status) = guard.get_mut(address) else {
143+
return false;
144+
};
146145

147-
self.addresses.insert(banned_address);
146+
status.ban(&self.base_ban_period);
148147

149-
Ok(())
148+
true
150149
}
151150

152151
/// Clears address' ban record
153-
pub fn unban_address(&self, address: &Address) -> Result<(), AddressListError> {
154-
if self.addresses.remove(address).is_none() {
155-
return Err(AddressListError::AddressNotFound(address.uri.clone()));
152+
/// Returns false if the address is not in the list.
153+
pub fn unban(&self, address: &Address) -> bool {
154+
let mut guard = self.addresses.write().unwrap();
155+
156+
let Some(mut status) = guard.get_mut(address) else {
157+
return false;
156158
};
157159

158-
let mut unbanned_address = address.clone();
159-
unbanned_address.unban();
160+
status.unban();
161+
162+
true
163+
}
160164

161-
self.addresses.insert(unbanned_address);
165+
/// Check if the address is banned.
166+
pub fn is_banned(&self, address: &Address) -> bool {
167+
let guard = self.addresses.read().unwrap();
162168

163-
Ok(())
169+
guard
170+
.get(address)
171+
.map(|status| status.is_banned())
172+
.unwrap_or(false)
164173
}
165174

166175
/// Adds a node [Address] to [AddressList]
167176
/// Returns false if the address is already in the list.
168177
pub fn add(&mut self, address: Address) -> bool {
169-
self.addresses.insert(address)
178+
let mut guard = self.addresses.write().unwrap();
179+
180+
match guard.entry(address) {
181+
Entry::Occupied(_) => false,
182+
Entry::Vacant(e) => {
183+
e.insert(AddressStatus::default());
184+
185+
true
186+
}
187+
}
188+
}
189+
190+
/// Remove address from the list
191+
/// Returns [AddressStatus] if the address was in the list.
192+
pub fn remove(&mut self, address: &Address) -> Option<AddressStatus> {
193+
let mut guard = self.addresses.write().unwrap();
194+
195+
guard.remove(address)
170196
}
171197

172198
// TODO: this is the most simple way to add an address
@@ -175,41 +201,53 @@ impl AddressList {
175201
/// Add a node [Address] to [AddressList] by [Uri].
176202
/// Returns false if the address is already in the list.
177203
pub fn add_uri(&mut self, uri: Uri) -> bool {
178-
self.addresses.insert(uri.into())
204+
self.add(Address::from(uri))
179205
}
180206

181207
/// Randomly select a not banned address.
182-
pub fn get_live_address(&self) -> Option<RefMulti<Address>> {
208+
pub fn get_live_address(&self) -> Option<Address> {
209+
let mut guard = self.addresses.read().unwrap();
210+
183211
let mut rng = SmallRng::from_entropy();
184212

185213
let now = chrono::Utc::now();
186214

187-
self.addresses
215+
guard
188216
.iter()
189-
.filter(|addr| {
190-
addr.banned_until
217+
.filter(|(addr, status)| {
218+
status
219+
.banned_until
191220
.map(|banned_until| banned_until < now)
192221
.unwrap_or(true)
193222
})
194223
.choose(&mut rng)
224+
.map(|(addr, _)| addr.clone())
195225
}
196226

197227
/// Get number of all addresses, both banned and not banned.
198228
pub fn len(&self) -> usize {
199-
self.addresses.len()
229+
self.addresses.read().unwrap().len()
200230
}
201231

202232
/// Check if the list is empty.
203233
/// Returns true if there are no addresses in the list.
204234
/// Returns false if there is at least one address in the list.
205235
/// Banned addresses are also counted.
206236
pub fn is_empty(&self) -> bool {
207-
self.addresses.is_empty()
237+
self.addresses.read().unwrap().is_empty()
208238
}
239+
}
240+
241+
impl IntoIterator for AddressList {
242+
type Item = (Address, AddressStatus);
243+
type IntoIter = std::collections::hash_map::IntoIter<Address, AddressStatus>;
244+
245+
fn into_iter(self) -> Self::IntoIter {
246+
let mut guard = self.addresses.write().unwrap();
247+
248+
let addresses_map = mem::take(&mut *guard);
209249

210-
/// Get an iterator over all addresses.
211-
pub fn iter(&self) -> impl Iterator<Item = RefMulti<Address>> {
212-
self.addresses.iter()
250+
addresses_map.into_iter()
213251
}
214252
}
215253

0 commit comments

Comments
 (0)