diff --git a/glommio/src/net/stream.rs b/glommio/src/net/stream.rs index 21d5b0e84..6f3e2e59a 100644 --- a/glommio/src/net/stream.rs +++ b/glommio/src/net/stream.rs @@ -13,7 +13,7 @@ use std::{ cell::Cell, io, net::Shutdown, - os::unix::io::{AsRawFd, FromRawFd, RawFd}, + os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}, rc::{Rc, Weak}, task::{Context, Poll, Waker}, time::{Duration, Instant}, @@ -510,3 +510,17 @@ impl GlommioStream { self.rx_buf.consume(amt); } } + +impl IntoRawFd for GlommioStream { + fn into_raw_fd(self) -> RawFd { + // Clean up reactor resources before extracting the fd + let reactor = self.stream.reactor.upgrade(); + if let Some(reactor) = reactor { + self.stream.write_timeout.cancel_timer(&reactor); + self.stream.read_timeout.cancel_timer(&reactor); + } + + // Extract the raw fd from the underlying stream + self.stream.stream.into_raw_fd() + } +} \ No newline at end of file diff --git a/glommio/src/net/tcp_socket.rs b/glommio/src/net/tcp_socket.rs index 37cc2b1da..5be573367 100644 --- a/glommio/src/net/tcp_socket.rs +++ b/glommio/src/net/tcp_socket.rs @@ -433,6 +433,12 @@ impl FromRawFd for TcpStream { } } +impl IntoRawFd for TcpStream { + fn into_raw_fd(self) -> RawFd { + self.stream.into_raw_fd() + } +} + fn make_tcp_socket(addr: &SocketAddr) -> io::Result { let domain = if addr.is_ipv6() { Domain::IPV6 @@ -1328,4 +1334,81 @@ mod tests { assert_eq!(s.local_addr().unwrap(), peer_addr.await); }); } + + #[test] + fn tcp_stream_into_raw_fd() { + test_executor!(async move { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + + // Test with non-buffered stream + let stream = TcpStream::connect(addr).await.unwrap(); + let original_fd = stream.as_raw_fd(); + + // Extract the raw fd + let raw_fd = stream.into_raw_fd(); + assert_eq!(original_fd, raw_fd); + + // Verify we can create a new stream from the raw fd + let restored_stream = unsafe { TcpStream::from_raw_fd(raw_fd) }; + assert_eq!(restored_stream.as_raw_fd(), raw_fd); + + // Clean up + std::mem::drop(restored_stream); + }); + } + + #[test] + fn tcp_stream_buffered_into_raw_fd() { + test_executor!(async move { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + + // Test with buffered stream + let stream = TcpStream::connect(addr).await.unwrap().buffered(); + let original_fd = stream.as_raw_fd(); + + // Extract the raw fd + let raw_fd = stream.into_raw_fd(); + assert_eq!(original_fd, raw_fd); + + // Verify we can create a new stream from the raw fd + let restored_stream = unsafe { TcpStream::from_raw_fd(raw_fd) }; + assert_eq!(restored_stream.as_raw_fd(), raw_fd); + + // Clean up + std::mem::drop(restored_stream); + }); + } + + #[test] + fn tcp_stream_into_raw_fd_with_timeouts() { + test_executor!(async move { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + + let stream = TcpStream::connect(addr).await.unwrap(); + + // Set timeouts to verify they get cleaned up + stream.set_read_timeout(Some(Duration::from_secs(30))).unwrap(); + stream.set_write_timeout(Some(Duration::from_secs(30))).unwrap(); + + assert_eq!(stream.read_timeout(), Some(Duration::from_secs(30))); + assert_eq!(stream.write_timeout(), Some(Duration::from_secs(30))); + + let original_fd = stream.as_raw_fd(); + + // Extract the raw fd - this should clean up timers + let raw_fd = stream.into_raw_fd(); + assert_eq!(original_fd, raw_fd); + + // Create a new stream and verify timeouts are reset + let restored_stream = unsafe { TcpStream::from_raw_fd(raw_fd) }; + assert_eq!(restored_stream.read_timeout(), None); + assert_eq!(restored_stream.write_timeout(), None); + + // Clean up + std::mem::drop(restored_stream); + }); + } }