Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion glommio/src/net/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use std::{
cell::Cell,
io,
net::Shutdown,
os::unix::io::{AsRawFd, FromRawFd, RawFd},
os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
rc::{Rc, Weak},
task::{Context, Poll, Waker},
time::{Duration, Instant},
Expand Down Expand Up @@ -510,3 +510,17 @@ impl<S: AsRawFd, B: Buffered> GlommioStream<S, B> {
self.rx_buf.consume(amt);
}
}

impl<S: IntoRawFd, B: RxBuf> IntoRawFd for GlommioStream<S, B> {
fn into_raw_fd(self) -> RawFd {
// Clean up reactor resources before extracting the fd
let reactor = self.stream.reactor.upgrade();
if let Some(reactor) = reactor {
self.stream.write_timeout.cancel_timer(&reactor);
self.stream.read_timeout.cancel_timer(&reactor);
}

// Extract the raw fd from the underlying stream
self.stream.stream.into_raw_fd()
}
}
83 changes: 83 additions & 0 deletions glommio/src/net/tcp_socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,12 @@ impl FromRawFd for TcpStream {
}
}

impl<B: RxBuf> IntoRawFd for TcpStream<B> {
fn into_raw_fd(self) -> RawFd {
self.stream.into_raw_fd()
}
}

fn make_tcp_socket(addr: &SocketAddr) -> io::Result<Socket> {
let domain = if addr.is_ipv6() {
Domain::IPV6
Expand Down Expand Up @@ -1328,4 +1334,81 @@ mod tests {
assert_eq!(s.local_addr().unwrap(), peer_addr.await);
});
}

#[test]
fn tcp_stream_into_raw_fd() {
test_executor!(async move {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();

// Test with non-buffered stream
let stream = TcpStream::connect(addr).await.unwrap();
let original_fd = stream.as_raw_fd();

// Extract the raw fd
let raw_fd = stream.into_raw_fd();
assert_eq!(original_fd, raw_fd);

// Verify we can create a new stream from the raw fd
let restored_stream = unsafe { TcpStream::from_raw_fd(raw_fd) };
assert_eq!(restored_stream.as_raw_fd(), raw_fd);

// Clean up
std::mem::drop(restored_stream);
});
}

#[test]
fn tcp_stream_buffered_into_raw_fd() {
test_executor!(async move {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();

// Test with buffered stream
let stream = TcpStream::connect(addr).await.unwrap().buffered();
let original_fd = stream.as_raw_fd();

// Extract the raw fd
let raw_fd = stream.into_raw_fd();
assert_eq!(original_fd, raw_fd);

// Verify we can create a new stream from the raw fd
let restored_stream = unsafe { TcpStream::from_raw_fd(raw_fd) };
assert_eq!(restored_stream.as_raw_fd(), raw_fd);

// Clean up
std::mem::drop(restored_stream);
});
}

#[test]
fn tcp_stream_into_raw_fd_with_timeouts() {
test_executor!(async move {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();

let stream = TcpStream::connect(addr).await.unwrap();

// Set timeouts to verify they get cleaned up
stream.set_read_timeout(Some(Duration::from_secs(30))).unwrap();
stream.set_write_timeout(Some(Duration::from_secs(30))).unwrap();

assert_eq!(stream.read_timeout(), Some(Duration::from_secs(30)));
assert_eq!(stream.write_timeout(), Some(Duration::from_secs(30)));

let original_fd = stream.as_raw_fd();

// Extract the raw fd - this should clean up timers
let raw_fd = stream.into_raw_fd();
assert_eq!(original_fd, raw_fd);

// Create a new stream and verify timeouts are reset
let restored_stream = unsafe { TcpStream::from_raw_fd(raw_fd) };
assert_eq!(restored_stream.read_timeout(), None);
assert_eq!(restored_stream.write_timeout(), None);

// Clean up
std::mem::drop(restored_stream);
});
}
}
Loading