16
16
std:: {
17
17
borrow:: Borrow ,
18
18
io,
19
- iter:: repeat,
20
19
net:: { SocketAddr , UdpSocket } ,
21
20
} ,
22
21
thiserror:: Error ,
@@ -35,11 +34,15 @@ impl From<SendPktsError> for TransportError {
35
34
}
36
35
}
37
36
37
+ // The type and lifetime constraints are overspecified to match 'linux' code.
38
38
#[ cfg( not( target_os = "linux" ) ) ]
39
- pub fn batch_send < S , T > ( sock : & UdpSocket , packets : & [ ( T , S ) ] ) -> Result < ( ) , SendPktsError >
39
+ pub fn batch_send < ' a , S , T : ' a + ?Sized > (
40
+ sock : & UdpSocket ,
41
+ packets : impl IntoIterator < Item = ( & ' a T , S ) , IntoIter : ExactSizeIterator > ,
42
+ ) -> Result < ( ) , SendPktsError >
40
43
where
41
44
S : Borrow < SocketAddr > ,
42
- T : AsRef < [ u8 ] > ,
45
+ & ' a T : AsRef < [ u8 ] > ,
43
46
{
44
47
let mut num_failed = 0 ;
45
48
let mut erropt = None ;
@@ -158,12 +161,17 @@ fn sendmmsg_retry(sock: &UdpSocket, hdrs: &mut [mmsghdr]) -> Result<(), SendPkts
158
161
const MAX_IOV : usize = libc:: UIO_MAXIOV as usize ;
159
162
160
163
#[ cfg( target_os = "linux" ) ]
161
- pub fn batch_send_max_iov < S , T > ( sock : & UdpSocket , packets : & [ ( T , S ) ] ) -> Result < ( ) , SendPktsError >
164
+ fn batch_send_max_iov < ' a , S , T : ' a + ?Sized > (
165
+ sock : & UdpSocket ,
166
+ packets : impl IntoIterator < Item = ( & ' a T , S ) , IntoIter : ExactSizeIterator > ,
167
+ ) -> Result < ( ) , SendPktsError >
162
168
where
163
169
S : Borrow < SocketAddr > ,
164
- T : AsRef < [ u8 ] > ,
170
+ & ' a T : AsRef < [ u8 ] > ,
165
171
{
166
- assert ! ( packets. len( ) <= MAX_IOV ) ;
172
+ let packets = packets. into_iter ( ) ;
173
+ let num_packets = packets. len ( ) ;
174
+ debug_assert ! ( num_packets <= MAX_IOV ) ;
167
175
168
176
let mut iovs = [ MaybeUninit :: uninit ( ) ; MAX_IOV ] ;
169
177
let mut addrs = [ MaybeUninit :: uninit ( ) ; MAX_IOV ] ;
@@ -177,13 +185,13 @@ where
177
185
// SAFETY: The first `packets.len()` elements of `hdrs`, `iovs`, and `addrs` are
178
186
// guaranteed to be initialized by `mmsghdr_for_packet` before this loop.
179
187
let hdrs_slice =
180
- unsafe { std:: slice:: from_raw_parts_mut ( hdrs. as_mut_ptr ( ) as * mut mmsghdr , packets . len ( ) ) } ;
188
+ unsafe { std:: slice:: from_raw_parts_mut ( hdrs. as_mut_ptr ( ) as * mut mmsghdr , num_packets ) } ;
181
189
182
190
let result = sendmmsg_retry ( sock, hdrs_slice) ;
183
191
184
192
// SAFETY: The first `packets.len()` elements of `hdrs`, `iovs`, and `addrs` are
185
193
// guaranteed to be initialized by `mmsghdr_for_packet` before this loop.
186
- for ( hdr, iov, addr) in izip ! ( & mut hdrs, & mut iovs, & mut addrs) . take ( packets . len ( ) ) {
194
+ for ( hdr, iov, addr) in izip ! ( & mut hdrs, & mut iovs, & mut addrs) . take ( num_packets ) {
187
195
unsafe {
188
196
hdr. assume_init_drop ( ) ;
189
197
iov. assume_init_drop ( ) ;
@@ -194,13 +202,23 @@ where
194
202
result
195
203
}
196
204
205
+ // Need &'a to ensure that raw packet pointers obtained in mmsghdr_for_packet
206
+ // stay valid.
197
207
#[ cfg( target_os = "linux" ) ]
198
- pub fn batch_send < S , T > ( sock : & UdpSocket , packets : & [ ( T , S ) ] ) -> Result < ( ) , SendPktsError >
208
+ pub fn batch_send < ' a , S , T : ' a + ?Sized > (
209
+ sock : & UdpSocket ,
210
+ packets : impl IntoIterator < Item = ( & ' a T , S ) , IntoIter : ExactSizeIterator > ,
211
+ ) -> Result < ( ) , SendPktsError >
199
212
where
200
213
S : Borrow < SocketAddr > ,
201
- T : AsRef < [ u8 ] > ,
214
+ & ' a T : AsRef < [ u8 ] > ,
202
215
{
203
- for chunk in packets. chunks ( MAX_IOV ) {
216
+ let mut packets = packets. into_iter ( ) ;
217
+ loop {
218
+ let chunk = packets. by_ref ( ) . take ( MAX_IOV ) ;
219
+ if chunk. len ( ) == 0 {
220
+ break ;
221
+ }
204
222
batch_send_max_iov ( sock, chunk) ?;
205
223
}
206
224
Ok ( ( ) )
@@ -216,8 +234,8 @@ where
216
234
T : AsRef < [ u8 ] > ,
217
235
{
218
236
let dests = dests. iter ( ) . map ( Borrow :: borrow) ;
219
- let pkts: Vec < _ > = repeat ( & packet) . zip ( dests ) . collect ( ) ;
220
- batch_send ( sock, & pkts)
237
+ let pkts = dests . map ( |addr| ( & packet, addr ) ) ;
238
+ batch_send ( sock, pkts)
221
239
}
222
240
223
241
#[ cfg( test) ]
@@ -246,7 +264,7 @@ mod tests {
246
264
let packets: Vec < _ > = ( 0 ..32 ) . map ( |_| vec ! [ 0u8 ; PACKET_DATA_SIZE ] ) . collect ( ) ;
247
265
let packet_refs: Vec < _ > = packets. iter ( ) . map ( |p| ( & p[ ..] , & addr) ) . collect ( ) ;
248
266
249
- let sent = batch_send ( & sender, & packet_refs[ .. ] ) . ok ( ) ;
267
+ let sent = batch_send ( & sender, packet_refs) . ok ( ) ;
250
268
assert_eq ! ( sent, Some ( ( ) ) ) ;
251
269
252
270
let mut packets = vec ! [ Packet :: default ( ) ; 32 ] ;
@@ -277,7 +295,7 @@ mod tests {
277
295
} )
278
296
. collect ( ) ;
279
297
280
- let sent = batch_send ( & sender, & packet_refs[ .. ] ) . ok ( ) ;
298
+ let sent = batch_send ( & sender, packet_refs) . ok ( ) ;
281
299
assert_eq ! ( sent, Some ( ( ) ) ) ;
282
300
283
301
let mut packets = vec ! [ Packet :: default ( ) ; 32 ] ;
@@ -345,7 +363,7 @@ mod tests {
345
363
let dest_refs: Vec < _ > = vec ! [ & ip4, & ip6, & ip4] ;
346
364
347
365
let sender = bind_to_unspecified ( ) . expect ( "bind" ) ;
348
- let res = batch_send ( & sender, & packet_refs[ .. ] ) ;
366
+ let res = batch_send ( & sender, packet_refs) ;
349
367
assert_matches ! ( res, Err ( SendPktsError :: IoError ( _, /*num_failed*/ 1 ) ) ) ;
350
368
let res = multi_target_send ( & sender, & packets[ 0 ] , & dest_refs) ;
351
369
assert_matches ! ( res, Err ( SendPktsError :: IoError ( _, /*num_failed*/ 1 ) ) ) ;
@@ -366,7 +384,7 @@ mod tests {
366
384
( & packets[ 3 ] [ ..] , & ipv4broadcast) ,
367
385
( & packets[ 4 ] [ ..] , & ipv4local) ,
368
386
] ;
369
- match batch_send ( & sender, & packet_refs[ .. ] ) {
387
+ match batch_send ( & sender, packet_refs) {
370
388
Ok ( ( ) ) => panic ! ( ) ,
371
389
Err ( SendPktsError :: IoError ( ioerror, num_failed) ) => {
372
390
assert_matches ! ( ioerror. kind( ) , ErrorKind :: PermissionDenied ) ;
@@ -382,7 +400,7 @@ mod tests {
382
400
( & packets[ 3 ] [ ..] , & ipv4local) ,
383
401
( & packets[ 4 ] [ ..] , & ipv4broadcast) ,
384
402
] ;
385
- match batch_send ( & sender, & packet_refs[ .. ] ) {
403
+ match batch_send ( & sender, packet_refs) {
386
404
Ok ( ( ) ) => panic ! ( ) ,
387
405
Err ( SendPktsError :: IoError ( ioerror, num_failed) ) => {
388
406
assert_matches ! ( ioerror. kind( ) , ErrorKind :: PermissionDenied ) ;
@@ -398,7 +416,7 @@ mod tests {
398
416
( & packets[ 3 ] [ ..] , & ipv4broadcast) ,
399
417
( & packets[ 4 ] [ ..] , & ipv4local) ,
400
418
] ;
401
- match batch_send ( & sender, & packet_refs[ .. ] ) {
419
+ match batch_send ( & sender, packet_refs) {
402
420
Ok ( ( ) ) => panic ! ( ) ,
403
421
Err ( SendPktsError :: IoError ( ioerror, num_failed) ) => {
404
422
assert_matches ! ( ioerror. kind( ) , ErrorKind :: PermissionDenied ) ;
0 commit comments