|
1 | 1 | // Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. |
2 | 2 | // SPDX-License-Identifier: Apache-2.0 |
3 | 3 |
|
| 4 | +use std::fs::File; |
4 | 5 | use std::mem; |
5 | 6 | use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; |
6 | 7 | use std::os::unix::net::UnixStream; |
@@ -62,6 +63,8 @@ pub trait VhostUserSlaveReqHandler { |
62 | 63 | fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>; |
63 | 64 | fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>; |
64 | 65 | fn set_slave_req_fd(&self, _vu_req: SlaveFsCacheReq) {} |
| 66 | + fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, RawFd)>; |
| 67 | + fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>; |
65 | 68 | fn get_max_mem_slots(&self) -> Result<u64>; |
66 | 69 | fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: RawFd) -> Result<()>; |
67 | 70 | fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>; |
@@ -105,6 +108,11 @@ pub trait VhostUserSlaveReqHandlerMut { |
105 | 108 | ) -> Result<Vec<u8>>; |
106 | 109 | fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>; |
107 | 110 | fn set_slave_req_fd(&mut self, _vu_req: SlaveFsCacheReq) {} |
| 111 | + fn get_inflight_fd( |
| 112 | + &mut self, |
| 113 | + inflight: &VhostUserInflight, |
| 114 | + ) -> Result<(VhostUserInflight, RawFd)>; |
| 115 | + fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>; |
108 | 116 | fn get_max_mem_slots(&mut self) -> Result<u64>; |
109 | 117 | fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: RawFd) -> Result<()>; |
110 | 118 | fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>; |
@@ -197,6 +205,14 @@ impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> { |
197 | 205 | self.lock().unwrap().set_slave_req_fd(vu_req) |
198 | 206 | } |
199 | 207 |
|
| 208 | + fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, RawFd)> { |
| 209 | + self.lock().unwrap().get_inflight_fd(inflight) |
| 210 | + } |
| 211 | + |
| 212 | + fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()> { |
| 213 | + self.lock().unwrap().set_inflight_fd(inflight, file) |
| 214 | + } |
| 215 | + |
200 | 216 | fn get_max_mem_slots(&self) -> Result<u64> { |
201 | 217 | self.lock().unwrap().get_max_mem_slots() |
202 | 218 | } |
@@ -435,6 +451,41 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { |
435 | 451 | self.check_request_size(&hdr, size, hdr.get_size() as usize)?; |
436 | 452 | self.set_slave_req_fd(&hdr, rfds)?; |
437 | 453 | } |
| 454 | + MasterReq::GET_INFLIGHT_FD => { |
| 455 | + if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits() |
| 456 | + == 0 |
| 457 | + { |
| 458 | + return Err(Error::InvalidOperation); |
| 459 | + } |
| 460 | + |
| 461 | + let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?; |
| 462 | + let (inflight, fd) = self.backend.get_inflight_fd(&msg)?; |
| 463 | + let reply_hdr = self.new_reply_header::<VhostUserInflight>(&hdr, 0)?; |
| 464 | + self.main_sock |
| 465 | + .send_message(&reply_hdr, &inflight, Some(&[fd]))?; |
| 466 | + } |
| 467 | + MasterReq::SET_INFLIGHT_FD => { |
| 468 | + if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits() |
| 469 | + == 0 |
| 470 | + { |
| 471 | + return Err(Error::InvalidOperation); |
| 472 | + } |
| 473 | + let file = if let Some(fds) = rfds { |
| 474 | + if fds.len() != 1 || fds[0] < 0 { |
| 475 | + Endpoint::<MasterReq>::close_rfds(Some(fds)); |
| 476 | + return Err(Error::IncorrectFds); |
| 477 | + } |
| 478 | + |
| 479 | + // Safe because we know the fd is valid. |
| 480 | + unsafe { File::from_raw_fd(fds[0]) } |
| 481 | + } else { |
| 482 | + return Err(Error::IncorrectFds); |
| 483 | + }; |
| 484 | + |
| 485 | + let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?; |
| 486 | + let res = self.backend.set_inflight_fd(&msg, file); |
| 487 | + self.send_ack_message(&hdr, res)?; |
| 488 | + } |
438 | 489 | MasterReq::GET_MAX_MEM_SLOTS => { |
439 | 490 | if self.acked_protocol_features |
440 | 491 | & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() |
|
0 commit comments