diff --git a/Cargo.toml b/Cargo.toml index dc77904b..b8e0cbc0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,7 +24,6 @@ version = "0.1.0" edition = "2021" [workspace.dependencies] -async-lock = "3" futures = "0.3" rand = "0.8" bytes = "1" @@ -43,6 +42,9 @@ clap = { version = "4", features = ["derive"] } dashmap = "6" derive_builder = "0.20" env_logger = "0.11" +parking_lot = "0.11" + + [workspace.dependencies.qbase] path = "./qbase" version = "0.1.0" diff --git a/qbase/Cargo.toml b/qbase/Cargo.toml index 5eba07e2..114f7ec1 100644 --- a/qbase/Cargo.toml +++ b/qbase/Cargo.toml @@ -17,6 +17,7 @@ deref-derive = { workspace = true } rustls = { workspace = true } log = { workspace = true } derive_builder = { workspace = true } +parking_lot = { workspace = true } [dev-dependencies] tokio = { workspace = true } diff --git a/qbase/src/cid/local_cid.rs b/qbase/src/cid/local_cid.rs index b78e4303..e02ad59b 100644 --- a/qbase/src/cid/local_cid.rs +++ b/qbase/src/cid/local_cid.rs @@ -1,4 +1,6 @@ -use std::sync::{Arc, Mutex}; +use std::sync::Arc; + +use parking_lot::Mutex; use super::{ConnectionId, UniqueCid}; use crate::{ @@ -8,7 +10,6 @@ use crate::{ util::IndexDeque, varint::{VarInt, VARINT_MAX}, }; - /// 我方负责发放足够的cid,poll_issue_cid,将当前有效的cid注册到连接id路由。 /// 当cid不足时,就发放新的连接id,包括增大active_cid_limit,以及对方淘汰旧的cid。 #[derive(Debug)] @@ -140,7 +141,6 @@ where pub fn active_cids(&self) -> Vec { self.0 .lock() - .unwrap() .cid_deque .iter() .filter_map(|v| v.map(|(cid, _)| cid)) @@ -148,7 +148,7 @@ where } pub fn set_limit(&self, active_cid_limit: u64) -> Result<(), Error> { - self.0.lock().unwrap().set_limit(active_cid_limit) + self.0.lock().set_limit(active_cid_limit) } } @@ -163,7 +163,7 @@ where &mut self, frame: &RetireConnectionIdFrame, ) -> Result { - self.0.lock().unwrap().recv_retire_cid_frame(frame) + self.0.lock().recv_retire_cid_frame(frame) } } @@ -196,7 +196,7 @@ mod tests { fn test_issue_cid() { let initial_scid = ConnectionId::random_gen(8); let local_cids = ArcLocalCids::new(generator, initial_scid, IssuedCids::default()); - let mut guard = local_cids.0.lock().unwrap(); + let mut guard = local_cids.0.lock(); assert_eq!(guard.cid_deque.len(), 2); diff --git a/qbase/src/cid/remote_cid.rs b/qbase/src/cid/remote_cid.rs index 6baf6007..e9da231e 100644 --- a/qbase/src/cid/remote_cid.rs +++ b/qbase/src/cid/remote_cid.rs @@ -1,11 +1,12 @@ use std::{ future::Future, pin::Pin, - sync::{Arc, Mutex}, + sync::Arc, task::{Context, Poll, Waker}, }; use deref_derive::{Deref, DerefMut}; +use parking_lot::Mutex; use super::ConnectionId; use crate::{ @@ -145,7 +146,7 @@ where // retire the cids before seq, including the applied and unapplied for seq in self.cid_cells.offset()..max_retired { let (_, cell) = self.cid_cells.pop_front().unwrap(); - let mut guard = cell.0.lock().unwrap(); + let mut guard = cell.0.lock(); if guard.is_retired() { continue; } else { @@ -226,7 +227,7 @@ where /// - have been allocated again after retirement of last cid /// - have been retired pub fn apply_cid(&self) -> ArcCidCell { - self.0.lock().unwrap().apply_cid() + self.0.lock().apply_cid() } } @@ -237,7 +238,7 @@ where type Output = Option; fn recv_frame(&mut self, frame: &NewConnectionIdFrame) -> Result { - self.0.lock().unwrap().recv_new_cid_frame(frame) + self.0.lock().recv_new_cid_frame(frame) } } @@ -332,15 +333,15 @@ where } fn assign(&self, cid: ConnectionId) { - self.0.lock().unwrap().assign(cid); + self.0.lock().assign(cid); } pub fn set_cid(&self, cid: ConnectionId) { - self.0.lock().unwrap().set_cid(cid); + self.0.lock().set_cid(cid); } pub fn poll_get_cid(&self, cx: &mut Context<'_>) -> Poll> { - self.0.lock().unwrap().poll_get_cid(cx) + self.0.lock().poll_get_cid(cx) } /// Getting the connection ID, if it is not ready, return a future @@ -353,7 +354,7 @@ where /// is marked as no longer in use, with a RetireConnectionIdFrame being sent to peer. #[inline] pub fn retire(&self) { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); if !guard.is_retired() { guard.state.retire(); let seq = guard.seq; @@ -371,7 +372,7 @@ where type Output = Option; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.0.lock().unwrap().poll_get_cid(cx) + self.0.lock().poll_get_cid(cx) } } @@ -392,10 +393,7 @@ mod tests { // Will return Pending, because the peer hasn't issue any connection id let cid_apply = remote_cids.apply_cid(); assert_eq!(cid_apply.get_cid().poll_unpin(&mut cx), Poll::Pending); - assert!(matches!( - cid_apply.0.lock().unwrap().state, - CidState::Demand(_) - )); + assert!(matches!(cid_apply.0.lock().state, CidState::Demand(_))); let cid = ConnectionId::random_gen(8); let frame = NewConnectionIdFrame { @@ -424,7 +422,7 @@ mod tests { let mut cx = std::task::Context::from_waker(&waker); let retired_cids = ArcAsyncDeque::::new(); let remote_cids = ArcRemoteCids::with_limit(8, retired_cids, None); - let mut guard = remote_cids.0.lock().unwrap(); + let mut guard = remote_cids.0.lock(); let mut cids = vec![]; for seq in 0..8 { @@ -441,8 +439,8 @@ mod tests { let cid_apply1 = guard.apply_cid(); let cid_apply2 = guard.apply_cid(); - assert_eq!(cid_apply1.0.lock().unwrap().seq, 0); - assert_eq!(cid_apply2.0.lock().unwrap().seq, 1); + assert_eq!(cid_apply1.0.lock().seq, 0); + assert_eq!(cid_apply2.0.lock().seq, 1); assert_eq!( cid_apply1.get_cid().poll_unpin(&mut cx), Poll::Ready(Some(cids[0])) @@ -457,8 +455,8 @@ mod tests { assert_eq!(guard.cid_cells.offset(), 4); assert_eq!(guard.retired_cids.len(), 4); - assert_eq!(cid_apply1.0.lock().unwrap().seq, 4); - assert_eq!(cid_apply2.0.lock().unwrap().seq, 5); + assert_eq!(cid_apply1.0.lock().seq, 4); + assert_eq!(cid_apply2.0.lock().seq, 5); for i in 0..4 { assert_eq!( @@ -498,7 +496,7 @@ mod tests { let mut cx = std::task::Context::from_waker(&waker); let retired_cids = ArcAsyncDeque::::new(); let remote_cids = ArcRemoteCids::with_limit(8, retired_cids, None); - let mut guard = remote_cids.0.lock().unwrap(); + let mut guard = remote_cids.0.lock(); let mut cids = vec![]; for seq in 0..8 { @@ -520,8 +518,8 @@ mod tests { let cid_apply1 = guard.apply_cid(); let cid_apply2 = guard.apply_cid(); - assert_eq!(cid_apply1.0.lock().unwrap().seq, 4); - assert_eq!(cid_apply2.0.lock().unwrap().seq, 5); + assert_eq!(cid_apply1.0.lock().seq, 4); + assert_eq!(cid_apply2.0.lock().seq, 5); assert_eq!( cid_apply1.get_cid().poll_unpin(&mut cx), Poll::Ready(Some(cids[4])) diff --git a/qbase/src/flow.rs b/qbase/src/flow.rs index 6b1b3449..51bed960 100644 --- a/qbase/src/flow.rs +++ b/qbase/src/flow.rs @@ -3,12 +3,13 @@ use std::{ pin::Pin, sync::{ atomic::{AtomicBool, AtomicU64, Ordering}, - Arc, Mutex, MutexGuard, + Arc, }, task::{Context, Poll, Waker}, }; use futures::{task::AtomicWaker, Future}; +use parking_lot::{Mutex, MutexGuard}; use thiserror::Error; use crate::{ @@ -95,7 +96,7 @@ impl ArcSendControler { } fn increase_limit(&self, max_data: u64) { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); if let Ok(inner) = guard.deref_mut() { inner.increase_limit(max_data); } @@ -109,7 +110,7 @@ impl ArcSendControler { /// Apply for sending data. If it has meet error, it will return Err directly. pub fn credit(&self) -> Result, QuicError> { - let guard = self.0.lock().unwrap(); + let guard = self.0.lock(); if let Err(e) = guard.deref() { return Err(e.clone()); } @@ -121,7 +122,7 @@ impl ArcSendControler { /// not require to register the send task on the flow control, as there may still be /// retransmission data that can be sent. pub fn register_waker(&self, waker: Waker) { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); if let Ok(inner) = guard.deref_mut() { inner.register_waker(waker); } @@ -129,7 +130,7 @@ impl ArcSendControler { /// Flow control can only be terminated if the connection encounters an error pub fn on_error(&self, error: &QuicError) { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); if guard.deref().is_err() { return; } @@ -156,7 +157,7 @@ impl Future for WouldBlock { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut guard = self.0 .0.lock().unwrap(); + let mut guard = self.0 .0.lock(); match guard.deref_mut() { Ok(inner) => inner.poll_would_block(cx), Err(e) => Poll::Ready(Err(e.clone())), diff --git a/qbase/src/packet/keys.rs b/qbase/src/packet/keys.rs index fb69d80d..b22629f9 100644 --- a/qbase/src/packet/keys.rs +++ b/qbase/src/packet/keys.rs @@ -2,10 +2,11 @@ use std::{ future::Future, ops::DerefMut, pin::Pin, - sync::{Arc, Mutex, MutexGuard}, + sync::Arc, task::{Context, Poll, Waker}, }; +use parking_lot::{Mutex, MutexGuard}; use rustls::quic::{HeaderProtectionKey, Keys, PacketKey, Secrets}; use super::KeyPhaseBit; @@ -22,7 +23,7 @@ pub struct ArcKeys(Arc>); impl ArcKeys { fn lock_guard(&self) -> MutexGuard { - self.0.lock().unwrap() + self.0.lock() } pub fn new_pending() -> Self { @@ -156,7 +157,7 @@ pub struct ArcOneRttPacketKeys(Arc<(Mutex, usize)>); impl ArcOneRttPacketKeys { pub fn lock_guard(&self) -> MutexGuard { - self.0 .0.lock().unwrap() + self.0 .0.lock() } pub fn tag_len(&self) -> usize { @@ -184,7 +185,7 @@ pub struct ArcOneRttKeys(Arc>); impl ArcOneRttKeys { fn lock_guard(&self) -> MutexGuard { - self.0.lock().unwrap() + self.0.lock() } pub fn new_pending() -> Self { diff --git a/qbase/src/streamid.rs b/qbase/src/streamid.rs index b40f4a4a..e0e1df5f 100644 --- a/qbase/src/streamid.rs +++ b/qbase/src/streamid.rs @@ -1,9 +1,10 @@ use std::{ fmt, ops, - sync::{Arc, Mutex}, + sync::Arc, task::{Context, Poll, Waker}, }; +use parking_lot::Mutex; use thiserror::Error; use super::varint::{be_varint, VarInt, WriteVarInt}; @@ -344,13 +345,13 @@ impl ArcLocalStreamIds { } pub fn role(&self) -> Role { - self.0.lock().unwrap().role() + self.0.lock().role() } /// The maximum stream ID that we can create is limited by peer. Therefore, it mainly /// depends on the peer's attitude and is subject to the MAX_STREAM_FRAME frame sent by peer. pub fn permit_max_sid(&self, dir: Dir, val: u64) { - self.0.lock().unwrap().permit_max_sid(dir, val); + self.0.lock().permit_max_sid(dir, val); } /// We are creating a new stream, and it should be incremented based on the previous stream ID. However, @@ -360,7 +361,7 @@ impl ArcLocalStreamIds { /// maximum stream ID and cannot increase it further. In this case, we should close the connection /// because sending MAX_STREAMS will not be received and would violate the protocol. pub fn poll_alloc_sid(&self, cx: &mut Context<'_>, dir: Dir) -> Poll> { - self.0.lock().unwrap().poll_alloc_sid(cx, dir) + self.0.lock().poll_alloc_sid(cx, dir) } } @@ -378,17 +379,17 @@ impl ArcRemoteStreamIds { } pub fn role(&self) -> Role { - self.0.lock().unwrap().role() + self.0.lock().role() } /// RFC9000: Before a stream is created, all streams of the same type /// with lower-numbered stream IDs MUST be created. pub fn try_accept_sid(&self, sid: StreamId) -> Result { - self.0.lock().unwrap().try_accept_sid(sid) + self.0.lock().try_accept_sid(sid) } pub fn poll_extend_sid(&self, cx: &mut Context<'_>, dir: Dir) -> Poll> { - self.0.lock().unwrap().poll_extend_sid(cx, dir) + self.0.lock().poll_extend_sid(cx, dir) } } @@ -444,15 +445,15 @@ mod tests { Poll::Ready(Some(StreamId(0))) ); assert_eq!(local.poll_alloc_sid(&mut cx, Dir::Bi), Poll::Pending); - assert!(local.0.lock().unwrap().wakers[0].is_some()); + assert!(local.0.lock().wakers[0].is_some()); local.permit_max_sid(Dir::Bi, 1); - let _ = local.0.lock().unwrap().wakers[0].take(); + let _ = local.0.lock().wakers[0].take(); assert_eq!( local.poll_alloc_sid(&mut cx, Dir::Bi), Poll::Ready(Some(StreamId(4))) ); assert_eq!(local.poll_alloc_sid(&mut cx, Dir::Bi), Poll::Pending); - assert!(local.0.lock().unwrap().wakers[0].is_some()); + assert!(local.0.lock().wakers[0].is_some()); local.permit_max_sid(Dir::Uni, 2); assert_eq!( @@ -468,7 +469,7 @@ mod tests { Poll::Ready(Some(StreamId(10))) ); assert_eq!(local.poll_alloc_sid(&mut cx, Dir::Uni), Poll::Pending); - assert!(local.0.lock().unwrap().wakers[1].is_some()); + assert!(local.0.lock().wakers[1].is_some()); } #[test] @@ -482,7 +483,7 @@ mod tests { end: StreamId(21) })) ); - assert_eq!(remote.0.lock().unwrap().unallocated[0], StreamId(25)); + assert_eq!(remote.0.lock().unallocated[0], StreamId(25)); let result = remote.try_accept_sid(StreamId(25)); assert_eq!( @@ -492,7 +493,7 @@ mod tests { end: StreamId(25) })) ); - assert_eq!(remote.0.lock().unwrap().unallocated[0], StreamId(29)); + assert_eq!(remote.0.lock().unallocated[0], StreamId(29)); let result = remote.try_accept_sid(StreamId(41)); assert_eq!( @@ -502,7 +503,7 @@ mod tests { end: StreamId(41) })) ); - assert_eq!(remote.0.lock().unwrap().unallocated[0], StreamId(45)); + assert_eq!(remote.0.lock().unallocated[0], StreamId(45)); if let Ok(AcceptSid::New(mut range)) = result { assert_eq!(range.next(), Some(StreamId(29))); assert_eq!(range.next(), Some(StreamId(33))); diff --git a/qbase/src/token.rs b/qbase/src/token.rs index dbd5f44a..91210519 100644 --- a/qbase/src/token.rs +++ b/qbase/src/token.rs @@ -1,7 +1,8 @@ -use std::sync::{Arc, Mutex, MutexGuard}; +use std::sync::Arc; use bytes::BufMut; use nom::{bytes::complete::take, IResult}; +use parking_lot::{Mutex, MutexGuard}; use rand::Rng; use crate::{ @@ -98,7 +99,7 @@ impl ArcTokenRegistry { } pub fn lock_guard(&self) -> MutexGuard { - self.0.lock().unwrap() + self.0.lock() } } pub enum TokenRegistry { @@ -110,7 +111,7 @@ impl ReceiveFrame for ArcTokenRegistry { type Output = (); fn recv_frame(&mut self, frame: &NewTokenFrame) -> Result { - let guard = self.0.lock().unwrap(); + let guard = self.0.lock(); match &*guard { TokenRegistry::Client((server_name, client)) => { client.sink(server_name, frame.token.clone()); diff --git a/qbase/src/util/async_deque.rs b/qbase/src/util/async_deque.rs index 23448687..a340c1e9 100644 --- a/qbase/src/util/async_deque.rs +++ b/qbase/src/util/async_deque.rs @@ -1,11 +1,12 @@ use std::{ collections::VecDeque, pin::Pin, - sync::{Arc, Mutex, MutexGuard}, + sync::Arc, task::{Context, Poll, Waker}, }; use futures::Stream; +use parking_lot::{Mutex, MutexGuard}; #[derive(Debug)] struct AsyncDeque { @@ -25,7 +26,7 @@ impl ArcAsyncDeque { } pub fn push(&self, value: T) { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); if let Some(queue) = &mut guard.queue { queue.push_back(value); if let Some(waker) = guard.waker.take() { @@ -35,18 +36,12 @@ impl ArcAsyncDeque { } pub fn pop(&self) -> Option { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); guard.queue.as_mut().and_then(|q| q.pop_front()) } pub fn len(&self) -> usize { - self.0 - .lock() - .unwrap() - .queue - .as_ref() - .map(|v| v.len()) - .unwrap_or(0) + self.0.lock().queue.as_ref().map(|v| v.len()).unwrap_or(0) } pub fn is_empty(&self) -> bool { @@ -54,7 +49,7 @@ impl ArcAsyncDeque { } pub fn close(&self) { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); guard.queue = None; if let Some(waker) = guard.waker.take() { waker.wake(); @@ -63,7 +58,7 @@ impl ArcAsyncDeque { // pub fn writer<'a>(&'a self) -> ArcFrameQueueWriter<'a, T> { pub fn writer(&self) -> ArcAsyncDequeWriter<'_, T> { - let guard = self.0.lock().unwrap(); + let guard = self.0.lock(); let old_len = guard.queue.as_ref().map(|q| q.len()).unwrap_or(0); ArcAsyncDequeWriter { guard, old_len } } @@ -85,7 +80,7 @@ impl Stream for ArcAsyncDeque { type Item = T; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); match &mut guard.queue { Some(queue) => { if let Some(frame) = queue.pop_front() { @@ -102,7 +97,7 @@ impl Stream for ArcAsyncDeque { impl Extend for ArcAsyncDeque { fn extend>(&mut self, iter: I) { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); if let Some(queue) = &mut guard.queue { queue.extend(iter); if let Some(waker) = guard.waker.take() { @@ -137,15 +132,12 @@ impl ArcAsyncDequeWriter<'_, T> { impl Drop for ArcAsyncDequeWriter<'_, T> { fn drop(&mut self) { - match &mut self.guard.queue { - Some(queue) => { - if queue.len() > self.old_len { - if let Some(waker) = self.guard.waker.take() { - waker.wake(); - } + if let Some(queue) = &mut self.guard.queue { + if queue.len() > self.old_len { + if let Some(waker) = self.guard.waker.take() { + waker.wake(); } } - None => {} } } } diff --git a/qcongestion/Cargo.toml b/qcongestion/Cargo.toml index 99913095..e71382bc 100644 --- a/qcongestion/Cargo.toml +++ b/qcongestion/Cargo.toml @@ -8,4 +8,5 @@ edition.workspace = true [dependencies] rand = { workspace = true } qbase = { workspace = true } -qrecovery = { workspace = true } \ No newline at end of file +qrecovery = { workspace = true } +parking_lot = { workspace = true } diff --git a/qcongestion/src/congestion.rs b/qcongestion/src/congestion.rs index e231d4a6..d813be79 100644 --- a/qcongestion/src/congestion.rs +++ b/qcongestion/src/congestion.rs @@ -2,11 +2,12 @@ use std::{ cmp::Ordering, collections::VecDeque, future::Future, - sync::{Arc, Mutex}, + sync::Arc, task::{Context, Poll, Waker}, time::{Duration, Instant}, }; +use parking_lot::Mutex; use qbase::frame::{AckFrame, EcnCounts}; use qrecovery::space::Epoch; @@ -425,7 +426,7 @@ impl ArcCC { impl super::CongestionControl for ArcCC { fn poll_send(&self, cx: &mut Context<'_>) -> Poll { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); let now = Instant::now(); if guard.loss_timer.is_timeout(now) { @@ -459,7 +460,7 @@ impl super::CongestionControl for ArcCC { } fn need_ack(&self, space: Epoch) -> Option<(u64, Instant)> { - let guard = self.0.lock().unwrap(); + let guard = self.0.lock(); if let Some(recved) = &guard.largest_ack_eliciting_packet[space] { return Some((recved.pn, recved.recv_time)); } @@ -475,7 +476,7 @@ impl super::CongestionControl for ArcCC { in_flight: bool, ack: Option, ) { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); let now = Instant::now(); guard.on_packet_sent(pn, epoch, is_ack_eliciting, in_flight, sent_bytes, now); @@ -489,7 +490,7 @@ impl super::CongestionControl for ArcCC { } fn on_ack(&self, space: Epoch, ack_frame: &AckFrame) { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); let now = Instant::now(); guard.on_ack_rcvd(space, ack_frame, now); } @@ -500,7 +501,7 @@ impl super::CongestionControl for ArcCC { } let now = Instant::now(); let recved = Recved { pn, recv_time: now }; - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); if let Some(r) = &guard.largest_ack_eliciting_packet[space] { if pn > r.pn { guard.largest_ack_eliciting_packet[space] = Some(recved); @@ -521,16 +522,16 @@ impl super::CongestionControl for ArcCC { } fn get_pto_time(&self, epoch: Epoch) -> Duration { - self.0.lock().unwrap().get_pto_time(epoch) + self.0.lock().get_pto_time(epoch) } fn on_get_handshake_keys(&self) { - let mut gurad = self.0.lock().unwrap(); + let mut gurad = self.0.lock(); gurad.has_handshake_keys = true; } fn on_handshake_done(&self) { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); guard.is_handshake_done = true; guard.rtt.on_handshake_done(); } @@ -542,7 +543,7 @@ impl Future for MayLoss { type Output = (Epoch, Vec); fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut guard = self.0 .0.lock().unwrap(); + let mut guard = self.0 .0.lock(); if let Some(loss) = guard.loss_pns.take() { return Poll::Ready(loss); @@ -558,7 +559,7 @@ impl Future for IndicateAck { type Output = (Epoch, Vec); fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut guard = self.0 .0.lock().unwrap(); + let mut guard = self.0 .0.lock(); if let Some(acked) = guard.newly_ack_pns.take() { return Poll::Ready(acked); } @@ -573,7 +574,7 @@ impl Future for Prober { type Output = Epoch; fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut guard = self.0 .0.lock().unwrap(); + let mut guard = self.0 .0.lock(); if let Some(space) = guard.pto_space { return Poll::Ready(space); } diff --git a/qcongestion/src/rtt.rs b/qcongestion/src/rtt.rs index 7e71e577..21cd1982 100644 --- a/qcongestion/src/rtt.rs +++ b/qcongestion/src/rtt.rs @@ -1,8 +1,9 @@ use std::{ - sync::{Arc, Mutex}, + sync::Arc, time::{Duration, Instant}, }; +use parking_lot::Mutex; pub const INITIAL_RTT: Duration = Duration::from_millis(333); const GRANULARITY: Duration = Duration::from_millis(1); const TIME_THRESHOLD: f32 = 1.125; @@ -92,27 +93,27 @@ impl ArcRtt { } pub fn update(&self, latest_rtt: Duration, ack_delay: Duration) { - self.0.lock().unwrap().update(latest_rtt, ack_delay); + self.0.lock().update(latest_rtt, ack_delay); } pub fn loss_delay(&self) -> Duration { - self.0.lock().unwrap().loss_delay() + self.0.lock().loss_delay() } pub fn on_handshake_done(&self) { - self.0.lock().unwrap().on_handshake_done(); + self.0.lock().on_handshake_done(); } pub fn pto_base_duration(&self, times: u32) -> Duration { - self.0.lock().unwrap().pto_base_duration(times) + self.0.lock().pto_base_duration(times) } pub fn smoothed_rtt(&self) -> Duration { - self.0.lock().unwrap().smoothed_rtt + self.0.lock().smoothed_rtt } pub fn rttvar(&self) -> Duration { - self.0.lock().unwrap().rttvar + self.0.lock().rttvar } } diff --git a/qconnection/Cargo.toml b/qconnection/Cargo.toml index e536f5a6..305f5b82 100644 --- a/qconnection/Cargo.toml +++ b/qconnection/Cargo.toml @@ -19,3 +19,4 @@ rustls = { workspace = true } log = { workspace = true } deref-derive = { workspace = true } dashmap = { workspace = true } +parking_lot = { workspace = true } diff --git a/qconnection/src/connection.rs b/qconnection/src/connection.rs index 448a70c8..6761071c 100644 --- a/qconnection/src/connection.rs +++ b/qconnection/src/connection.rs @@ -2,13 +2,14 @@ use std::{ fmt::Debug, mem, ops::DerefMut, - sync::{Arc, Mutex}, + sync::Arc, time::{Duration, Instant}, }; use closing::ClosingConnection; use draining::DrainingConnection; use futures::{channel::mpsc, StreamExt}; +use parking_lot::Mutex; use qbase::{ cid::{self, ConnectionId}, config::Parameters, @@ -121,7 +122,7 @@ impl ArcConnection { /// This function is intended for use by the application layer to signal an /// error and initiate the connection closure. pub fn close_with_error(self, error: Error) { - let guard = self.0.lock().unwrap(); + let guard = self.0.lock(); if let ConnState::Raw(ref raw_conn) = *guard { raw_conn.error.set_app_error(error) } @@ -134,7 +135,7 @@ impl ArcConnection { /// confirmation, any remaining data is drained. If the timeout expires without /// confirmation, the connection is forcefully terminated. fn should_enter_closing_with_error(&self, error: Error) { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); let ConnState::Raw(raw_conn) = mem::replace(guard.deref_mut(), ConnState::Closed) else { unreachable!() @@ -198,7 +199,7 @@ impl ArcConnection { /// Enter draining state from raw state or closing state. /// Can only be called internally, and the app should not care this method. pub(crate) fn enter_draining(&self, remaining: Duration) { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); let draining_conn = match mem::replace(guard.deref_mut(), ConnState::Closed) { Raw(conn) => DrainingConnection::from(conn), Closing(closing_conn) => DrainingConnection::from(closing_conn), @@ -219,7 +220,7 @@ impl ArcConnection { /// Dismiss the connection, remove it from the global router. /// Can only be called internally, and the app should not care this method. pub(crate) fn die(self) { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); let local_cids = match mem::replace(guard.deref_mut(), ConnState::Closed) { Raw(conn) => conn.cid_registry.local, Closing(conn) => conn.cid_registry.local, @@ -233,15 +234,15 @@ impl ArcConnection { } pub fn update_path_recv_time(&self, pathway: Pathway) { - let guard = self.0.lock().unwrap(); + let guard = self.0.lock(); if let ConnState::Raw(ref raw_conn) = *guard { raw_conn.update_path_recv_time(pathway); } } pub fn recv_retry_packet(&self, retry: &RetryHeader, pathway: Pathway, usc: ArcUsc) { - if let Raw(conn) = &mut *self.0.lock().unwrap() { - *conn.token.lock().unwrap() = retry.token.to_vec(); + if let Raw(conn) = &mut *self.0.lock() { + *conn.token.lock() = retry.token.to_vec(); let path = conn.pathes.get(pathway, usc); path.set_dcid(retry.scid); let sent_record = conn.initial.space.sent_packets(); diff --git a/qconnection/src/connection/closing.rs b/qconnection/src/connection/closing.rs index cfadae2e..dee387af 100644 --- a/qconnection/src/connection/closing.rs +++ b/qconnection/src/connection/closing.rs @@ -4,12 +4,13 @@ use std::{ pin::Pin, sync::{ atomic::{AtomicUsize, Ordering}, - Arc, Mutex, + Arc, }, task::{Context, Poll, Waker}, time::{Duration, Instant}, }; +use parking_lot::Mutex; use qbase::{ error::Error, frame::ConnectionCloseFrame, @@ -60,7 +61,7 @@ impl ClosingConnection { pub fn recv_packet_via_pathway(&mut self, packet: DataPacket, _pathway: Pathway, _usc: ArcUsc) { self.rcvd_packets.fetch_add(1, Ordering::Release); // TODO: 数值从配置中读取, 还是直接固定值? - let mut last_send_ccf = self.last_send_ccf.lock().unwrap(); + let mut last_send_ccf = self.last_send_ccf.lock(); if self.rcvd_packets.load(Ordering::Relaxed) > 5 || last_send_ccf.elapsed() > Duration::from_millis(100) { @@ -116,7 +117,7 @@ impl RcvdCcf { } pub fn on_ccf_rcvd(&self) { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); if let RcvdCcfState::Pending(waker) = guard.deref_mut() { waker.wake_by_ref(); } @@ -128,7 +129,7 @@ impl Future for RcvdCcf { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); match guard.deref_mut() { RcvdCcfState::None | RcvdCcfState::Pending(_) => { *guard = RcvdCcfState::Pending(cx.waker().clone()); diff --git a/qconnection/src/connection/raw.rs b/qconnection/src/connection/raw.rs index 657eddf1..2b6f33b5 100644 --- a/qconnection/src/connection/raw.rs +++ b/qconnection/src/connection/raw.rs @@ -1,6 +1,7 @@ -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use futures::channel::mpsc; +use parking_lot::Mutex; use qbase::{ cid::ConnectionId, flow::FlowController, diff --git a/qconnection/src/connection/scope/initial.rs b/qconnection/src/connection/scope/initial.rs index 4db94974..3e36137b 100644 --- a/qconnection/src/connection/scope/initial.rs +++ b/qconnection/src/connection/scope/initial.rs @@ -1,6 +1,7 @@ -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use futures::{channel::mpsc, StreamExt}; +use parking_lot::Mutex; use qbase::{ frame::{AckFrame, Frame, FrameReader, ReceiveFrame}, packet::{ diff --git a/qconnection/src/connection/transmit/initial.rs b/qconnection/src/connection/transmit/initial.rs index d221c10f..f63dbb5a 100644 --- a/qconnection/src/connection/transmit/initial.rs +++ b/qconnection/src/connection/transmit/initial.rs @@ -1,9 +1,7 @@ -use std::{ - sync::{Arc, Mutex}, - time::Instant, -}; +use std::{sync::Arc, time::Instant}; use bytes::BufMut; +use parking_lot::Mutex; use qbase::{ cid::ConnectionId, packet::{ @@ -39,7 +37,7 @@ impl InitialSpaceReader { let k = self.keys.get_local_keys()?; // 2. 生成包头,预留2字节len,根据包头大小,配合constraints、剩余空间,检查是否能发送,不能的话,直接返回 - let token = self.token.lock().unwrap(); + let token = self.token.lock(); let hdr = LongHeaderBuilder::with_cid(dcid, scid).initial(token.clone()); // length字段预留2字节, 20字节为最小Payload长度,为了保护包头的Sample至少16字节 if buf.len() < hdr.size() + 2 + 20 { diff --git a/qconnection/src/error.rs b/qconnection/src/error.rs index fe57b1c8..db7ded3e 100644 --- a/qconnection/src/error.rs +++ b/qconnection/src/error.rs @@ -2,10 +2,11 @@ use std::{ future::Future, ops::DerefMut, pin::Pin, - sync::{Arc, Mutex}, + sync::Arc, task::{Context, Poll, Waker}, }; +use parking_lot::Mutex; use qbase::{error::Error, frame::ConnectionCloseFrame}; #[derive(Debug, Clone, Default)] @@ -53,7 +54,7 @@ impl ConnError { /// When a connection close frame is received, it will change the state and wake the external if necessary. pub fn on_ccf_rcvd(&self, ccf: &ConnectionCloseFrame) { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); // ccf具有最高的优先级 if let ConnErrorState::Pending(waker) = guard.deref_mut() { waker.wake_by_ref(); @@ -62,7 +63,7 @@ impl ConnError { } pub fn on_error(&self, error: Error) { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); match guard.deref_mut() { ConnErrorState::None => { *guard = ConnErrorState::Closing(error); @@ -77,7 +78,7 @@ impl ConnError { /// App actively close the connection with an error pub fn set_app_error(&self, error: Error) { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); match guard.deref_mut() { ConnErrorState::None => { *guard = ConnErrorState::App(error); @@ -111,7 +112,7 @@ impl Future for ConnError { /// - If the state is `Closing` or `App`, it returns `Poll::Ready(true)`, indicating that the connection is closing or has been closed due to an application error. /// - If the state is `Draining`, it returns `Poll::Ready(false)`, indicating that the connection is draining and will be closed gracefully. fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); match guard.deref_mut() { ConnErrorState::None | ConnErrorState::Pending(_) => { *guard = ConnErrorState::Pending(cx.waker().clone()); diff --git a/qconnection/src/path/raw.rs b/qconnection/src/path/raw.rs index 612d741e..b8f9087d 100644 --- a/qconnection/src/path/raw.rs +++ b/qconnection/src/path/raw.rs @@ -166,6 +166,6 @@ impl RawPath { /// Sets the receive time to the current instant. pub fn update_recv_time(&self) { - *self.state.deref().lock().unwrap() = time::Instant::now(); + *self.state.deref().lock() = time::Instant::now(); } } diff --git a/qconnection/src/path/state.rs b/qconnection/src/path/state.rs index 98e57996..315ed2b8 100644 --- a/qconnection/src/path/state.rs +++ b/qconnection/src/path/state.rs @@ -1,12 +1,13 @@ use std::{ future::Future, pin::Pin, - sync::{Arc, Mutex}, + sync::Arc, task::{Context, Poll, Waker}, time, }; use deref_derive::Deref; +use parking_lot::Mutex; use qbase::cid::ArcCidCell; use qrecovery::reliable::ArcReliableFrameDeque; @@ -42,7 +43,7 @@ impl ArcPathState { async move { loop { let now = time::Instant::now(); - let recv_time = *state.lock().unwrap(); + let recv_time = *state.lock(); // TODO: 失活时间暂定30s if now.duration_since(recv_time) >= time::Duration::from_secs(30) { state.to_inactive(cid); @@ -80,7 +81,7 @@ impl ArcPathState { pub fn to_inactive(&self, cid: ArcCidCell) { ArcCidCell::retire(&cid); - let mut guard = self.state.lock().unwrap(); + let mut guard = self.state.lock(); if let PathState::Pending(ref mut wakers) = *guard { let wakers = std::mem::take(wakers); *guard = PathState::InActive; @@ -95,7 +96,7 @@ impl Future for ArcPathState { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut guard = self.state.lock().unwrap(); + let mut guard = self.state.lock(); match *guard { PathState::Active => *guard = PathState::Pending(vec![cx.waker().clone()]), diff --git a/qconnection/src/path/util.rs b/qconnection/src/path/util.rs index 8b3771cb..d8385233 100644 --- a/qconnection/src/path/util.rs +++ b/qconnection/src/path/util.rs @@ -2,11 +2,12 @@ use std::{ future::Future, ops::{Deref, DerefMut}, pin::Pin, - sync::{Arc, Mutex}, + sync::Arc, task::{Context, Poll, Waker}, }; use bytes::BufMut; +use parking_lot::Mutex; use qbase::frame::{io::WriteFrame, BeFrame}; #[derive(Default, Clone)] @@ -14,7 +15,7 @@ pub struct SendBuffer(Arc>>); impl SendBuffer { pub fn write(&self, frame: T) { - *self.0.lock().unwrap() = Some(frame); + *self.0.lock() = Some(frame); } } @@ -24,7 +25,7 @@ where for<'a> &'a mut [u8]: WriteFrame, { pub fn try_read(&self, mut buf: &mut [u8]) -> usize { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); if let Some(frame) = guard.deref() { let size = frame.encoding_size(); if buf.remaining_mut() >= size { @@ -55,7 +56,7 @@ impl RecvBuffer { } pub fn write(&self, value: T) { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); match guard.deref() { RecvState::None => *guard = RecvState::Rcvd(value), RecvState::Pending(waker) => { @@ -92,7 +93,7 @@ impl RecvBuffer { } pub fn dismiss(&self) { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); if let RecvState::Pending(waker) = guard.deref() { waker.wake_by_ref(); } @@ -104,7 +105,7 @@ impl Future for RecvBuffer { type Output = Option; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); match std::mem::take(guard.deref_mut()) { RecvState::None | RecvState::Pending(_) => { *guard = RecvState::Pending(cx.waker().clone()); diff --git a/qconnection/src/tls.rs b/qconnection/src/tls.rs index c44b8f19..937c2e74 100644 --- a/qconnection/src/tls.rs +++ b/qconnection/src/tls.rs @@ -2,10 +2,11 @@ use std::{ future::Future, ops::DerefMut, pin::Pin, - sync::{Arc, Mutex, MutexGuard}, + sync::Arc, task::{Context, Poll, Waker}, }; +use parking_lot::{Mutex, MutexGuard}; use qbase::{ cid::ConnectionId, config::{ @@ -154,7 +155,7 @@ impl ArcTlsSession { } fn lock_guard(&self) -> MutexGuard<'_, RawTlsSession> { - self.0.lock().unwrap() + self.0.lock() } pub fn write_tls_msg(&self, plaintext: &[u8]) -> Result<(), rustls::Error> { @@ -338,7 +339,7 @@ pub struct GetParameters(Arc>); impl GetParameters { fn set_parameters(&self, parameters: Parameters) { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); let RawGetParameters::Pending(waker) = guard.deref_mut() else { return; }; @@ -347,7 +348,7 @@ impl GetParameters { } fn on_handshake_done(&self) { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); let RawGetParameters::Pending(waker) = guard.deref_mut() else { return; }; @@ -356,7 +357,7 @@ impl GetParameters { } pub fn poll_get_parameters(&self, cx: &mut Context) -> Poll> { - self.0.lock().unwrap().poll_get_parameters(cx) + self.0.lock().poll_get_parameters(cx) } } @@ -364,6 +365,6 @@ impl Future for GetParameters { type Output = Option; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.0.lock().unwrap().poll_get_parameters(cx) + self.0.lock().poll_get_parameters(cx) } } diff --git a/qrecovery/Cargo.toml b/qrecovery/Cargo.toml index d39f7605..f0b966f2 100644 --- a/qrecovery/Cargo.toml +++ b/qrecovery/Cargo.toml @@ -12,8 +12,8 @@ qbase = { workspace = true } rustls = { workspace = true } bytes = { workspace = true } thiserror = { workspace = true } -async-lock = { workspace = true } deref-derive = { workspace = true } rand = { workspace = true } log = { workspace = true } enum_dispatch = { workspace = true } +parking_lot = { workspace = true } diff --git a/qrecovery/src/recv.rs b/qrecovery/src/recv.rs index b7bea8a5..d8af837a 100644 --- a/qrecovery/src/recv.rs +++ b/qrecovery/src/recv.rs @@ -4,9 +4,10 @@ mod recver; pub mod rcvbuf; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; pub use incoming::{Incoming, IsStopped, UpdateWindow}; +use parking_lot::Mutex; pub use reader::Reader; use recver::Recver; diff --git a/qrecovery/src/recv/incoming.rs b/qrecovery/src/recv/incoming.rs index b9efe2a4..605bbd30 100644 --- a/qrecovery/src/recv/incoming.rs +++ b/qrecovery/src/recv/incoming.rs @@ -23,11 +23,11 @@ impl Incoming { } pub fn recv_data(&self, stream_frame: &StreamFrame, body: Bytes) -> Result { - let mut recver = self.0.lock().unwrap(); + let mut recver = self.0.lock(); let inner = recver.deref_mut(); let mut new_data_size = 0; - match inner { - Ok(receiving_state) => match receiving_state { + if let Ok(receiving_state) = inner { + match receiving_state { Recver::Recv(r) => { new_data_size = r.recv(stream_frame, body)?; } @@ -40,34 +40,32 @@ impl Incoming { _ => { log::debug!("ignored stream frame {:?}", stream_frame); } - }, - Err(_) => (), + } } Ok(new_data_size) } pub fn end(&self, final_size: u64) { - let mut recver = self.0.lock().unwrap(); + let mut recver = self.0.lock(); let inner = recver.deref_mut(); - match inner { - Ok(receiving_state) => match receiving_state { + if let Ok(receiving_state) = inner { + match receiving_state { Recver::Recv(r) => { *receiving_state = Recver::SizeKnown(r.determin_size(final_size)); } _ => { log::debug!("there is sth wrong, ignored finish"); } - }, - Err(_) => (), + } } } pub fn recv_reset(&self, reset_frame: &ResetStreamFrame) -> Result<(), QuicError> { // TODO: ResetStream中还有错误信息,比如http3的错误码,看是否能用到 - let mut recver = self.0.lock().unwrap(); + let mut recver = self.0.lock(); let inner = recver.deref_mut(); - match inner { - Ok(receiving_state) => match receiving_state { + if let Ok(receiving_state) = inner { + match receiving_state { Recver::Recv(r) => { let final_size = r.recv_reset(reset_frame)?; *receiving_state = Recver::ResetRcvd(final_size); @@ -80,14 +78,13 @@ impl Incoming { log::error!("there is sth wrong, ignored recv_reset"); unreachable!(); } - }, - Err(_) => (), + } } Ok(()) } pub fn on_conn_error(&self, err: &QuicError) { - let mut recver = self.0.lock().unwrap(); + let mut recver = self.0.lock(); let inner = recver.deref_mut(); match inner { Ok(receiving_state) => match receiving_state { @@ -117,7 +114,7 @@ impl Future for UpdateWindow { type Output = Option; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut recver = self.0.lock().unwrap(); + let mut recver = self.0.lock(); let inner = recver.deref_mut(); match inner { Ok(receiving_state) => match receiving_state { @@ -141,7 +138,7 @@ impl Future for IsStopped { type Output = Option; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut recver = self.0.lock().unwrap(); + let mut recver = self.0.lock(); let inner = recver.deref_mut(); match inner { Ok(receiving_state) => match receiving_state { diff --git a/qrecovery/src/recv/reader.rs b/qrecovery/src/recv/reader.rs index 8e36acd3..afdafbb1 100644 --- a/qrecovery/src/recv/reader.rs +++ b/qrecovery/src/recv/reader.rs @@ -22,10 +22,10 @@ impl Reader { /// It meaning sending a STOP_SENDING frame to peer. pub fn stop(self, error_code: u64) { debug_assert!(error_code <= VARINT_MAX); - let mut recver = self.0.lock().unwrap(); + let mut recver = self.0.lock(); let inner = recver.deref_mut(); - match inner { - Ok(receiving_state) => match receiving_state { + if let Ok(receiving_state) = inner { + match receiving_state { Recver::Recv(r) => { r.stop(error_code); } @@ -33,8 +33,7 @@ impl Reader { r.stop(error_code); } _ => (), - }, - Err(_) => (), + } } } } @@ -45,7 +44,7 @@ impl AsyncRead for Reader { cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - let mut recver = self.0.lock().unwrap(); + let mut recver = self.0.lock(); let inner = recver.deref_mut(); // 能相当清楚地看到应用层读取数据驱动的接收状态演变 match inner { @@ -79,29 +78,26 @@ impl AsyncRead for Reader { impl Drop for Reader { fn drop(&mut self) { - let mut recver = self.0.lock().unwrap(); + let mut recver = self.0.lock(); let inner = recver.deref_mut(); - match inner { - // strict mode: don't forget to call stop with the error code when an - // abnormal termination occurs, or it will panic. - Ok(receiving_state) => match receiving_state { + if let Ok(receiving_state) = inner { + match receiving_state { Recver::Recv(r) => { assert!( r.is_stopped(), r#"RecvStream in Recv State must be - stopped with error code before dropped!"# + stopped with error code before dropped!"# ) } Recver::SizeKnown(r) => { assert!( r.is_stopped(), r#"RecvStream in Recv State must be - stopped with error code before dropped!"# + stopped with error code before dropped!"# ) } _ => (), - }, - Err(_) => (), + } } } } diff --git a/qrecovery/src/recv/recver.rs b/qrecovery/src/recv/recver.rs index fbaeab71..8491d5ea 100644 --- a/qrecovery/src/recv/recver.rs +++ b/qrecovery/src/recv/recver.rs @@ -1,10 +1,11 @@ use std::{ io, - sync::{Arc, Mutex}, + sync::Arc, task::{Context, Poll, Waker}, }; use bytes::{BufMut, Bytes}; +use parking_lot::Mutex; use qbase::{ error::{Error, ErrorKind}, frame::{BeFrame, ResetStreamFrame, StreamFrame}, diff --git a/qrecovery/src/reliable.rs b/qrecovery/src/reliable.rs index f10f9034..f0252cdc 100644 --- a/qrecovery/src/reliable.rs +++ b/qrecovery/src/reliable.rs @@ -1,10 +1,8 @@ -use std::{ - collections::VecDeque, - sync::{Arc, Mutex, MutexGuard}, -}; +use std::{collections::VecDeque, sync::Arc}; use deref_derive::{Deref, DerefMut}; use enum_dispatch::enum_dispatch; +use parking_lot::{Mutex, MutexGuard}; use qbase::frame::{io::WriteFrame, BeFrame, CryptoFrame, ReliableFrame, StreamFrame}; pub mod rcvdpkt; @@ -73,7 +71,7 @@ impl ArcReliableFrameDeque { } pub fn lock_guard(&self) -> MutexGuard<'_, RawReliableFrameDeque> { - self.0.lock().unwrap() + self.0.lock() } pub fn try_read(&self, buf: &mut [u8]) -> Option<(ReliableFrame, usize)> { diff --git a/qrecovery/src/reliable/sentpkt.rs b/qrecovery/src/reliable/sentpkt.rs index e4222a34..69cca3c7 100644 --- a/qrecovery/src/reliable/sentpkt.rs +++ b/qrecovery/src/reliable/sentpkt.rs @@ -1,10 +1,7 @@ -use std::{ - collections::VecDeque, - ops::DerefMut, - sync::{Arc, Mutex, MutexGuard}, -}; +use std::{collections::VecDeque, ops::DerefMut, sync::Arc}; use deref_derive::{Deref, DerefMut}; +use parking_lot::{Mutex, MutexGuard}; use qbase::{packet::PacketNumber, util::IndexDeque, varint::VARINT_MAX}; /// 记录发送的数据包的状态,包括 @@ -138,12 +135,12 @@ impl ArcSentPktRecords { pub fn receive(&self) -> RecvGuard<'_, T> { RecvGuard { - inner: self.0.lock().unwrap(), + inner: self.0.lock(), } } pub fn send(&self) -> SendGuard<'_, T> { - let inner = self.0.lock().unwrap(); + let inner = self.0.lock(); let origin_len = inner.queue.len(); SendGuard { necessary: false, diff --git a/qrecovery/src/send.rs b/qrecovery/src/send.rs index 0913a5b0..a46897bd 100644 --- a/qrecovery/src/send.rs +++ b/qrecovery/src/send.rs @@ -1,4 +1,6 @@ -use std::sync::{Arc, Mutex}; +use std::sync::Arc; + +use parking_lot::Mutex; pub mod sndbuf; diff --git a/qrecovery/src/send/outgoing.rs b/qrecovery/src/send/outgoing.rs index 99e361f8..7fdb1da6 100644 --- a/qrecovery/src/send/outgoing.rs +++ b/qrecovery/src/send/outgoing.rs @@ -24,15 +24,10 @@ pub struct Outgoing(pub(super) ArcSender); impl Outgoing { pub fn update_window(&self, max_data_size: u64) { assert!(max_data_size <= VARINT_MAX); - let mut sender = self.0.lock().unwrap(); + let mut sender = self.0.lock(); let inner = sender.deref_mut(); - match inner { - Ok(sending_state) => { - if let Sender::Sending(s) = sending_state { - s.update_window(max_data_size); - } - } - Err(_) => (), + if let Ok(Sender::Sending(s)) = inner { + s.update_window(max_data_size); } } @@ -66,7 +61,7 @@ impl Outgoing { let predicate = |offset| { StreamFrame::estimate_max_capacity(capacity, sid, offset).map(|c| tokens.min(c)) }; - let mut sender = self.0.lock().unwrap(); + let mut sender = self.0.lock(); let inner = sender.deref_mut(); match inner { @@ -94,10 +89,10 @@ impl Outgoing { /// return true if all data has been rcvd pub fn on_data_acked(&self, range: &Range) -> bool { - let mut sender = self.0.lock().unwrap(); + let mut sender = self.0.lock(); let inner = sender.deref_mut(); - match inner { - Ok(sending_state) => match sending_state { + if let Ok(sending_state) = inner { + match sending_state { Sender::Ready(_) => { unreachable!("never send data before recv data"); } @@ -113,17 +108,16 @@ impl Outgoing { } // ignore recv _ => {} - }, - Err(_) => (), + } }; false } pub fn may_loss_data(&self, range: &Range) { - let mut sender = self.0.lock().unwrap(); + let mut sender = self.0.lock(); let inner = sender.deref_mut(); - match inner { - Ok(sending_state) => match sending_state { + if let Ok(sending_state) = inner { + match sending_state { Sender::Ready(_) => { unreachable!("never send data before recv data"); } @@ -135,14 +129,13 @@ impl Outgoing { } // ignore loss _ => (), - }, - Err(_) => (), + } }; } /// 被动stop,返回true说明成功stop了;返回false则表明流没有必要stop,要么已经完成,要么已经reset pub fn stop(&self) -> bool { - let mut sender = self.0.lock().unwrap(); + let mut sender = self.0.lock(); let inner = sender.deref_mut(); match inner { Ok(sending_state) => match sending_state { @@ -164,27 +157,26 @@ impl Outgoing { } pub fn on_reset_acked(&self) { - let mut sender = self.0.lock().unwrap(); + let mut sender = self.0.lock(); let inner = sender.deref_mut(); - match inner { - Ok(sending_state) => match sending_state { + if let Ok(sending_state) = inner { + match sending_state { Sender::ResetSent(_) | Sender::ResetRcvd => { *sending_state = Sender::ResetRcvd; } _ => { unreachable!( - "If no RESET_STREAM has been sent, how can there be a received acknowledgment?" - ); + "If no RESET_STREAM has been sent, how can there be a received acknowledgment?" + ); } - }, - Err(_) => (), + } } } /// When a connection-level error occurs, all data streams must be notified. /// Their reading and writing should be terminated, accompanied the error of the connection. pub fn on_conn_error(&self, err: &QuicError) { - let mut sender = self.0.lock().unwrap(); + let mut sender = self.0.lock(); let inner = sender.deref_mut(); match inner { Ok(sending_state) => match sending_state { @@ -210,7 +202,7 @@ impl Future for IsCancelled { type Output = Option<(u64, u64)>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut sender = self.0.lock().unwrap(); + let mut sender = self.0.lock(); let inner = sender.deref_mut(); match inner { Ok(sending_state) => match sending_state { diff --git a/qrecovery/src/send/sender.rs b/qrecovery/src/send/sender.rs index 727b905c..a415fa72 100644 --- a/qrecovery/src/send/sender.rs +++ b/qrecovery/src/send/sender.rs @@ -1,10 +1,11 @@ use std::{ io, ops::Range, - sync::{Arc, Mutex}, + sync::Arc, task::{Context, Poll, Waker}, }; +use parking_lot::Mutex; use qbase::util::DescribeData; use super::sndbuf::SendBuf; diff --git a/qrecovery/src/send/writer.rs b/qrecovery/src/send/writer.rs index d8aa09fd..9daa2ae2 100644 --- a/qrecovery/src/send/writer.rs +++ b/qrecovery/src/send/writer.rs @@ -19,7 +19,7 @@ impl AsyncWrite for Writer { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - let mut sender = self.0.lock().unwrap(); + let mut sender = self.0.lock(); let inner = sender.deref_mut(); match inner { Ok(sending_state) => match sending_state { @@ -47,7 +47,7 @@ impl AsyncWrite for Writer { } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut sender = self.0.lock().unwrap(); + let mut sender = self.0.lock(); let inner = sender.deref_mut(); match inner { Ok(sending_state) => match sending_state { @@ -75,7 +75,7 @@ impl AsyncWrite for Writer { } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut sender = self.0.lock().unwrap(); + let mut sender = self.0.lock(); let inner = sender.deref_mut(); match inner { Ok(sending_state) => match sending_state { @@ -121,10 +121,10 @@ impl AsyncWrite for Writer { impl Writer { pub fn cancel(self, err_code: u64) { - let mut sender = self.0.lock().unwrap(); + let mut sender = self.0.lock(); let inner = sender.deref_mut(); - match inner { - Ok(sending_state) => match sending_state { + if let Ok(sending_state) = inner { + match sending_state { Sender::Ready(s) => { s.cancel(err_code); } @@ -135,44 +135,40 @@ impl Writer { s.cancel(err_code); } _ => (), - }, - Err(_) => (), + } }; } } impl Drop for Writer { fn drop(&mut self) { - let mut sender = self.0.lock().unwrap(); + let mut sender = self.0.lock(); let inner = sender.deref_mut(); - match inner { - // strict mode: don't forget to call cancel with the error code when an - // abnormal termination occurs, or it will panic. - Ok(sending_state) => match sending_state { + if let Ok(sending_state) = inner { + match sending_state { Sender::Ready(s) => { assert!( s.is_cancelled(), "SendingStream in Ready State must be - cancelled with error code before dropped!" + cancelled with error code before dropped!" ); } Sender::Sending(s) => { assert!( s.is_cancelled(), "SendingStream in Sending State must be - cancelled with error code before dropped!" + cancelled with error code before dropped!" ); } Sender::DataSent(s) => { assert!( s.is_cancelled(), "SendingStream in DataSent State must be - cancelled with error code before dropped!" + cancelled with error code before dropped!" ); } _ => (), - }, - Err(_) => (), + } }; } } diff --git a/qrecovery/src/streams/crypto.rs b/qrecovery/src/streams/crypto.rs index 33a9ba93..474c06df 100644 --- a/qrecovery/src/streams/crypto.rs +++ b/qrecovery/src/streams/crypto.rs @@ -2,11 +2,12 @@ mod send { use std::{ io, pin::Pin, - sync::{Arc, Mutex}, + sync::Arc, task::{Context, Poll, Waker}, }; use bytes::BufMut; + use parking_lot::Mutex; use qbase::{ frame::{io::WriteDataFrame, CryptoFrame}, util::DescribeData, @@ -99,11 +100,11 @@ mod send { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - self.0.lock().unwrap().poll_write(cx, buf) + self.0.lock().poll_write(cx, buf) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.0.lock().unwrap().poll_flush(cx) + self.0.lock().poll_flush(cx) } fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { @@ -114,15 +115,15 @@ mod send { impl CryptoStreamOutgoing { pub fn try_read_data(&self, buffer: &mut [u8]) -> Option<(CryptoFrame, usize)> { - self.0.lock().unwrap().try_read_data(buffer) + self.0.lock().try_read_data(buffer) } pub fn on_data_acked(&self, crypto_frame: &CryptoFrame) { - self.0.lock().unwrap().on_data_acked(crypto_frame) + self.0.lock().on_data_acked(crypto_frame) } pub fn may_loss_data(&self, crypto_frame: &CryptoFrame) { - self.0.lock().unwrap().may_loss_data(crypto_frame) + self.0.lock().may_loss_data(crypto_frame) } } @@ -139,11 +140,12 @@ mod recv { use std::{ io, pin::Pin, - sync::{Arc, Mutex}, + sync::Arc, task::{Context, Poll, Waker}, }; use bytes::{BufMut, Bytes}; + use parking_lot::Mutex; use qbase::{ error::Error, frame::{CryptoFrame, ReceiveFrame}, @@ -199,7 +201,7 @@ mod recv { cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - self.0.lock().unwrap().poll_read(cx, buf) + self.0.lock().poll_read(cx, buf) } } @@ -210,10 +212,7 @@ mod recv { &mut self, (frame, data): &(CryptoFrame, Bytes), ) -> Result { - self.0 - .lock() - .unwrap() - .recv(frame.offset.into(), data.clone()); + self.0.lock().recv(frame.offset.into(), data.clone()); Ok(()) } } diff --git a/qrecovery/src/streams/data.rs b/qrecovery/src/streams/data.rs index df6df19f..925e33a2 100644 --- a/qrecovery/src/streams/data.rs +++ b/qrecovery/src/streams/data.rs @@ -1,10 +1,11 @@ use std::{ collections::{BTreeMap, HashMap}, - sync::{Arc, Mutex, MutexGuard}, + sync::Arc, task::{ready, Context, Poll}, }; use deref_derive::{Deref, DerefMut}; +use parking_lot::{Mutex, MutexGuard}; use qbase::{ config::Parameters, error::{Error as QuicError, ErrorKind}, @@ -44,7 +45,7 @@ impl Default for ArcOutput { impl ArcOutput { fn guard(&self) -> Result { - let guard = self.0.lock().unwrap(); + let guard = self.0.lock(); match guard.as_ref() { Ok(_) => Ok(ArcOutputGuard { inner: guard }), Err(e) => Err(e.clone()), @@ -88,7 +89,7 @@ impl Default for ArcInput { impl ArcInput { fn guard(&self) -> Result { - let guard = self.0.lock().unwrap(); + let guard = self.0.lock(); match guard.as_ref() { Ok(_) => Ok(ArcInputGuard { inner: guard }), Err(e) => Err(e.clone()), @@ -143,7 +144,7 @@ impl ArcDataStreamParameters { } fn guard(&self) -> MutexGuard<'_, RawDataStreamParameters> { - self.0.lock().unwrap() + self.0.lock() } fn apply_transport_parameters(&self, params: &Parameters) { @@ -187,7 +188,7 @@ impl RawDataStreams { buf: &mut [u8], flow_limit: usize, ) -> Option<(StreamFrame, usize, usize)> { - let guard = &mut self.output.0.lock().unwrap(); + let guard = &mut self.output.0.lock(); let output = guard.as_mut().ok()?; const DEFAULT_TOKENS: usize = 4096; @@ -225,7 +226,7 @@ impl RawDataStreams { } pub fn on_data_acked(&self, frame: StreamFrame) { - if let Ok(set) = self.output.0.lock().unwrap().as_mut() { + if let Ok(set) = self.output.0.lock().as_mut() { if set .get(&frame.id) .map(|o| o.on_data_acked(&frame.range())) @@ -241,7 +242,6 @@ impl RawDataStreams { .output .0 .lock() - .unwrap() .as_mut() .ok() .and_then(|set| set.get(&stream_frame.id)) @@ -251,7 +251,7 @@ impl RawDataStreams { } pub fn on_reset_acked(&self, reset_frame: ResetStreamFrame) { - if let Ok(set) = self.output.0.lock().unwrap().as_mut() { + if let Ok(set) = self.output.0.lock().as_mut() { if let Some(o) = set.remove(&reset_frame.stream_id) { o.on_reset_acked(); } @@ -283,7 +283,6 @@ impl RawDataStreams { .input .0 .lock() - .unwrap() .as_mut() .ok() .and_then(|set| set.get(&sid)) @@ -314,7 +313,7 @@ impl RawDataStreams { )); } } - if let Ok(set) = self.input.0.lock().unwrap().as_mut() { + if let Ok(set) = self.input.0.lock().as_mut() { if let Some(incoming) = set.remove(&sid) { incoming.recv_reset(reset)?; } @@ -339,7 +338,6 @@ impl RawDataStreams { .output .0 .lock() - .unwrap() .as_mut() .ok() .and_then(|set| set.get(&sid)) @@ -376,7 +374,6 @@ impl RawDataStreams { .output .0 .lock() - .unwrap() .as_ref() .ok() .and_then(|set| set.get(&sid)) diff --git a/qrecovery/src/streams/listener.rs b/qrecovery/src/streams/listener.rs index 04cae154..ca5b2338 100644 --- a/qrecovery/src/streams/listener.rs +++ b/qrecovery/src/streams/listener.rs @@ -2,10 +2,11 @@ use std::{ collections::VecDeque, future::Future, pin::Pin, - sync::{Arc, Mutex, MutexGuard}, + sync::Arc, task::{Context, Poll, Waker}, }; +use parking_lot::{Mutex, MutexGuard}; use qbase::error::Error as QuicError; use crate::{recv::Reader, send::Writer}; @@ -67,7 +68,7 @@ impl Default for ArcListener { impl ArcListener { pub(crate) fn guard(&self) -> Result { - let guard = self.0.lock().unwrap(); + let guard = self.0.lock(); match guard.as_ref() { Ok(_) => Ok(ListenerGuard { inner: guard }), Err(e) => Err(e.clone()), @@ -90,14 +91,14 @@ impl ArcListener { &self, cx: &mut Context<'_>, ) -> Poll> { - match self.0.lock().unwrap().as_mut() { + match self.0.lock().as_mut() { Ok(set) => set.poll_accept_bi_stream(cx), Err(e) => Poll::Ready(Err(e.clone())), } } pub fn poll_accept_recv_stream(&self, cx: &mut Context<'_>) -> Poll> { - match self.0.lock().unwrap().as_mut() { + match self.0.lock().as_mut() { Ok(set) => set.poll_accept_recv_stream(cx), Err(e) => Poll::Ready(Err(e.clone())), } diff --git a/qudp/Cargo.toml b/qudp/Cargo.toml index 2eb3e0d8..475c40d2 100644 --- a/qudp/Cargo.toml +++ b/qudp/Cargo.toml @@ -11,6 +11,8 @@ socket2 = { workspace = true } libc = { workspace = true } tokio = { workspace = true } log = { workspace = true } +parking_lot = { workspace = true } + [dev-dependencies] env_logger = "0" diff --git a/qudp/src/lib.rs b/qudp/src/lib.rs index 75a0ee96..276a66cc 100644 --- a/qudp/src/lib.rs +++ b/qudp/src/lib.rs @@ -3,11 +3,12 @@ use std::{ future::Future, io::{self, IoSlice, IoSliceMut}, net::SocketAddr, - sync::{Arc, Mutex}, + sync::Arc, task::{ready, Context, Poll}, }; use msg::Encoder; +use parking_lot::Mutex; use socket2::{Domain, Socket, Type}; use tokio::io::Interest; use unix::DEFAULT_TTL; @@ -146,7 +147,7 @@ impl ArcUsc { hdr: &PacketHeader, cx: &mut Context, ) -> Poll> { - let controller = self.0.lock().unwrap(); + let controller = self.0.lock(); ready!(controller.io.poll_send_ready(cx))?; let ret = controller .io @@ -161,7 +162,7 @@ impl ArcUsc { hdrs: &mut [PacketHeader], cx: &mut Context, ) -> Poll> { - let controller = self.0.lock().unwrap(); + let controller = self.0.lock(); ready!(controller.io.poll_recv_ready(cx))?; let ret = controller .io @@ -176,20 +177,20 @@ impl ArcUsc { } pub fn ttl(&self) -> u8 { - self.0.lock().unwrap().ttl + self.0.lock().ttl } pub fn set_ttl(&self, ttl: u8) -> io::Result<()> { - self.0.lock().unwrap().set_ttl(ttl) + self.0.lock().set_ttl(ttl) } pub fn local_addr(&self) -> SocketAddr { - self.0.lock().unwrap().local_addr() + self.0.lock().local_addr() } // Send synchronously, usc saves a small amount of data packets,and USC sends internal asynchronous tasks pub fn sync_send(&self, packet: Vec, hdr: PacketHeader) -> io::Result<()> { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); if guard.bufs.len() >= BUFFER_CAPACITY { return Err(io::Error::new(io::ErrorKind::WouldBlock, "buffer full")); } @@ -237,7 +238,7 @@ impl<'a> Future for Sender<'a> { type Output = io::Result; fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let usc = self.usc.0.lock().unwrap(); + let usc = self.usc.0.lock(); ready!(usc.io.poll_send_ready(cx))?; let ret = usc .io @@ -253,7 +254,7 @@ impl Future for SyncGuard { type Output = io::Result; fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut usc = self.0 .0.lock().unwrap(); + let mut usc = self.0 .0.lock(); if let Some((pkt, hdr)) = usc.bufs.pop_front() { ready!(usc.io.poll_send_ready(cx))?; let ret = usc.io.try_io(Interest::WRITABLE, || { diff --git a/quic/Cargo.toml b/quic/Cargo.toml index db065a1d..41ebf7ab 100644 --- a/quic/Cargo.toml +++ b/quic/Cargo.toml @@ -21,6 +21,7 @@ rustls-pemfile = { workspace = true } log = { workspace = true } deref-derive = { workspace = true } dashmap = { workspace = true } +parking_lot = { workspace = true } [dev-dependencies] env_logger = { workspace = true } diff --git a/quic/src/lib.rs b/quic/src/lib.rs index d85a4ca2..d3d12ac7 100644 --- a/quic/src/lib.rs +++ b/quic/src/lib.rs @@ -2,12 +2,13 @@ use std::{ future::Future, io, net::SocketAddr, - sync::{Arc, LazyLock, Mutex}, + sync::{Arc, LazyLock}, task::{Poll, Waker}, }; use bytes::BytesMut; use dashmap::DashMap; +use parking_lot::Mutex; use qbase::{ cid::{ConnectionId, MAX_CID_SIZE}, packet::{ @@ -196,8 +197,8 @@ impl Acceptor { } fn accept(&self, value: (QuicConnection, SocketAddr)) { - *self.0.lock().unwrap() = Some(value); - if let Some(waker) = self.1.lock().unwrap().take() { + *self.0.lock() = Some(value); + if let Some(waker) = self.1.lock().take() { waker.wake(); } } @@ -208,7 +209,7 @@ impl Future for Acceptor { fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { let this = self.get_mut(); - if let Some(value) = this.0.lock().unwrap().take() { + if let Some(value) = this.0.lock().take() { Poll::Ready(value) } else { this.1 = Arc::new(Mutex::new(Some(cx.waker().clone()))); diff --git a/qunreliable/Cargo.toml b/qunreliable/Cargo.toml index 5a8ad5fc..d40b96b6 100644 --- a/qunreliable/Cargo.toml +++ b/qunreliable/Cargo.toml @@ -8,3 +8,4 @@ tokio = { workspace = true } bytes = { workspace = true } qbase = { workspace = true } futures = { workspace = true } +parking_lot = { workspace = true } diff --git a/qunreliable/src/flow.rs b/qunreliable/src/flow.rs index 5933e0ca..f9d0f3cf 100644 --- a/qunreliable/src/flow.rs +++ b/qunreliable/src/flow.rs @@ -1,5 +1,6 @@ -use std::sync::{Arc, Mutex}; +use std::sync::Arc; +use parking_lot::Mutex; use qbase::{ config::Parameters, error::Error, diff --git a/qunreliable/src/reader.rs b/qunreliable/src/reader.rs index 706e5ee3..06bbadf3 100644 --- a/qunreliable/src/reader.rs +++ b/qunreliable/src/reader.rs @@ -4,11 +4,12 @@ use std::{ io, ops::DerefMut, pin::Pin, - sync::{Arc, Mutex}, + sync::Arc, task::{Context, Poll, Waker}, }; use bytes::{BufMut, Bytes}; +use parking_lot::Mutex; use qbase::{ error::{Error, ErrorKind}, frame::{BeFrame, DatagramFrame}, @@ -59,7 +60,7 @@ impl DatagramReader { frame: &DatagramFrame, data: bytes::Bytes, ) -> Result<(), Error> { - let reader = &mut self.0.lock().unwrap(); + let reader = &mut self.0.lock(); let inner = reader.deref_mut(); let Ok(reader) = inner else { return Ok(()); @@ -96,7 +97,7 @@ impl DatagramReader { /// /// if the connection is already closed, the new error will be ignored. pub(super) fn on_conn_error(&self, error: &Error) { - let reader = &mut self.0.lock().unwrap(); + let reader = &mut self.0.lock(); let inner = reader.deref_mut(); if let Ok(reader) = inner { reader.wakers.drain(..).for_each(|waker| waker.wake()); @@ -158,7 +159,7 @@ impl DatagramReader { /// /// Return [`Err`] when the connection is closing or already closed pub fn get_local_max_datagram_frame_size(&self) -> io::Result { - let reader = self.0.lock().unwrap(); + let reader = self.0.lock(); match &*reader { Ok(reader) => Ok(reader.local_max_size), Err(error) => Err(io::Error::new(io::ErrorKind::BrokenPipe, error.to_string())), @@ -178,7 +179,7 @@ impl Future for ReadIntoSlice<'_> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let s = self.get_mut(); - let mut reader = s.reader.lock().unwrap(); + let mut reader = s.reader.lock(); match reader.deref_mut() { Ok(reader) => match reader.queue.pop_front() { Some(bytes) => { @@ -210,7 +211,7 @@ where fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let s = self.get_mut(); - let mut reader = s.reader.lock().unwrap(); + let mut reader = s.reader.lock(); match reader.deref_mut() { Ok(reader) => match reader.queue.pop_front() { Some(bytes) => { diff --git a/qunreliable/src/writer.rs b/qunreliable/src/writer.rs index 5520d2be..a52c53bf 100644 --- a/qunreliable/src/writer.rs +++ b/qunreliable/src/writer.rs @@ -1,11 +1,7 @@ -use std::{ - collections::VecDeque, - io, - ops::DerefMut, - sync::{Arc, Mutex}, -}; +use std::{collections::VecDeque, io, ops::DerefMut, sync::Arc}; use bytes::Bytes; +use parking_lot::Mutex; use qbase::{ error::{Error, ErrorKind}, frame::{io::WriteDataFrame, BeFrame, DatagramFrame, FrameType}, @@ -47,7 +43,7 @@ impl DatagramWriter { /// contain the datagram. /// pub(super) fn try_read_datagram(&self, mut buf: &mut [u8]) -> Option<(DatagramFrame, usize)> { - let mut guard = self.0.lock().unwrap(); + let mut guard = self.0.lock(); let writer = guard.as_mut().ok()?; let datagram = writer.queue.front()?; @@ -91,7 +87,7 @@ impl DatagramWriter { /// /// if the connection is already closed, the new error will be ignored. pub(super) fn on_conn_error(&self, error: &Error) { - let writer = &mut self.0.lock().unwrap(); + let writer = &mut self.0.lock(); if writer.is_ok() { **writer = Err(io::Error::new(io::ErrorKind::BrokenPipe, error.to_string())); } @@ -111,7 +107,7 @@ impl DatagramWriter { /// /// Return [`Err`] when the connection is closing or already closed pub fn send_bytes(&self, data: Bytes) -> io::Result<()> { - match self.0.lock().unwrap().deref_mut() { + match self.0.lock().deref_mut() { Ok(writer) => { // 这里只考虑最小的编码方式:也就是1字节 if (1 + data.len()) > writer.remote_max_size { @@ -157,7 +153,7 @@ impl DatagramWriter { /// The value may have been set by a previous connection. This method will return [`Err`] when the new size /// is less than the previous size, or the current connection is closing or already closed. pub(crate) fn update_remote_max_datagram_frame_size(&self, size: usize) -> Result<(), Error> { - let mut writer = self.0.lock().unwrap(); + let mut writer = self.0.lock(); let inner = writer.deref_mut(); if let Ok(writer) = inner { @@ -179,7 +175,7 @@ impl DatagramWriter { /// /// Return [`Err`] when the connection is closing or already closed pub fn get_remote_max_datagram_frame_size(&self) -> io::Result { - let reader = self.0.lock().unwrap(); + let reader = self.0.lock(); match &*reader { Ok(reader) => Ok(reader.remote_max_size), Err(error) => Err(io::Error::new(io::ErrorKind::BrokenPipe, error.to_string())), @@ -295,7 +291,7 @@ mod tests { let writer = DatagramWriter(writer); writer.update_remote_max_datagram_frame_size(2048).unwrap(); - let writer_guard = writer.0.lock().unwrap(); + let writer_guard = writer.0.lock(); let writer = writer_guard.as_ref().unwrap(); assert_eq!(writer.remote_max_size, 2048); } @@ -319,7 +315,7 @@ mod tests { FrameType::Datagram(0), "test", )); - let writer_guard = writer.0.lock().unwrap(); + let writer_guard = writer.0.lock(); assert!(writer_guard.as_ref().is_err()); } }