6
6
7
7
use crate :: io:: { AsyncRead , AsyncWrite , ReadBuf } ;
8
8
9
- use std:: cell:: UnsafeCell ;
10
9
use std:: fmt;
11
10
use std:: io;
12
11
use std:: pin:: Pin ;
13
- use std:: sync:: atomic:: AtomicBool ;
14
- use std:: sync:: atomic:: Ordering :: { Acquire , Release } ;
15
12
use std:: sync:: Arc ;
13
+ use std:: sync:: Mutex ;
16
14
use std:: task:: { Context , Poll } ;
17
15
18
16
cfg_io_util ! {
@@ -38,8 +36,7 @@ cfg_io_util! {
38
36
let is_write_vectored = stream. is_write_vectored( ) ;
39
37
40
38
let inner = Arc :: new( Inner {
41
- locked: AtomicBool :: new( false ) ,
42
- stream: UnsafeCell :: new( stream) ,
39
+ stream: Mutex :: new( stream) ,
43
40
is_write_vectored,
44
41
} ) ;
45
42
@@ -54,13 +51,19 @@ cfg_io_util! {
54
51
}
55
52
56
53
struct Inner < T > {
57
- locked : AtomicBool ,
58
- stream : UnsafeCell < T > ,
54
+ stream : Mutex < T > ,
59
55
is_write_vectored : bool ,
60
56
}
61
57
62
- struct Guard < ' a , T > {
63
- inner : & ' a Inner < T > ,
58
+ impl < T > Inner < T > {
59
+ fn with_lock < R > ( & self , f : impl FnOnce ( Pin < & mut T > ) -> R ) -> R {
60
+ let mut guard = self . stream . lock ( ) . unwrap ( ) ;
61
+
62
+ // safety: we do not move the stream.
63
+ let stream = unsafe { Pin :: new_unchecked ( & mut * guard) } ;
64
+
65
+ f ( stream)
66
+ }
64
67
}
65
68
66
69
impl < T > ReadHalf < T > {
@@ -90,7 +93,7 @@ impl<T> ReadHalf<T> {
90
93
. ok ( )
91
94
. expect ( "`Arc::try_unwrap` failed" ) ;
92
95
93
- inner. stream . into_inner ( )
96
+ inner. stream . into_inner ( ) . unwrap ( )
94
97
} else {
95
98
panic ! ( "Unrelated `split::Write` passed to `split::Read::unsplit`." )
96
99
}
@@ -111,8 +114,7 @@ impl<T: AsyncRead> AsyncRead for ReadHalf<T> {
111
114
cx : & mut Context < ' _ > ,
112
115
buf : & mut ReadBuf < ' _ > ,
113
116
) -> Poll < io:: Result < ( ) > > {
114
- let mut inner = ready ! ( self . inner. poll_lock( cx) ) ;
115
- inner. stream_pin ( ) . poll_read ( cx, buf)
117
+ self . inner . with_lock ( |stream| stream. poll_read ( cx, buf) )
116
118
}
117
119
}
118
120
@@ -122,67 +124,31 @@ impl<T: AsyncWrite> AsyncWrite for WriteHalf<T> {
122
124
cx : & mut Context < ' _ > ,
123
125
buf : & [ u8 ] ,
124
126
) -> Poll < Result < usize , io:: Error > > {
125
- let mut inner = ready ! ( self . inner. poll_lock( cx) ) ;
126
- inner. stream_pin ( ) . poll_write ( cx, buf)
127
+ self . inner . with_lock ( |stream| stream. poll_write ( cx, buf) )
127
128
}
128
129
129
130
fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , io:: Error > > {
130
- let mut inner = ready ! ( self . inner. poll_lock( cx) ) ;
131
- inner. stream_pin ( ) . poll_flush ( cx)
131
+ self . inner . with_lock ( |stream| stream. poll_flush ( cx) )
132
132
}
133
133
134
134
fn poll_shutdown ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , io:: Error > > {
135
- let mut inner = ready ! ( self . inner. poll_lock( cx) ) ;
136
- inner. stream_pin ( ) . poll_shutdown ( cx)
135
+ self . inner . with_lock ( |stream| stream. poll_shutdown ( cx) )
137
136
}
138
137
139
138
fn poll_write_vectored (
140
139
self : Pin < & mut Self > ,
141
140
cx : & mut Context < ' _ > ,
142
141
bufs : & [ io:: IoSlice < ' _ > ] ,
143
142
) -> Poll < Result < usize , io:: Error > > {
144
- let mut inner = ready ! ( self . inner. poll_lock ( cx ) ) ;
145
- inner . stream_pin ( ) . poll_write_vectored ( cx, bufs)
143
+ self . inner
144
+ . with_lock ( |stream| stream . poll_write_vectored ( cx, bufs) )
146
145
}
147
146
148
147
fn is_write_vectored ( & self ) -> bool {
149
148
self . inner . is_write_vectored
150
149
}
151
150
}
152
151
153
- impl < T > Inner < T > {
154
- fn poll_lock ( & self , cx : & mut Context < ' _ > ) -> Poll < Guard < ' _ , T > > {
155
- if self
156
- . locked
157
- . compare_exchange ( false , true , Acquire , Acquire )
158
- . is_ok ( )
159
- {
160
- Poll :: Ready ( Guard { inner : self } )
161
- } else {
162
- // Spin... but investigate a better strategy
163
-
164
- std:: thread:: yield_now ( ) ;
165
- cx. waker ( ) . wake_by_ref ( ) ;
166
-
167
- Poll :: Pending
168
- }
169
- }
170
- }
171
-
172
- impl < T > Guard < ' _ , T > {
173
- fn stream_pin ( & mut self ) -> Pin < & mut T > {
174
- // safety: the stream is pinned in `Arc` and the `Guard` ensures mutual
175
- // exclusion.
176
- unsafe { Pin :: new_unchecked ( & mut * self . inner . stream . get ( ) ) }
177
- }
178
- }
179
-
180
- impl < T > Drop for Guard < ' _ , T > {
181
- fn drop ( & mut self ) {
182
- self . inner . locked . store ( false , Release ) ;
183
- }
184
- }
185
-
186
152
unsafe impl < T : Send > Send for ReadHalf < T > { }
187
153
unsafe impl < T : Send > Send for WriteHalf < T > { }
188
154
unsafe impl < T : Sync > Sync for ReadHalf < T > { }
0 commit comments