diff --git a/src/devices/src/virtio/vsock/csm/connection.rs b/src/devices/src/virtio/vsock/csm/connection.rs index 4b7f78f0ff8..21d60055dba 100644 --- a/src/devices/src/virtio/vsock/csm/connection.rs +++ b/src/devices/src/virtio/vsock/csm/connection.rs @@ -125,6 +125,8 @@ pub struct VsockConnection { /// Instant when this connection should be scheduled for immediate termination, due to some /// timeout condition having been fulfilled. expiry: Option, + /// If this true, Reply the connection status before transfer data or close this connection. + need_reply: bool, } impl VsockChannel for VsockConnection @@ -304,9 +306,30 @@ where // Next up: receiving a response / confirmation for a host-initiated connection. // We'll move to an Established state, and pass on the good news through the host // stream. - ConnState::LocalInit if pkt.op() == uapi::VSOCK_OP_RESPONSE => { - self.expiry = None; - self.state = ConnState::Established; + ConnState::LocalInit + if pkt.op() == uapi::VSOCK_OP_RESPONSE || pkt.op() == uapi::VSOCK_OP_RST => + { + let is_response = pkt.op() == uapi::VSOCK_OP_RESPONSE; + if self.need_reply { + self.need_reply = false; + if let Err(err) = self.send_bytes(if is_response { b"101\n" } else { b"503\n" }) + { + // If we can't write to the host stream, that's an unrecoverable error, so + // we'll terminate this connection. + warn!( + "vsock: error writing to local stream (lp={}, pp={}): {:?}", + self.local_port, self.peer_port, err + ); + if is_response { + self.kill(); + } + return Ok(()); + } + } + if is_response { + self.expiry = None; + self.state = ConnState::Established; + } } // The peer wants to shut down an established connection. If they have nothing @@ -478,6 +501,7 @@ where last_fwd_cnt_to_peer: Wrapping(0), pending_rx: PendingRxSet::from(PendingRx::Response), expiry: None, + need_reply: false, } } @@ -488,6 +512,7 @@ where peer_cid: u64, local_port: u32, peer_port: u32, + need_reply: bool, ) -> Self { Self { local_cid, @@ -504,6 +529,7 @@ where last_fwd_cnt_to_peer: Wrapping(0), pending_rx: PendingRxSet::from(PendingRx::Request), expiry: None, + need_reply, } } @@ -738,7 +764,7 @@ mod tests { Self::new(ConnState::Established) } - fn new(conn_state: ConnState) -> Self { + fn new_maybe_need_reply(conn_state: ConnState, need_reply: bool) -> Self { let vsock_test_ctx = TestContext::new(); let mut handler_ctx = vsock_test_ctx.create_epoll_handler_context(); let stream = TestStream::new(); @@ -756,7 +782,7 @@ mod tests { PEER_BUF_ALLOC, ), ConnState::LocalInit => VsockConnection::::new_local_init( - stream, LOCAL_CID, PEER_CID, LOCAL_PORT, PEER_PORT, + stream, LOCAL_CID, PEER_CID, LOCAL_PORT, PEER_PORT, need_reply, ), ConnState::Established => { let mut conn = VsockConnection::::new_peer_init( @@ -782,6 +808,10 @@ mod tests { } } + fn new(conn_state: ConnState) -> Self { + Self::new_maybe_need_reply(conn_state, false) + } + fn set_stream(&mut self, stream: TestStream) { self.conn.stream = stream; } @@ -826,6 +856,7 @@ mod tests { fn test_peer_request() { let mut ctx = CsmTestContext::new(ConnState::PeerInit); assert!(ctx.conn.has_pending_rx()); + assert_eq!(ctx.conn.need_reply, false); ctx.recv(); // For peer-initiated requests, our connection should always yield a vsock reponse packet, // in order to establish the connection. @@ -844,6 +875,29 @@ mod tests { #[test] fn test_local_request() { let mut ctx = CsmTestContext::new(ConnState::LocalInit); + assert_eq!(ctx.conn.need_reply, false); + // Host-initiated connections should first yield a connection request packet. + assert!(ctx.conn.has_pending_rx()); + // Before yielding the connection request packet, the timeout kill timer shouldn't be + // armed. + assert!(!ctx.conn.will_expire()); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_REQUEST); + // Since the request might time-out, the kill timer should now be armed. + assert!(ctx.conn.will_expire()); + assert!(!ctx.conn.has_expired()); + ctx.init_pkt(uapi::VSOCK_OP_RESPONSE, 0); + ctx.send(); + // Upon receiving a connection response, the connection should have transitioned to the + // established state, and the kill timer should've been disarmed. + assert_eq!(ctx.conn.state, ConnState::Established); + assert!(!ctx.conn.will_expire()); + } + + #[test] + fn test_local_request_need_reply() { + let mut ctx = CsmTestContext::new_maybe_need_reply(ConnState::LocalInit, true); + assert_eq!(ctx.conn.need_reply, true); // Host-initiated connections should first yield a connection request packet. assert!(ctx.conn.has_pending_rx()); // Before yielding the connection request packet, the timeout kill timer shouldn't be @@ -865,6 +919,21 @@ mod tests { #[test] fn test_local_request_timeout() { let mut ctx = CsmTestContext::new(ConnState::LocalInit); + assert_eq!(ctx.conn.need_reply, false); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_REQUEST); + assert!(ctx.conn.will_expire()); + assert!(!ctx.conn.has_expired()); + std::thread::sleep(std::time::Duration::from_millis( + defs::CONN_REQUEST_TIMEOUT_MS, + )); + assert!(ctx.conn.has_expired()); + } + + #[test] + fn test_local_request_timeout_need_reply() { + let mut ctx = CsmTestContext::new_maybe_need_reply(ConnState::LocalInit, true); + assert_eq!(ctx.conn.need_reply, true); ctx.recv(); assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_REQUEST); assert!(ctx.conn.will_expire()); diff --git a/src/devices/src/virtio/vsock/unix/muxer.rs b/src/devices/src/virtio/vsock/unix/muxer.rs index 10ad4aaf058..a429d341a84 100644 --- a/src/devices/src/virtio/vsock/unix/muxer.rs +++ b/src/devices/src/virtio/vsock/unix/muxer.rs @@ -227,20 +227,19 @@ impl VsockChannel for VsockMuxer { return Ok(()); } + // Alright, everything looks in order - forward this packet to its owning connection. + let mut res: VsockResult<()> = Ok(()); + self.apply_conn_mutation(conn_key, |conn| { + res = conn.send_pkt(pkt); + }); + // Right, we know where to send this packet, then (to `conn_key`). // However, if this is an RST, we have to forcefully terminate the connection, so // there's no point in forwarding it the packet. if pkt.op() == uapi::VSOCK_OP_RST { self.remove_connection(conn_key); - return Ok(()); } - // Alright, everything looks in order - forward this packet to its owning connection. - let mut res: VsockResult<()> = Ok(()); - self.apply_conn_mutation(conn_key, |conn| { - res = conn.send_pkt(pkt); - }); - res } @@ -381,8 +380,10 @@ impl VsockMuxer { Some(EpollListener::LocalStream(_)) => { if let Some(EpollListener::LocalStream(mut stream)) = self.remove_listener(fd) { Self::read_local_stream_port(&mut stream) - .and_then(|peer_port| Ok((self.allocate_local_port(), peer_port))) - .and_then(|(local_port, peer_port)| { + .and_then(|(peer_port, need_reply)| { + Ok((self.allocate_local_port(), peer_port, need_reply)) + }) + .and_then(|(local_port, peer_port, need_reply)| { self.add_connection( ConnMapKey { local_port, @@ -394,6 +395,7 @@ impl VsockMuxer { self.cid, local_port, peer_port, + need_reply, ), ) }) @@ -409,8 +411,8 @@ impl VsockMuxer { } } - /// Parse a host "connect" command, and extract the destination vsock port. - fn read_local_stream_port(stream: &mut UnixStream) -> Result { + /// Parse a host "connect" and "upgrade" command, and extract the destination vsock port. + fn read_local_stream_port(stream: &mut UnixStream) -> Result<(u32, bool)> { let mut buf = [0u8; 32]; // This is the minimum number of bytes that we should be able to read, when parsing a @@ -437,19 +439,25 @@ impl VsockMuxer { .map_err(|_| Error::InvalidPortRequest)? .split_whitespace(); - word_iter + let mut need_reply = false; + let port = word_iter .next() .ok_or(Error::InvalidPortRequest) .and_then(|word| { if word.to_lowercase() == "connect" { Ok(()) + } else if word.to_lowercase() == "upgrade" { + need_reply = true; + Ok(()) } else { Err(Error::InvalidPortRequest) } }) .and_then(|_| word_iter.next().ok_or(Error::InvalidPortRequest)) .and_then(|word| word.parse::().map_err(|_| Error::InvalidPortRequest)) - .map_err(|_| Error::InvalidPortRequest) + .map_err(|_| Error::InvalidPortRequest)?; + + Ok((port, need_reply)) } /// Add a new connection to the active connection pool. @@ -839,7 +847,14 @@ mod tests { LocalListener::new(format!("{}_{}", self.muxer.host_sock_path, port)) } - fn local_connect(&mut self, peer_port: u32) -> (UnixStream, u32) { + fn local_connect_maybe_upgrade_oprst( + &mut self, + peer_port: u32, + is_upgrade: bool, + is_oprst: bool, + ) -> (UnixStream, u32) { + assert!(!is_oprst || (is_oprst && is_upgrade)); + let (init_local_lsn_count, init_conn_lsn_count) = self.count_epoll_listeners(); let mut stream = UnixStream::connect(self.muxer.host_sock_path.clone()).unwrap(); @@ -853,7 +868,11 @@ mod tests { let (local_lsn_count, _) = self.count_epoll_listeners(); assert_eq!(local_lsn_count, init_local_lsn_count + 1); - let buf = format!("CONNECT {}\n", peer_port); + let buf = if is_upgrade { + format!("Upgrade {}\n", peer_port) + } else { + format!("CONNECT {}\n", peer_port) + }; stream.write_all(buf.as_bytes()).unwrap(); // The muxer would now get notified that data is available for reading from the locally // initiated connection. @@ -882,11 +901,27 @@ mod tests { assert_eq!(self.pkt.dst_port(), peer_port); assert_eq!(self.pkt.src_port(), local_port); - self.init_pkt(local_port, peer_port, uapi::VSOCK_OP_RESPONSE); + self.init_pkt( + local_port, + peer_port, + if is_oprst { + uapi::VSOCK_OP_RST + } else { + uapi::VSOCK_OP_RESPONSE + }, + ); self.send(); (stream, local_port) } + + fn local_connect(&mut self, peer_port: u32) -> (UnixStream, u32) { + self.local_connect_maybe_upgrade_oprst(peer_port, false, false) + } + + fn local_connect_with_upgrade(&mut self, peer_port: u32) -> (UnixStream, u32) { + self.local_connect_maybe_upgrade_oprst(peer_port, true, false) + } } struct LocalListener { @@ -1063,6 +1098,53 @@ mod tests { assert_eq!(ctx.pkt.buf().unwrap()[..data.len()], data); } + #[test] + fn test_local_connection_with_upgrade() { + let mut ctx = MuxerTestContext::new("local_connection_with_upgrade"); + let peer_port = 1025; + let (mut stream, local_port) = ctx.local_connect_with_upgrade(peer_port); + + // Test the handshake + let mut buf = vec![0; 4]; + stream.read_exact(buf.as_mut_slice()).unwrap(); + let buf = String::from_utf8(buf).unwrap(); + assert_eq!(buf, "101\n".to_string()); + + // Test guest -> host data flow. + let data = [1, 2, 3, 4]; + ctx.init_data_pkt(local_port, peer_port, &data); + ctx.send(); + + let mut buf = vec![0u8; data.len()]; + stream.read_exact(buf.as_mut_slice()).unwrap(); + assert_eq!(buf.as_slice(), &data); + + // Test host -> guest data flow. + let data = [5, 6, 7, 8]; + stream.write_all(&data).unwrap(); + ctx.notify_muxer(); + + assert!(ctx.muxer.has_pending_rx()); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RW); + assert_eq!(ctx.pkt.src_port(), local_port); + assert_eq!(ctx.pkt.dst_port(), peer_port); + assert_eq!(ctx.pkt.buf().unwrap()[..data.len()], data); + } + + #[test] + fn test_local_connection_with_upgrade_get_oprst() { + let mut ctx = MuxerTestContext::new("local_connection_with_upgrade_get_oprst"); + let peer_port = 1025; + let (mut stream, _local_port) = + ctx.local_connect_maybe_upgrade_oprst(peer_port, true, true); + + let mut buf = vec![0; 4]; + stream.read_exact(buf.as_mut_slice()).unwrap(); + let buf = String::from_utf8(buf).unwrap(); + assert_eq!(buf, "503\n".to_string()); + } + #[test] fn test_local_close() { let peer_port = 1025; @@ -1096,6 +1178,46 @@ mod tests { assert!(!ctx.muxer.local_port_set.contains(&local_port)); } + #[test] + fn test_local_close_with_upgrade() { + let peer_port = 1025; + let mut ctx = MuxerTestContext::new("local_close_with_upgrade"); + let local_port; + { + let (mut stream, local_port_) = ctx.local_connect_with_upgrade(peer_port); + + // Test the handshake + let mut buf = vec![0; 4]; + stream.read_exact(buf.as_mut_slice()).unwrap(); + let buf = String::from_utf8(buf).unwrap(); + assert_eq!(buf, "101\n".to_string()); + + local_port = local_port_; + } + // Local var `_stream` was now dropped, thus closing the local stream. After the muxer gets + // notified via EPOLLIN, it should attempt to gracefully shutdown the connection, issuing a + // VSOCK_OP_SHUTDOWN with both no-more-send and no-more-recv indications set. + ctx.notify_muxer(); + assert!(ctx.muxer.has_pending_rx()); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_SHUTDOWN); + assert_ne!(ctx.pkt.flags() & uapi::VSOCK_FLAGS_SHUTDOWN_SEND, 0); + assert_ne!(ctx.pkt.flags() & uapi::VSOCK_FLAGS_SHUTDOWN_RCV, 0); + assert_eq!(ctx.pkt.src_port(), local_port); + assert_eq!(ctx.pkt.dst_port(), peer_port); + + // The connection should get removed (and its local port freed), after the peer replies + // with an RST. + ctx.init_pkt(local_port, peer_port, uapi::VSOCK_OP_RST); + ctx.send(); + let key = ConnMapKey { + local_port, + peer_port, + }; + assert!(!ctx.muxer.conn_map.contains_key(&key)); + assert!(!ctx.muxer.local_port_set.contains(&local_port)); + } + #[test] fn test_peer_close() { let peer_port = 1025; diff --git a/tests/integration_tests/build/test_coverage.py b/tests/integration_tests/build/test_coverage.py index 4e8c77d32f2..37012d1835b 100644 --- a/tests/integration_tests/build/test_coverage.py +++ b/tests/integration_tests/build/test_coverage.py @@ -19,7 +19,7 @@ import host_tools.cargo_build as host # pylint: disable=import-error -COVERAGE_TARGET_PCT = 85.3 +COVERAGE_TARGET_PCT = 85.2 COVERAGE_MAX_DELTA = 0.01 CARGO_KCOV_REL_PATH = os.path.join(host.CARGO_BUILD_REL_PATH, 'kcov')