-
Notifications
You must be signed in to change notification settings - Fork 75
/
tcp.rs
326 lines (280 loc) · 11.1 KB
/
tcp.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
pub use socket2::{TcpKeepalive};
use crate::network::adapter::{
Resource, Remote, Local, Adapter, SendStatus, AcceptedType, ReadStatus, ConnectionInfo,
ListeningInfo, PendingStatus,
};
use crate::network::{RemoteAddr, Readiness, TransportConnect, TransportListen};
use mio::net::{TcpListener, TcpStream};
use mio::event::{Source};
use socket2::{Socket, Domain, Type, Protocol};
use std::net::{SocketAddr};
#[cfg(unix)]
use std::ffi::{CString};
use std::io::{self, ErrorKind, Read, Write};
#[cfg(target_os = "macos")]
use std::num::NonZeroU32;
use std::mem::{forget, MaybeUninit};
use std::os::raw::c_int;
#[cfg(target_os = "windows")]
use std::os::windows::io::{FromRawSocket, AsRawSocket};
#[cfg(not(target_os = "windows"))]
use std::os::{fd::AsRawFd, unix::io::FromRawFd};
/// Size of the internal reading buffer.
/// It implies that at most the generated [`crate::network::NetEvent::Message`]
/// will contains a chunk of data of this value.
pub const INPUT_BUFFER_SIZE: usize = u16::MAX as usize; // 2^16 - 1
/// The maximum length of the pending (unaccepted) connection queue of a listener.
pub const LISTENER_BACKLOG: c_int = 1024;
#[derive(Clone, Debug, Default)]
pub struct TcpConnectConfig {
bind_device: Option<String>,
source_address: Option<SocketAddr>,
keepalive: Option<TcpKeepalive>,
}
impl TcpConnectConfig {
/// Bind the TCP connection to a specific interface, identified by its name. This option works
/// in Unix, on other systems, it will be ignored.
pub fn with_bind_device(mut self, device: String) -> Self {
self.bind_device = Some(device);
self
}
/// Enables TCP keepalive settings on the socket.
pub fn with_keepalive(mut self, keepalive: TcpKeepalive) -> Self {
self.keepalive = Some(keepalive);
self
}
/// Specify the source address and port.
pub fn with_source_address(mut self, source_address: SocketAddr) -> Self {
self.source_address = Some(source_address);
self
}
}
#[derive(Clone, Debug, Default)]
pub struct TcpListenConfig {
bind_device: Option<String>,
keepalive: Option<TcpKeepalive>,
}
impl TcpListenConfig {
/// Bind the TCP listener to a specific interface, identified by its name. This option works in
/// Unix, on other systems, it will be ignored.
pub fn with_bind_device(mut self, device: String) -> Self {
self.bind_device = Some(device);
self
}
/// Enables TCP keepalive settings on client connection sockets.
pub fn with_keepalive(mut self, keepalive: TcpKeepalive) -> Self {
self.keepalive = Some(keepalive);
self
}
}
pub(crate) struct TcpAdapter;
impl Adapter for TcpAdapter {
type Remote = RemoteResource;
type Local = LocalResource;
}
pub(crate) struct RemoteResource {
stream: TcpStream,
keepalive: Option<TcpKeepalive>,
}
impl Resource for RemoteResource {
fn source(&mut self) -> &mut dyn Source {
&mut self.stream
}
}
impl Remote for RemoteResource {
fn connect_with(
config: TransportConnect,
remote_addr: RemoteAddr,
) -> io::Result<ConnectionInfo<Self>> {
let config = match config {
TransportConnect::Tcp(config) => config,
_ => panic!("Internal error: Got wrong config"),
};
let peer_addr = *remote_addr.socket_addr();
let socket = Socket::new(
match peer_addr {
SocketAddr::V4 { .. } => Domain::IPV4,
SocketAddr::V6 { .. } => Domain::IPV6,
},
Type::STREAM,
Some(Protocol::TCP),
)?;
socket.set_nonblocking(true)?;
if let Some(source_address) = config.source_address {
socket.bind(&source_address.into())?;
}
#[cfg(unix)]
if let Some(bind_device) = config.bind_device {
let device = CString::new(bind_device)?;
#[cfg(not(target_os = "macos"))]
socket.bind_device(Some(device.as_bytes()))?;
#[cfg(target_os = "macos")]
match NonZeroU32::new(unsafe { libc::if_nametoindex(device.as_ptr()) }) {
Some(index) => socket.bind_device_by_index_v4(Some(index))?,
None => {
return Err(io::Error::new(
ErrorKind::NotFound,
"Bind device interface not found",
))
}
}
}
match socket.connect(&peer_addr.into()) {
#[cfg(unix)]
Err(e) if e.raw_os_error() != Some(libc::EINPROGRESS) => return Err(e),
#[cfg(windows)]
Err(e) if e.kind() != io::ErrorKind::WouldBlock => return Err(e),
_ => {}
}
let stream = TcpStream::from_std(socket.into());
let local_addr = stream.local_addr()?;
Ok(ConnectionInfo {
remote: Self { stream, keepalive: config.keepalive },
local_addr,
peer_addr,
})
}
fn receive(&self, mut process_data: impl FnMut(&[u8])) -> ReadStatus {
let buffer: MaybeUninit<[u8; INPUT_BUFFER_SIZE]> = MaybeUninit::uninit();
let mut input_buffer = unsafe { buffer.assume_init() }; // Avoid to initialize the array
loop {
let mut stream = &self.stream;
match stream.read(&mut input_buffer) {
Ok(0) => break ReadStatus::Disconnected,
Ok(size) => process_data(&input_buffer[..size]),
Err(ref err) if err.kind() == ErrorKind::Interrupted => continue,
Err(ref err) if err.kind() == ErrorKind::WouldBlock => {
break ReadStatus::WaitNextEvent
}
Err(ref err) if err.kind() == ErrorKind::ConnectionReset => {
break ReadStatus::Disconnected
}
Err(err) => {
log::error!("TCP receive error: {}", err);
break ReadStatus::Disconnected // should not happen
}
}
}
}
fn send(&self, data: &[u8]) -> SendStatus {
// TODO: The current implementation implies an active waiting,
// improve it using POLLIN instead to avoid active waiting.
// Note: Despite the fear that an active waiting could generate,
// this only occurs in the case when the receiver is full because reads slower that it sends.
let mut total_bytes_sent = 0;
loop {
let mut stream = &self.stream;
match stream.write(&data[total_bytes_sent..]) {
Ok(bytes_sent) => {
total_bytes_sent += bytes_sent;
if total_bytes_sent == data.len() {
break SendStatus::Sent
}
}
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => continue,
// Others errors are considered fatal for the connection.
// a Event::Disconnection will be generated later.
Err(err) => {
log::error!("TCP receive error: {}", err);
break SendStatus::ResourceNotFound // should not happen
}
}
}
}
fn pending(&self, _readiness: Readiness) -> PendingStatus {
let status = check_stream_ready(&self.stream);
if status == PendingStatus::Ready {
if let Some(keepalive) = &self.keepalive {
#[cfg(target_os = "windows")]
let socket = unsafe { Socket::from_raw_socket(self.stream.as_raw_socket()) };
#[cfg(not(target_os = "windows"))]
let socket = unsafe { Socket::from_raw_fd(self.stream.as_raw_fd()) };
if let Err(e) = socket.set_tcp_keepalive(keepalive) {
log::warn!("TCP set keepalive error: {}", e);
}
// Don't drop so the underlying socket is not closed.
forget(socket);
}
}
status
}
}
/// Check if a TcpStream can be considered connected.
pub fn check_stream_ready(stream: &TcpStream) -> PendingStatus {
// A multiplatform non-blocking way to determine if the TCP stream is connected:
// Extracted from: https://github.com/tokio-rs/mio/issues/1486
if let Ok(Some(_)) = stream.take_error() {
return PendingStatus::Disconnected
}
match stream.peer_addr() {
Ok(_) => PendingStatus::Ready,
Err(err) if err.kind() == io::ErrorKind::NotConnected => PendingStatus::Incomplete,
Err(err) if err.kind() == io::ErrorKind::InvalidInput => PendingStatus::Incomplete,
Err(_) => PendingStatus::Disconnected,
}
}
pub(crate) struct LocalResource {
listener: TcpListener,
keepalive: Option<TcpKeepalive>,
}
impl Resource for LocalResource {
fn source(&mut self) -> &mut dyn Source {
&mut self.listener
}
}
impl Local for LocalResource {
type Remote = RemoteResource;
fn listen_with(config: TransportListen, addr: SocketAddr) -> io::Result<ListeningInfo<Self>> {
let config = match config {
TransportListen::Tcp(config) => config,
_ => panic!("Internal error: Got wrong config"),
};
let socket = Socket::new(
match addr {
SocketAddr::V4 { .. } => Domain::IPV4,
SocketAddr::V6 { .. } => Domain::IPV6,
},
Type::STREAM,
Some(Protocol::TCP),
)?;
socket.set_nonblocking(true)?;
socket.set_reuse_address(true)?;
#[cfg(unix)]
if let Some(bind_device) = config.bind_device {
let device = CString::new(bind_device)?;
#[cfg(not(target_os = "macos"))]
socket.bind_device(Some(device.as_bytes()))?;
#[cfg(target_os = "macos")]
match NonZeroU32::new(unsafe { libc::if_nametoindex(device.as_ptr()) }) {
Some(index) => socket.bind_device_by_index_v4(Some(index))?,
None => {
return Err(io::Error::new(
ErrorKind::NotFound,
"Bind device interface not found",
))
}
}
}
socket.bind(&addr.into())?;
socket.listen(LISTENER_BACKLOG)?;
let listener = TcpListener::from_std(socket.into());
let local_addr = listener.local_addr().unwrap();
Ok(ListeningInfo {
local: { LocalResource { listener, keepalive: config.keepalive } },
local_addr,
})
}
fn accept(&self, mut accept_remote: impl FnMut(AcceptedType<'_, Self::Remote>)) {
loop {
match self.listener.accept() {
Ok((stream, addr)) => accept_remote(AcceptedType::Remote(
addr,
RemoteResource { stream, keepalive: self.keepalive.clone() },
)),
Err(ref err) if err.kind() == ErrorKind::WouldBlock => break,
Err(ref err) if err.kind() == ErrorKind::Interrupted => continue,
Err(err) => break log::error!("TCP accept error: {}", err), // Should not happen
}
}
}
}