diff --git a/src/vmm/src/devices/virtio/mmio.rs b/src/vmm/src/devices/virtio/mmio.rs index cd9c1a5e0335..4de940549ae3 100644 --- a/src/vmm/src/devices/virtio/mmio.rs +++ b/src/vmm/src/devices/virtio/mmio.rs @@ -201,7 +201,10 @@ impl MmioTransport { let mut device_status = self.device_status; let reset_result = self.locked_device().reset(); match reset_result { - Some((_interrupt_evt, mut _queue_evts)) => {} + Some((_interrupt_evt, mut _queue_evts)) => { + // The device MUST initialize device status to 0 upon reset. + device_status = INIT; + } None => { device_status |= FAILED; } diff --git a/src/vmm/src/devices/virtio/net/device.rs b/src/vmm/src/devices/virtio/net/device.rs index 9d8cdfd73723..0e0433242f8d 100755 --- a/src/vmm/src/devices/virtio/net/device.rs +++ b/src/vmm/src/devices/virtio/net/device.rs @@ -870,6 +870,26 @@ impl VirtioDevice for Net { fn is_activated(&self) -> bool { self.device_state.is_activated() } + + fn reset(&mut self) -> Option<(EventFd, Vec)> { + self.device_state = DeviceState::Inactive; + self.rx_bytes_read = 0; + self.rx_deferred_frame = false; + self.rx_frame_buf = [0u8; MAX_BUFFER_SIZE]; + self.metrics = NetMetricsPerDevice::alloc(self.id.clone()); + + let queue_evts: Vec<_> = self + .queue_evts + .iter() + .filter_map(|q| q.try_clone().ok()) + .collect(); + + if let Ok(irq_evt) = self.irq_trigger.irq_evt.try_clone() { + Some((irq_evt, queue_evts)) + } else { + None + } + } } #[cfg(test)] @@ -2015,17 +2035,29 @@ pub mod tests { th.activate_net(); let net = th.net.lock().unwrap(); - // Test queues count (TX and RX). - let queues = net.queues(); - assert_eq!(queues.len(), NET_QUEUE_SIZES.len()); - assert_eq!(queues[RX_INDEX].size, th.rxq.size()); - assert_eq!(queues[TX_INDEX].size, th.txq.size()); + let validate = |net: &Net| { + // Test queues count (TX and RX). + let queues = net.queues(); + assert_eq!(queues.len(), NET_QUEUE_SIZES.len()); + assert_eq!(queues[RX_INDEX].size, th.rxq.size()); + assert_eq!(queues[TX_INDEX].size, th.txq.size()); + + // Test corresponding queues events. + assert_eq!(net.queue_events().len(), NET_QUEUE_SIZES.len()); + + // Test interrupts. + assert!(!&net.irq_trigger.has_pending_irq(IrqType::Vring)); + }; + + validate(&net); - // Test corresponding queues events. - assert_eq!(net.queue_events().len(), NET_QUEUE_SIZES.len()); + // Test reset. + let mut net = net; + assert!(net.device_state.is_activated()); + let (_interrupt_evt, _queue_evts) = net.reset().unwrap(); + assert!(!net.device_state.is_activated()); - // Test interrupts. - assert!(!&net.irq_trigger.has_pending_irq(IrqType::Vring)); + validate(&net); } #[test]