diff --git a/monarch_rdma/src/rdma_components.rs b/monarch_rdma/src/rdma_components.rs index 9a022d700..e4a3becdf 100644 --- a/monarch_rdma/src/rdma_components.rs +++ b/monarch_rdma/src/rdma_components.rs @@ -41,7 +41,7 @@ //! 7. Resources are cleaned up when dropped /// Maximum size for a single RDMA operation in bytes (1 GiB) -const MAX_RDMA_MSG_SIZE: usize = 1024 * 1024 * 1024; +pub const MAX_RDMA_MSG_SIZE: usize = 1024 * 1024 * 1024; use std::ffi::CStr; use std::fs; @@ -130,31 +130,15 @@ impl RdmaBuffer { remote.owner.actor_id(), remote, ); - let remote_owner = remote.owner.clone(); - - let local_device = self.device_name.clone(); - let remote_device = remote.device_name.clone(); - let mut qp = self - .owner - .request_queue_pair_deprecated( + self.owner + .read_into( client, - remote_owner.clone(), - local_device.clone(), - remote_device.clone(), + self.clone(), + remote, + tokio::time::Duration::from_secs(timeout), ) .await?; - - qp.put(self.clone(), remote)?; - let result = self - .wait_for_completion(&mut qp, PollTarget::Send, timeout) - .await; - - // Release the queue pair back to the actor - self.owner - .release_queue_pair_deprecated(client, remote_owner, local_device, remote_device, qp) - .await?; - - result + Ok(true) } /// Write from the provided memory into the RdmaBuffer. @@ -182,32 +166,14 @@ impl RdmaBuffer { remote.owner.actor_id(), remote, ); - let remote_owner = remote.owner.clone(); // Clone before the move! - - // Extract device name from buffer, fallback to a default if not present - let local_device = self.device_name.clone(); - let remote_device = remote.device_name.clone(); - - let mut qp = self - .owner - .request_queue_pair_deprecated( + self.owner + .write_from( client, - remote_owner.clone(), - local_device.clone(), - remote_device.clone(), + self.clone(), + remote, + tokio::time::Duration::from_secs(timeout), ) .await?; - qp.get(self.clone(), remote)?; - let result = self - .wait_for_completion(&mut qp, PollTarget::Send, timeout) - .await; - - // Release the queue pair back to the actor - self.owner - .release_queue_pair_deprecated(client, remote_owner, local_device, remote_device, qp) - .await?; - - result?; Ok(true) } /// Waits for the completion of an RDMA operation. @@ -217,21 +183,19 @@ impl RdmaBuffer { /// /// # Arguments /// * `qp` - The RDMA Queue Pair to poll for completion - /// * `timeout` - Timeout in seconds for the RDMA operation to complete. + /// * `timeout` - Timeout for the RDMA operation to complete. /// /// # Returns /// `Ok(true)` if the operation completes successfully within the timeout, /// or an error if the timeout is reached - async fn wait_for_completion( + pub async fn wait_for_completion( &self, qp: &mut RdmaQueuePair, poll_target: PollTarget, - timeout: u64, + timeout: tokio::time::Duration, ) -> Result { - let timeout = Duration::from_secs(timeout); - let start_time = std::time::Instant::now(); - - while start_time.elapsed() < timeout { + RealClock.timeout(timeout, async { + loop { match qp.poll_completion_target(poll_target) { Ok(Some(_wc)) => { tracing::debug!("work completed"); @@ -252,11 +216,14 @@ impl RdmaBuffer { } } } - tracing::error!("timed out while waiting on request completion"); - Err(anyhow::anyhow!( - "[buffer({:?})] rdma operation did not complete in time", - self - )) + }).await.map_err(|_| { + tracing::error!("timed out while waiting on request completion"); + anyhow::anyhow!( + "[buffer({:?})] rdma operation did not complete in time (timeout={:?})", + self, + timeout + ) + })? } /// Drop the buffer and release remote handles. @@ -1057,7 +1024,7 @@ impl RdmaQueuePair { /// * `op_type` - Optional operation type /// * `raddr` - the remote address, representing the memory location on the remote peer /// * `rkey` - the remote key, representing the key required to access the remote memory region - fn post_op( + pub fn post_op( &mut self, laddr: usize, lkey: u32, @@ -1318,6 +1285,51 @@ impl RdmaQueuePair { } } + /// Poll for completions on the specified completion queue. + /// This function does not mutate the various indices (e.g. send_db_idx, send_cq_idx, etc.) + /// However, it does change the state of the device. As such, + /// calling this function while using poll_completion_target() will result in + /// a race condition, since poll_completion_target() will not see the completion retrieved by this function. + pub fn poll_once_stateless(&self, target: PollTarget) -> Result, anyhow::Error> { + let cq = if target == PollTarget::Send { + self.send_cq as *mut rdmaxcel_sys::ibv_cq + } else { + self.recv_cq as *mut rdmaxcel_sys::ibv_cq + }; + let context = self.context as *mut rdmaxcel_sys::ibv_context; + unsafe { + let ops = &mut (*context).ops; + let mut wc = std::mem::MaybeUninit::::zeroed().assume_init(); + let ret = ops.poll_cq.as_mut().unwrap()(cq, 1, &mut wc); + + if ret < 0 { + return Err(anyhow::anyhow!( + "Failed to poll CQ (target={:?}): {}", + target, + Error::last_os_error() + )); + } + + if ret > 0 { + if !wc.is_valid() { + if let Some((status, vendor_err)) = wc.error() { + return Err(anyhow::anyhow!( + "work completion (target={:?}) failed with status: {:?}, vendor error: {}, wr_id: {}", + target, + status, + vendor_err, + wc.wr_id(), + )); + } + } + return Ok(Some(IbvWc::from(wc))); + } + } + + // No completion found + Ok(None) + } + pub fn poll_send_completion(&mut self) -> Result, anyhow::Error> { self.poll_completion_target(PollTarget::Send) } @@ -1326,7 +1338,6 @@ impl RdmaQueuePair { self.poll_completion_target(PollTarget::Recv) } } - /// Utility to validate execution context. /// /// Remote Execution environments do not always have access to the nvidia_peermem module diff --git a/monarch_rdma/src/rdma_manager_actor.rs b/monarch_rdma/src/rdma_manager_actor.rs index 8e7638238..2736e2999 100644 --- a/monarch_rdma/src/rdma_manager_actor.rs +++ b/monarch_rdma/src/rdma_manager_actor.rs @@ -28,8 +28,12 @@ //! //! See test examples: `test_rdma_write_loopback` and `test_rdma_read_loopback`. use std::collections::HashMap; +use std::sync::Arc; +use std::sync::atomic::AtomicU64; +use std::sync::atomic::Ordering; use async_trait::async_trait; +use dashmap::DashMap; use hyperactor::Actor; use hyperactor::ActorId; use hyperactor::ActorRef; @@ -40,26 +44,110 @@ use hyperactor::Instance; use hyperactor::Named; use hyperactor::OncePortRef; use hyperactor::RefClient; +use hyperactor::clock::Clock; +use hyperactor::clock::RealClock; use hyperactor::supervision::ActorSupervisionEvent; use serde::Deserialize; use serde::Serialize; +use tokio::sync::Notify; +use tokio::sync::Semaphore; +use tokio::sync::oneshot; +use crate::ibverbs_primitives::IbvWc; use crate::ibverbs_primitives::IbverbsConfig; use crate::ibverbs_primitives::RdmaMemoryRegionView; +use crate::ibverbs_primitives::RdmaOperation; use crate::ibverbs_primitives::RdmaQpInfo; use crate::ibverbs_primitives::ibverbs_supported; use crate::ibverbs_primitives::resolve_qp_type; +use crate::rdma_components::MAX_RDMA_MSG_SIZE; +use crate::rdma_components::PollTarget; use crate::rdma_components::RdmaBuffer; use crate::rdma_components::RdmaDomain; use crate::rdma_components::RdmaQueuePair; use crate::rdma_components::get_registered_cuda_segments; use crate::validate_execution_context; -/// Represents the state of a queue pair in the manager, either available or checked out. +#[derive(Debug)] +struct CompletionTracker { + // wr_id -> channel to send completion result. + reply_channels: Arc>>>, + polling_task: Option>, +} + +impl CompletionTracker { + pub fn new() -> Self { + Self { + reply_channels: Arc::new(DashMap::new()), + polling_task: None, + } + } + + /// Submit a tracking request for a work request ID. Returns a channel to receive the result. + pub fn submit(&self, wr_id: u64) -> oneshot::Receiver> { + let (tx, rx) = oneshot::channel(); + self.reply_channels.insert(wr_id, tx); + rx + } + + pub fn start_polling(&mut self, qp: RdmaQueuePair) { + if self.polling_task.is_none() { + let reply_channels = self.reply_channels.clone(); + let polling_task = tokio::spawn(async move { + let mut interval = tokio::time::interval(tokio::time::Duration::from_micros(500)); + loop { + match qp.poll_once_stateless(PollTarget::Send) { + Ok(Some(wc)) => { + if let Some((wr_id, tx)) = reply_channels.remove(&wc.wr_id()) { + tx.send(Ok(wc)).unwrap_or_else(|_| { + tracing::info!( + "Failed to send completion result for wr_id {}, receiver might be cancelled", + wr_id + ); + }); + } else { + tracing::warn!( + "No completion result channel found for wr_id {}, this is likely a bug", + wc.wr_id() + ); + } + } + Ok(None) => { + // No completion available yet + interval.tick().await; + } + Err(e) => { + tracing::error!("Polling error: {}", e); + // FIXME(yuxuanh): broadcast error to all pending requests + break; + } + } + } + }); + self.polling_task = Some(polling_task); + } else { + panic!("Polling task already started"); + } + } +} + +/// Wrapper for a queue pair with a semaphore for access control and a polling task. +#[derive(Debug, Clone)] +pub struct QueuePairEntry { + pub qp: RdmaQueuePair, + pub semaphore: Arc, + completion_tracker: Arc, +} + +/// Represents the state of a queue pair in the manager. #[derive(Debug, Clone)] pub enum QueuePairState { - Available(RdmaQueuePair), - CheckedOut, + /// Connection establishment in progress. Waiters will be notified when ready or on error. + Connecting(Arc), + /// Queue pair is ready and available for use. + Ready(QueuePairEntry), + /// Connection failed. Error is persisted for all current and future requesters. + ConnectionError(Arc), } /// Helper function to get detailed error messages from RDMAXCEL error codes @@ -91,6 +179,7 @@ pub enum RdmaManagerMessage { other: ActorRef, self_device: String, other_device: String, + start_polling: bool, #[reply] /// `reply` - Reply channel to return the queue pair for communication reply: OncePortRef, @@ -125,8 +214,20 @@ pub enum RdmaManagerMessage { other: ActorRef, self_device: String, other_device: String, - /// `qp` - The queue pair to return (ownership transferred back) - qp: RdmaQueuePair, + }, + ReadInto { + local: RdmaBuffer, + remote: RdmaBuffer, + timeout: tokio::time::Duration, + #[reply] + reply: OncePortRef<()>, + }, + WriteFrom { + local: RdmaBuffer, + remote: RdmaBuffer, + timeout: tokio::time::Duration, + #[reply] + reply: OncePortRef<()>, }, } @@ -162,6 +263,8 @@ pub struct RdmaManagerActor { // Map of PCI addresses to their optimal RDMA devices // This is populated during actor initialization using the device selection algorithm pci_to_device: HashMap, + + next_wr_id: AtomicU64, } impl Drop for RdmaManagerActor { @@ -209,12 +312,20 @@ impl Drop for RdmaManagerActor { for (device_name, device_map) in self.device_qps.drain() { for ((actor_id, remote_device), qp_state) in device_map { match qp_state { - QueuePairState::Available(qp) => { - destroy_queue_pair(&qp, &format!("actor {:?}", actor_id)); + QueuePairState::Ready(entry) => { + destroy_queue_pair(&entry.qp, &format!("actor {:?}", actor_id)); } - QueuePairState::CheckedOut => { + QueuePairState::Connecting(_) => { tracing::warn!( - "QP for actor {:?} (device {} -> {}) was checked out during cleanup", + "QP for actor {:?} (device {} -> {}) was still connecting during cleanup", + actor_id, + device_name, + remote_device + ); + } + QueuePairState::ConnectionError(_) => { + tracing::warn!( + "QP for actor {:?} (device {} -> {}) had connection error during cleanup", actor_id, device_name, remote_device @@ -515,6 +626,280 @@ impl RdmaManagerActor { } Ok(()) } + + /// Establishes a connection between this actor and another remote actor. + /// Handles both loopback (same actor, same device) and remote connections. + async fn establish_connection( + &mut self, + cx: &Context<'_, Self>, + other: ActorRef, + self_device: String, + other_device: String, + ) -> Result { + let is_loopback = other.actor_id() == cx.bind::().actor_id() + && self_device == other_device; + + if is_loopback { + // Loopback connection setup + self.initialize_qp(cx, other.clone(), self_device.clone(), other_device.clone()) + .await?; + let endpoint = self + .connection_info(cx, other.clone(), other_device.clone(), self_device.clone()) + .await?; + self.connect( + cx, + other.clone(), + self_device.clone(), + other_device.clone(), + endpoint, + ) + .await?; + } else { + // Remote connection setup + self.initialize_qp(cx, other.clone(), self_device.clone(), other_device.clone()) + .await?; + other + .initialize_qp( + cx, + cx.bind().clone(), + other_device.clone(), + self_device.clone(), + ) + .await?; + let other_endpoint: RdmaQpInfo = other + .connection_info( + cx, + cx.bind().clone(), + other_device.clone(), + self_device.clone(), + ) + .await?; + self.connect( + cx, + other.clone(), + self_device.clone(), + other_device.clone(), + other_endpoint, + ) + .await?; + let local_endpoint = self + .connection_info(cx, other.clone(), self_device.clone(), other_device.clone()) + .await?; + other + .connect( + cx, + cx.bind().clone(), + other_device.clone(), + self_device.clone(), + local_endpoint, + ) + .await?; + } + + // Hardware init delay. apply_first_op_delay no longer works for mysterious reasons. + // FIXME(yuxuanh): refactor & make this behave like apply_first_op_delay + RealClock.sleep(tokio::time::Duration::from_millis(2)).await; + + // Retrieve the connected queue pair + let inner_key = (other.actor_id().clone(), other_device.clone()); + if let Some(device_map) = self.device_qps.get(&self_device) { + if let Some(qp_state) = device_map.get(&inner_key) { + match qp_state { + QueuePairState::Ready(entry) => Ok(entry.qp.clone()), + QueuePairState::Connecting(_) => Err(anyhow::anyhow!( + "Unexpected Connecting state after connection establishment" + )), + QueuePairState::ConnectionError(err) => { + Err(anyhow::anyhow!("Connection failed: {}", err)) + } + } + } else { + Err(anyhow::anyhow!( + "Failed to find connection for actor {} on device {}", + other.actor_id(), + other_device + )) + } + } else { + Err(anyhow::anyhow!( + "Failed to find device map for device {} after connection", + self_device + )) + } + } + + fn next_wr_id(&self) -> u64 { + self.next_wr_id.fetch_add(1, Ordering::Relaxed) + } + + fn post_op_chunked( + &self, + qp_entry: &QueuePairEntry, + lhandle: RdmaBuffer, + rhandle: RdmaBuffer, + op_type: RdmaOperation, + ) -> Result>>, anyhow::Error> { + let total_size = lhandle.size; + match op_type { + RdmaOperation::Write => { + if rhandle.size < total_size { + return Err(anyhow::anyhow!( + "Remote buffer size ({}) is smaller than local buffer size ({})", + rhandle.size, + total_size + )); + } + } + RdmaOperation::Read => { + if rhandle.size > total_size { + return Err(anyhow::anyhow!( + "Remote buffer size ({}) is larger than local buffer size ({}).", + rhandle.size, + total_size + )); + } + } + _ => { + unimplemented!("Unsupported operation type: {:?}", op_type) + } + } + + let mut qp = qp_entry.qp.clone(); + let tracker = &qp_entry.completion_tracker; + + let mut remaining = total_size; + let mut offset = 0; + let mut rxs = Vec::new(); + while remaining > 0 { + let chunk_size = std::cmp::min(remaining, MAX_RDMA_MSG_SIZE); + let wr_id = self.next_wr_id(); + rxs.push(tracker.submit(wr_id)); + qp.post_op( + lhandle.addr + offset, + lhandle.lkey, + chunk_size, + wr_id, + true, + op_type, + rhandle.addr + offset, + rhandle.rkey, + )?; + remaining -= chunk_size; + offset += chunk_size; + } + Ok(rxs) + } + + async fn request_queue_pair_entry( + &mut self, + cx: &Context<'_, RdmaManagerActor>, + other: ActorRef, + self_device: String, + other_device: String, + start_polling: bool, + ) -> Result { + let inner_key = (other.actor_id().clone(), other_device.clone()); + // Phase 1: Get or create the QueuePairEntry + let entry = loop { + let qp_state = self + .device_qps + .get(&self_device) + .and_then(|map| map.get(&inner_key)) + .cloned(); + + match qp_state { + Some(QueuePairState::Ready(entry)) => { + // Queue pair is ready + break entry; + } + Some(QueuePairState::ConnectionError(err)) => { + // Connection previously failed, propagate error + return Err(anyhow::anyhow!("Connection previously failed: {}", err)); + } + Some(QueuePairState::Connecting(ref notify)) => { + // Another task is connecting, wait for notification + let notify = notify.clone(); + drop(qp_state); // Release borrows before awaiting + + notify.notified().await; + // Loop back to re-check state (could be Ready or ConnectionError now) + continue; + } + None => { + // No connection exists, we need to establish it + let notify = Arc::new(Notify::new()); + + // Insert Connecting state + self.device_qps + .entry(self_device.clone()) + .or_insert_with(HashMap::new) + .insert( + inner_key.clone(), + QueuePairState::Connecting(notify.clone()), + ); + + // Establish the connection + let result = self + .establish_connection( + cx, + other.clone(), + self_device.clone(), + other_device.clone(), + ) + .await; + + match result { + Ok(qp) => { + let mut completion_tracker = CompletionTracker::new(); + if start_polling { + completion_tracker.start_polling(qp.clone()); + } + let entry = QueuePairEntry { + qp, + semaphore: Arc::new(Semaphore::new(1)), + completion_tracker: completion_tracker.into(), + }; + + // Update state to Ready + self.device_qps + .get_mut(&self_device) + .unwrap() + .insert(inner_key.clone(), QueuePairState::Ready(entry.clone())); + + // Notify all waiters + notify.notify_waiters(); + break entry; + } + Err(e) => { + let arc_err = Arc::new(e); + + // Insert ConnectionError state for all current and future requesters + self.device_qps.get_mut(&self_device).unwrap().insert( + inner_key.clone(), + QueuePairState::ConnectionError(arc_err.clone()), + ); + + // Notify all waiters to fail + notify.notify_waiters(); + return Err(anyhow::anyhow!("Connection failed: {}", arc_err)); + } + } + } + } + }; + + // Phase 2: Acquire semaphore permit (fair FIFO waiting) + let permit = entry + .semaphore + .acquire() + .await + .map_err(|e| anyhow::anyhow!("Failed to acquire semaphore: {}", e))?; + + // Forget the permit so it doesn't auto-release on drop + permit.forget(); + + Ok(entry) + } } #[async_trait] @@ -568,6 +953,7 @@ impl Actor for RdmaManagerActor { mr_map: HashMap::new(), mrv_id: 0, pci_to_device, + next_wr_id: AtomicU64::new(0), }) } @@ -649,12 +1035,17 @@ impl RdmaManagerMessageHandler for RdmaManagerActor { /// Requests a queue pair for communication with a remote RDMA manager actor. /// - /// Basic logic: if queue pair exists in map, return it; if None, create connection first. + /// This method uses a fair semaphore-based approach that allows multiple concurrent + /// requesters to wait for queue pair availability without failing. /// /// # Arguments /// /// * `cx` - The context of the actor requesting the queue pair. - /// * `remote` - The ActorRef of the remote RDMA manager actor to communicate with. + /// * `other` - The ActorRef of the remote RDMA manager actor to communicate with. + /// * `self_device` - The local device name. + /// * `other_device` - The remote device name. + /// * `start_polling` - Whether to start polling for completions immediately after the connection is established. + /// If the queue pair is already created before, this argument has no effect. /// /// # Returns /// @@ -664,120 +1055,16 @@ impl RdmaManagerMessageHandler for RdmaManagerActor { &mut self, cx: &Context, other: ActorRef, - self_device: String, other_device: String, + start_polling: bool, ) -> Result { - let other_id = other.actor_id().clone(); - - // Use the nested map structure: local_device -> (actor_id, remote_device) -> QueuePairState - let inner_key = (other_id.clone(), other_device.clone()); - - // Check if queue pair exists in map - if let Some(device_map) = self.device_qps.get(&self_device) { - if let Some(qp_state) = device_map.get(&inner_key).cloned() { - match qp_state { - QueuePairState::Available(qp) => { - // Queue pair exists and is available - return it - self.device_qps - .get_mut(&self_device) - .unwrap() - .insert(inner_key, QueuePairState::CheckedOut); - return Ok(qp); - } - QueuePairState::CheckedOut => { - return Err(anyhow::anyhow!( - "queue pair for actor {} on device {} is already checked out", - other_id, - other_device - )); - } - } - } - } - - // Queue pair doesn't exist - need to create connection - let is_loopback = other_id == cx.bind::().actor_id().clone() - && self_device == other_device; - - if is_loopback { - // Loopback connection setup - self.initialize_qp(cx, other.clone(), self_device.clone(), other_device.clone()) - .await?; - let endpoint = self - .connection_info(cx, other.clone(), other_device.clone(), self_device.clone()) - .await?; - self.connect( - cx, - other.clone(), - self_device.clone(), - other_device.clone(), - endpoint, - ) - .await?; - } else { - // Remote connection setup - self.initialize_qp(cx, other.clone(), self_device.clone(), other_device.clone()) - .await?; - other - .initialize_qp( - cx, - cx.bind().clone(), - other_device.clone(), - self_device.clone(), - ) - .await?; - let other_endpoint: RdmaQpInfo = other - .connection_info( - cx, - cx.bind().clone(), - other_device.clone(), - self_device.clone(), - ) - .await?; - self.connect( - cx, - other.clone(), - self_device.clone(), - other_device.clone(), - other_endpoint, - ) - .await?; - let local_endpoint = self - .connection_info(cx, other.clone(), self_device.clone(), other_device.clone()) - .await?; - other - .connect( - cx, - cx.bind().clone(), - other_device.clone(), - self_device.clone(), - local_endpoint, - ) - .await?; - } - - // Now that connection is established, get the queue pair - if let Some(device_map) = self.device_qps.get(&self_device) { - if let Some(QueuePairState::Available(qp)) = device_map.get(&inner_key).cloned() { - self.device_qps - .get_mut(&self_device) - .unwrap() - .insert(inner_key, QueuePairState::CheckedOut); - Ok(qp) - } else { - Err(anyhow::anyhow!( - "Failed to create connection for actor {} on device {}", - other_id, - other_device - )) - } - } else { - Err(anyhow::anyhow!( - "Failed to create connection for actor {} on device {} - no device map", - other_id, - other_device - )) + match self + .request_queue_pair_entry(cx, other, self_device, other_device, start_polling) + .await + { + Ok(entry) => Ok(entry.qp), + Err(e) => Err(e), } } @@ -791,9 +1078,9 @@ impl RdmaManagerMessageHandler for RdmaManagerActor { let other_id = other.actor_id().clone(); let inner_key = (other_id.clone(), other_device.clone()); - // Check if QP already exists in nested structure + // Check if QP already exists and is Ready if let Some(device_map) = self.device_qps.get(&self_device) { - if device_map.contains_key(&inner_key) { + if let Some(QueuePairState::Ready(_)) = device_map.get(&inner_key) { return Ok(true); } } @@ -826,11 +1113,18 @@ impl RdmaManagerMessageHandler for RdmaManagerActor { let qp = RdmaQueuePair::new(domain_context, domain_pd, self.config.clone()) .map_err(|e| anyhow::anyhow!("could not create RdmaQueuePair: {}", e))?; + // Wrap in QueuePairEntry with semaphore + let entry = QueuePairEntry { + qp, + semaphore: Arc::new(Semaphore::new(1)), + completion_tracker: CompletionTracker::new().into(), + }; + // Insert the QP into the nested map structure self.device_qps .entry(self_device.clone()) .or_insert_with(HashMap::new) - .insert(inner_key, QueuePairState::Available(qp)); + .insert(inner_key, QueuePairState::Ready(entry)); tracing::debug!( "successfully created a connection with {:?} for local device {} -> remote device {}", @@ -857,22 +1151,28 @@ impl RdmaManagerMessageHandler for RdmaManagerActor { ) -> Result<(), anyhow::Error> { tracing::debug!("connecting with {:?}", other); let other_id = other.actor_id().clone(); - - // For backward compatibility, use default device let inner_key = (other_id.clone(), other_device.clone()); if let Some(device_map) = self.device_qps.get_mut(&self_device) { match device_map.get_mut(&inner_key) { - Some(QueuePairState::Available(qp)) => { + Some(QueuePairState::Ready(entry)) => { + // Access the QP from the entry and connect + // Note: We need to mutate the QP, but entry is behind Arc/Clone semantics + // So we get a mutable reference to the QP directly from the map + let qp = &mut entry.qp; qp.connect(&endpoint).map_err(|e| { anyhow::anyhow!("could not connect to RDMA endpoint: {}", e) })?; Ok(()) } - Some(QueuePairState::CheckedOut) => Err(anyhow::anyhow!( - "Cannot connect: queue pair for actor {} is checked out", + Some(QueuePairState::Connecting(_)) => Err(anyhow::anyhow!( + "Cannot connect: queue pair for actor {} is still being initialized", other_id )), + Some(QueuePairState::ConnectionError(err)) => Err(anyhow::anyhow!( + "Cannot connect: connection failed: {}", + err + )), None => Err(anyhow::anyhow!( "No connection found for actor {}", other_id @@ -902,19 +1202,22 @@ impl RdmaManagerMessageHandler for RdmaManagerActor { ) -> Result { tracing::debug!("getting connection info with {:?}", other); let other_id = other.actor_id().clone(); - let inner_key = (other_id.clone(), other_device.clone()); if let Some(device_map) = self.device_qps.get_mut(&self_device) { match device_map.get_mut(&inner_key) { - Some(QueuePairState::Available(qp)) => { - let connection_info = qp.get_qp_info()?; + Some(QueuePairState::Ready(entry)) => { + let connection_info = entry.qp.get_qp_info()?; Ok(connection_info) } - Some(QueuePairState::CheckedOut) => Err(anyhow::anyhow!( - "Cannot get connection info: queue pair for actor {} is checked out", + Some(QueuePairState::Connecting(_)) => Err(anyhow::anyhow!( + "Cannot get connection info: queue pair for actor {} is still being initialized", other_id )), + Some(QueuePairState::ConnectionError(err)) => Err(anyhow::anyhow!( + "Cannot get connection info: connection failed: {}", + err + )), None => Err(anyhow::anyhow!( "No connection found for actor {}", other_id @@ -928,49 +1231,120 @@ impl RdmaManagerMessageHandler for RdmaManagerActor { } } - /// Releases a queue pair back to the HashMap + /// Releases a queue pair back to the pool. /// - /// This method returns a queue pair to the HashMap after the caller has finished - /// using it. This completes the request/release cycle, similar to RdmaBuffer. + /// This method releases a semaphore permit, allowing the next waiting requester + /// to acquire the queue pair. This completes the request/release cycle. /// /// # Arguments - /// * `remote` - The ActorRef of the remote actor to return the queue pair for - /// * `qp` - The queue pair to release + /// * `other` - The ActorRef of the remote actor + /// * `self_device` - The local device name + /// * `other_device` - The remote device name + /// * `qp` - The queue pair to release (unused but kept for API compatibility) async fn release_queue_pair( &mut self, _cx: &Context, other: ActorRef, self_device: String, other_device: String, - qp: RdmaQueuePair, ) -> Result<(), anyhow::Error> { let inner_key = (other.actor_id().clone(), other_device.clone()); - match self + // Get the entry from the map + let entry = self .device_qps - .get_mut(&self_device) - .unwrap() - .get_mut(&inner_key) - { - Some(QueuePairState::CheckedOut) => { - self.device_qps - .get_mut(&self_device) - .unwrap() - .insert(inner_key, QueuePairState::Available(qp)); + .get(&self_device) + .and_then(|map| map.get(&inner_key)) + .ok_or_else(|| { + anyhow::anyhow!( + "No queue pair found for actor {}, between devices {} and {}", + other.actor_id(), + self_device, + other_device, + ) + })?; + + match entry { + QueuePairState::Ready(entry) => { + // Release the semaphore permit, allowing next waiter to acquire + entry.semaphore.add_permits(1); Ok(()) } - Some(QueuePairState::Available(_)) => Err(anyhow::anyhow!( - "Cannot release queue pair: queue pair for actor {} is already available between devices {} and {}", + QueuePairState::Connecting(_) => Err(anyhow::anyhow!( + "Cannot release queue pair: connection still in progress for actor {} between devices {} and {}", other.actor_id(), self_device, other_device, )), - None => Err(anyhow::anyhow!( - "No queue pair found for actor {}, between devices {} and {}", - other.actor_id(), - self_device, - other_device, + QueuePairState::ConnectionError(err) => Err(anyhow::anyhow!( + "Cannot release queue pair: connection failed: {}", + err )), } } + + async fn read_into( + &mut self, + cx: &Context, + local: RdmaBuffer, + remote: RdmaBuffer, + timeout: tokio::time::Duration, + ) -> Result<(), anyhow::Error> { + let remote_owner = remote.owner.clone(); + + let local_device = local.device_name.clone(); + let remote_device = remote.device_name.clone(); + let qp_entry = self + .request_queue_pair_entry( + cx, + remote_owner.clone(), + local_device.clone(), + remote_device.clone(), + true, + ) + .await?; + let rxs = self.post_op_chunked(&qp_entry, local, remote, RdmaOperation::Write); + + // Release the queue pair back to the actor + self.release_queue_pair(cx, remote_owner, local_device, remote_device) + .await?; + + for rx in rxs? { + rx.await??; + } + + Ok(()) + } + async fn write_from( + &mut self, + cx: &Context, + local: RdmaBuffer, + remote: RdmaBuffer, + timeout: tokio::time::Duration, + ) -> Result<(), anyhow::Error> { + let remote_owner = remote.owner.clone(); + + let local_device = local.device_name.clone(); + let remote_device = remote.device_name.clone(); + let mut qp_entry = self + .request_queue_pair_entry( + cx, + remote_owner.clone(), + local_device.clone(), + remote_device.clone(), + true, + ) + .await?; + let rxs = self.post_op_chunked(&qp_entry, local, remote, RdmaOperation::Read); + + // Release the queue pair back to the actor + self.release_queue_pair(cx, remote_owner, local_device, remote_device) + .await?; + + for rx in rxs? { + rx.await?; + } + + Ok(()) + } } diff --git a/monarch_rdma/src/rdma_manager_actor_tests.rs b/monarch_rdma/src/rdma_manager_actor_tests.rs index 8c19719da..010fe3d64 100644 --- a/monarch_rdma/src/rdma_manager_actor_tests.rs +++ b/monarch_rdma/src/rdma_manager_actor_tests.rs @@ -39,6 +39,7 @@ mod tests { env.actor_2.clone(), env.rdma_handle_1.device_name.clone(), env.rdma_handle_2.device_name.clone(), + false, ) .await?; qp_1.put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; @@ -52,11 +53,10 @@ mod tests { env.actor_2.clone(), env.rdma_handle_1.device_name.clone(), env.rdma_handle_2.device_name.clone(), - qp_1, ) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers().await?; Ok(()) } @@ -77,6 +77,7 @@ mod tests { env.actor_2.clone(), env.rdma_handle_1.device_name.clone(), env.rdma_handle_2.device_name.clone(), + false, ) .await?; qp_1.put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; @@ -89,11 +90,10 @@ mod tests { env.actor_2.clone(), env.rdma_handle_1.device_name.clone(), env.rdma_handle_2.device_name.clone(), - qp_1, ) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers().await?; Ok(()) } @@ -116,6 +116,7 @@ mod tests { env.actor_2.clone(), env.rdma_handle_1.device_name.clone(), env.rdma_handle_2.device_name.clone(), + false, ) .await?; qp_1.get(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; @@ -123,7 +124,7 @@ mod tests { // Poll for completion wait_for_completion(&mut qp_1, PollTarget::Send, 2).await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers().await?; Ok(()) } @@ -146,12 +147,13 @@ mod tests { env.actor_2.clone(), env.rdma_handle_1.device_name.clone(), env.rdma_handle_2.device_name.clone(), + false, ) .await?; qp_1.put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; wait_for_completion(&mut qp_1, PollTarget::Send, 2).await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers().await?; env.cleanup().await?; Ok(()) } @@ -174,6 +176,7 @@ mod tests { env.actor_2.clone(), env.rdma_handle_1.device_name.clone(), env.rdma_handle_2.device_name.clone(), + false, ) .await?; let mut qp_2 = env @@ -183,12 +186,13 @@ mod tests { env.actor_1.clone(), env.rdma_handle_2.device_name.clone(), env.rdma_handle_1.device_name.clone(), + false, ) .await?; qp_2.put_with_recv(env.rdma_handle_2.clone(), env.rdma_handle_1.clone())?; qp_1.recv(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; wait_for_completion(&mut qp_2, PollTarget::Send, 5).await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers().await?; env.cleanup().await?; Ok(()) } @@ -212,6 +216,7 @@ mod tests { env.actor_2.clone(), env.rdma_handle_1.device_name.clone(), env.rdma_handle_2.device_name.clone(), + false, ) .await?; qp_1.enqueue_put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; @@ -219,7 +224,7 @@ mod tests { // Poll for completion wait_for_completion(&mut qp_1, PollTarget::Send, 5).await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers().await?; Ok(()) } @@ -242,6 +247,7 @@ mod tests { env.actor_1.clone(), env.rdma_handle_2.device_name.clone(), env.rdma_handle_1.device_name.clone(), + false, ) .await?; qp_2.enqueue_get(env.rdma_handle_2.clone(), env.rdma_handle_1.clone())?; @@ -249,7 +255,7 @@ mod tests { // Poll for completion wait_for_completion(&mut qp_2, PollTarget::Send, 5).await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers().await?; Ok(()) } @@ -270,7 +276,7 @@ mod tests { .read_into(env.client_1, env.rdma_handle_2.clone(), 2) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers().await?; env.cleanup().await?; Ok(()) } @@ -292,7 +298,7 @@ mod tests { .write_from(env.client_1, env.rdma_handle_2.clone(), 2) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers().await?; Ok(()) } @@ -335,6 +341,7 @@ mod tests { env.actor_2.clone(), env.rdma_handle_1.device_name.clone(), env.rdma_handle_2.device_name.clone(), + false, ) .await?; qp_1.enqueue_put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; @@ -342,7 +349,7 @@ mod tests { // Poll for completion wait_for_completion_gpu(&mut qp_1, PollTarget::Send, 5).await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers().await?; Ok(()) } @@ -374,6 +381,7 @@ mod tests { env.actor_2.clone(), env.rdma_handle_1.device_name.clone(), env.rdma_handle_2.device_name.clone(), + false, ) .await?; qp_1.enqueue_get(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; @@ -381,7 +389,7 @@ mod tests { // Poll for completion wait_for_completion_gpu(&mut qp_1, PollTarget::Send, 5).await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers().await?; Ok(()) } @@ -412,6 +420,7 @@ mod tests { env.actor_2.clone(), env.rdma_handle_1.device_name.clone(), env.rdma_handle_2.device_name.clone(), + false, ) .await?; let mut qp_2 = env @@ -421,6 +430,7 @@ mod tests { env.actor_1.clone(), env.rdma_handle_2.device_name.clone(), env.rdma_handle_1.device_name.clone(), + false, ) .await?; recv_wqe_gpu( @@ -440,7 +450,7 @@ mod tests { ring_db_gpu(&mut qp_2).await?; wait_for_completion_gpu(&mut qp_1, PollTarget::Recv, 10).await?; wait_for_completion_gpu(&mut qp_2, PollTarget::Send, 10).await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers().await?; env.cleanup().await?; Ok(()) } @@ -469,13 +479,14 @@ mod tests { env.actor_2.clone(), env.rdma_handle_1.device_name.clone(), env.rdma_handle_2.device_name.clone(), + false, ) .await?; qp_1.put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; wait_for_completion(&mut qp_1, PollTarget::Send, 5).await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers().await?; env.cleanup().await?; Ok(()) } @@ -504,13 +515,14 @@ mod tests { env.actor_2.clone(), env.rdma_handle_1.device_name.clone(), env.rdma_handle_2.device_name.clone(), + false, ) .await?; qp_1.put(env.rdma_handle_1.clone(), env.rdma_handle_2.clone())?; wait_for_completion(&mut qp_1, PollTarget::Send, 5).await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers().await?; env.cleanup().await?; Ok(()) } @@ -538,7 +550,7 @@ mod tests { .read_into(env.client_1, env.rdma_handle_2.clone(), 5) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers().await?; env.cleanup().await?; Ok(()) } @@ -564,7 +576,7 @@ mod tests { .read_into(env.client_1, env.rdma_handle_2.clone(), 5) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers().await?; env.cleanup().await?; Ok(()) } @@ -590,7 +602,7 @@ mod tests { .read_into(env.client_1, env.rdma_handle_2.clone(), 5) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers().await?; env.cleanup().await?; Ok(()) } @@ -616,7 +628,7 @@ mod tests { .write_from(env.client_1, env.rdma_handle_2.clone(), 5) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers().await?; env.cleanup().await?; Ok(()) } @@ -644,7 +656,7 @@ mod tests { .read_into(env.client_1, env.rdma_handle_2.clone(), 2) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers().await?; env.cleanup().await?; Ok(()) } @@ -672,7 +684,7 @@ mod tests { .write_from(env.client_1, env.rdma_handle_2.clone(), 2) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers().await?; env.cleanup().await?; Ok(()) } @@ -704,7 +716,7 @@ mod tests { .read_into(env.client_1, env.rdma_handle_2.clone(), 5) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers().await?; env.cleanup().await?; Ok(()) } @@ -736,7 +748,70 @@ mod tests { .write_from(env.client_1, env.rdma_handle_2.clone(), 5) .await?; - env.verify_buffers(BSIZE).await?; + env.verify_buffers().await?; + env.cleanup().await?; + Ok(()) + } + + #[timed_test::async_timed_test(timeout_secs = 60)] + async fn test_concurrent_read_into() -> Result<(), anyhow::Error> { + const BSIZE: usize = 32; + let devices = get_all_devices(); + if devices.len() < 5 { + println!( + "skipping this test as it is only configured on H100 nodes with backend network" + ); + return Ok(()); + } + + let mut env = RdmaManagerTestEnv::setup(BSIZE, "cpu:0", "cpu:1").await?; + + // Create three distinct buffer pairs that share the same actors + let (h1_a, h2_a) = env.create_buffer_pair(BSIZE).await?; + let (h1_b, h2_b) = env.create_buffer_pair(BSIZE).await?; + let (h1_c, h2_c) = env.create_buffer_pair(BSIZE).await?; + + // Launch 3 concurrent read_into operations, each with its own buffer pair + // All operations share the same queue pair connection and should wait fairly + let task1 = h1_a.read_into(env.client_1, h2_a.clone(), 2); + let task2 = h1_b.read_into(env.client_1, h2_b.clone(), 2); + let task3 = h1_c.read_into(env.client_1, h2_c.clone(), 2); + + tokio::try_join!(task1, task2, task3)?; + env.verify_all_buffer_pairs().await?; + + env.cleanup().await?; + Ok(()) + } + + #[timed_test::async_timed_test(timeout_secs = 60)] + async fn test_concurrent_write_from() -> Result<(), anyhow::Error> { + const BSIZE: usize = 32; + let devices = get_all_devices(); + if devices.len() < 5 { + println!( + "skipping this test as it is only configured on H100 nodes with backend network" + ); + return Ok(()); + } + + let mut env = RdmaManagerTestEnv::setup(BSIZE, "cpu:0", "cpu:1").await?; + + // Create three distinct buffer pairs that share the same actors + let (h1_a, h2_a) = env.create_buffer_pair(BSIZE).await?; + let (h1_b, h2_b) = env.create_buffer_pair(BSIZE).await?; + let (h1_c, h2_c) = env.create_buffer_pair(BSIZE).await?; + + // Launch 3 concurrent write_from operations, each with its own buffer pair + // All operations share the same queue pair connection and should wait fairly + let task1 = h1_a.write_from(env.client_1, h2_a.clone(), 2); + let task2 = h1_b.write_from(env.client_1, h2_b.clone(), 2); + let task3 = h1_c.write_from(env.client_1, h2_c.clone(), 2); + + tokio::try_join!(task1, task2, task3)?; + + env.verify_all_buffer_pairs().await?; + env.cleanup().await?; Ok(()) } diff --git a/monarch_rdma/src/test_utils.rs b/monarch_rdma/src/test_utils.rs index d63a0902c..19397bfd1 100644 --- a/monarch_rdma/src/test_utils.rs +++ b/monarch_rdma/src/test_utils.rs @@ -262,8 +262,8 @@ pub mod test_utils { } pub struct RdmaManagerTestEnv<'a> { - buffer_1: Buffer, - buffer_2: Buffer, + pub buffer_1: Buffer, + pub buffer_2: Buffer, pub client_1: &'a Instance<()>, pub client_2: &'a Instance<()>, pub actor_1: ActorRef, @@ -272,6 +272,11 @@ pub mod test_utils { pub rdma_handle_2: RdmaBuffer, cuda_context_1: Option, cuda_context_2: Option, + accel1: String, + accel2: String, + parsed_accel1: (String, usize), + parsed_accel2: (String, usize), + buffer_pairs: Vec<(Buffer, Buffer)>, } #[derive(Debug, Clone)] @@ -500,9 +505,241 @@ pub mod test_utils { rdma_handle_2, cuda_context_1: cuda_contexts.first().cloned().flatten(), cuda_context_2: cuda_contexts.get(1).cloned().flatten(), + accel1: accel1.to_string(), + accel2: accel2.to_string(), + parsed_accel1, + parsed_accel2, + buffer_pairs: Vec::new(), }) } + /// # Arguments + /// * `buffer_size` - The size of the buffers to create + /// + /// # Returns + /// A tuple of (RdmaBuffer, RdmaBuffer) representing handles to the two newly created buffers + pub async fn create_buffer_pair( + &mut self, + buffer_size: usize, + ) -> Result<(RdmaBuffer, RdmaBuffer), anyhow::Error> { + // Allocate buffer 1 + let buf_1 = if self.parsed_accel1.0 == "cpu" { + let mut buffer = vec![0u8; buffer_size].into_boxed_slice(); + let ptr = buffer.as_mut_ptr() as u64; + let len = buffer.len(); + Box::leak(buffer); // Leak the buffer so it lives for the test duration + Buffer { + ptr, + len, + cpu_ref: None, // Already leaked, no need to keep reference + } + } else if self.parsed_accel1.0 == "cuda" { + // CUDA case for buffer 1 + unsafe { + let mut dptr: cuda_sys::CUdeviceptr = std::mem::zeroed(); + let mut handle: cuda_sys::CUmemGenericAllocationHandle = std::mem::zeroed(); + let mut device: cuda_sys::CUdevice = std::mem::zeroed(); + + cu_check!(cuda_sys::cuDeviceGet( + &mut device, + self.parsed_accel1.1 as i32 + )); + cu_check!(cuda_sys::cuCtxSetCurrent( + self.cuda_context_1 + .expect("No CUDA context found for accel1") + )); + + let mut granularity: usize = 0; + let mut prop: cuda_sys::CUmemAllocationProp = std::mem::zeroed(); + prop.type_ = cuda_sys::CUmemAllocationType::CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type_ = cuda_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = device; + prop.allocFlags.gpuDirectRDMACapable = 1; + prop.requestedHandleTypes = + cuda_sys::CUmemAllocationHandleType::CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; + + cu_check!(cuda_sys::cuMemGetAllocationGranularity( + &mut granularity as *mut usize, + &prop, + cuda_sys::CUmemAllocationGranularity_flags::CU_MEM_ALLOC_GRANULARITY_MINIMUM, + )); + + let padded_size: usize = ((buffer_size - 1) / granularity + 1) * granularity; + assert!(padded_size == buffer_size); + + cu_check!(cuda_sys::cuMemCreate( + &mut handle as *mut cuda_sys::CUmemGenericAllocationHandle, + padded_size, + &prop, + 0 + )); + + cu_check!(cuda_sys::cuMemAddressReserve( + &mut dptr as *mut cuda_sys::CUdeviceptr, + padded_size, + 0, + 0, + 0, + )); + + let err = cuda_sys::cuMemMap( + dptr as cuda_sys::CUdeviceptr, + padded_size, + 0, + handle as cuda_sys::CUmemGenericAllocationHandle, + 0, + ); + if err != cuda_sys::CUresult::CUDA_SUCCESS { + panic!("failed reserving and mapping memory {:?}", err); + } + + let mut access_desc: cuda_sys::CUmemAccessDesc = std::mem::zeroed(); + access_desc.location.type_ = + cuda_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE; + access_desc.location.id = device; + access_desc.flags = + cuda_sys::CUmemAccess_flags::CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + cu_check!(cuda_sys::cuMemSetAccess(dptr, padded_size, &access_desc, 1)); + + Buffer { + ptr: dptr, + len: padded_size, + cpu_ref: None, + } + } + } else { + panic!("Unsupported accelerator type: {}", self.parsed_accel1.0); + }; + + // Allocate buffer 2 + let buf_2 = if self.parsed_accel2.0 == "cpu" { + let mut buffer = vec![0u8; buffer_size].into_boxed_slice(); + let ptr = buffer.as_mut_ptr() as u64; + let len = buffer.len(); + Box::leak(buffer); + Buffer { + ptr, + len, + cpu_ref: None, + } + } else if self.parsed_accel2.0 == "cuda" { + // CUDA case for buffer 2 + unsafe { + let mut dptr: cuda_sys::CUdeviceptr = std::mem::zeroed(); + let mut handle: cuda_sys::CUmemGenericAllocationHandle = std::mem::zeroed(); + let mut device: cuda_sys::CUdevice = std::mem::zeroed(); + + cu_check!(cuda_sys::cuDeviceGet( + &mut device, + self.parsed_accel2.1 as i32 + )); + cu_check!(cuda_sys::cuCtxSetCurrent( + self.cuda_context_2 + .expect("No CUDA context found for accel2") + )); + + let mut granularity: usize = 0; + let mut prop: cuda_sys::CUmemAllocationProp = std::mem::zeroed(); + prop.type_ = cuda_sys::CUmemAllocationType::CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type_ = cuda_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = device; + prop.allocFlags.gpuDirectRDMACapable = 1; + prop.requestedHandleTypes = + cuda_sys::CUmemAllocationHandleType::CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; + + cu_check!(cuda_sys::cuMemGetAllocationGranularity( + &mut granularity as *mut usize, + &prop, + cuda_sys::CUmemAllocationGranularity_flags::CU_MEM_ALLOC_GRANULARITY_MINIMUM, + )); + + let padded_size: usize = ((buffer_size - 1) / granularity + 1) * granularity; + assert!(padded_size == buffer_size); + + cu_check!(cuda_sys::cuMemCreate( + &mut handle as *mut cuda_sys::CUmemGenericAllocationHandle, + padded_size, + &prop, + 0 + )); + + cu_check!(cuda_sys::cuMemAddressReserve( + &mut dptr as *mut cuda_sys::CUdeviceptr, + padded_size, + 0, + 0, + 0, + )); + + let err = cuda_sys::cuMemMap( + dptr as cuda_sys::CUdeviceptr, + padded_size, + 0, + handle as cuda_sys::CUmemGenericAllocationHandle, + 0, + ); + if err != cuda_sys::CUresult::CUDA_SUCCESS { + panic!("failed reserving and mapping memory {:?}", err); + } + + let mut access_desc: cuda_sys::CUmemAccessDesc = std::mem::zeroed(); + access_desc.location.type_ = + cuda_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE; + access_desc.location.id = device; + access_desc.flags = + cuda_sys::CUmemAccess_flags::CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + cu_check!(cuda_sys::cuMemSetAccess(dptr, padded_size, &access_desc, 1)); + Buffer { + ptr: dptr, + len: padded_size, + cpu_ref: None, + } + } + } else { + panic!("Unsupported accelerator type: {}", self.parsed_accel2.0); + }; + + // Fill buffer1 with test data + if self.parsed_accel1.0 == "cuda" { + let mut temp_buffer = vec![0u8; buffer_size].into_boxed_slice(); + for (i, val) in temp_buffer.iter_mut().enumerate() { + *val = ((i + 42) % 256) as u8; + } + unsafe { + cu_check!(cuda_sys::cuCtxSetCurrent( + self.cuda_context_1.expect("No CUDA context found") + )); + cu_check!(cuda_sys::cuMemcpyHtoD_v2( + buf_1.ptr, + temp_buffer.as_ptr() as *const std::ffi::c_void, + temp_buffer.len() + )); + } + } else if self.parsed_accel1.0 == "cpu" { + unsafe { + let ptr = buf_1.ptr as *mut u8; + for i in 0..buf_1.len { + *ptr.add(i) = ((i + 42) % 256) as u8; + } + } + } else { + panic!("Unsupported accelerator type: {}", self.parsed_accel1.0); + } + + // Register buffers with actors + let rdma_handle_1 = self + .actor_1 + .request_buffer(self.client_1, buf_1.ptr as usize, buffer_size) + .await?; + let rdma_handle_2 = self + .actor_2 + .request_buffer(self.client_2, buf_2.ptr as usize, buffer_size) + .await?; + + self.buffer_pairs.push((buf_1, buf_2)); + Ok((rdma_handle_1, rdma_handle_2)) + } + pub async fn cleanup(self) -> Result<(), anyhow::Error> { self.actor_1 .release_buffer(self.client_1, self.rdma_handle_1.clone()) @@ -567,11 +804,29 @@ pub mod test_utils { .await } - pub async fn verify_buffers(&self, size: usize) -> Result<(), anyhow::Error> { + pub async fn verify_buffers(&self) -> Result<(), anyhow::Error> { + self.verify_buffer_pair(&self.buffer_1, &self.buffer_2) + .await?; + Ok(()) + } + + pub async fn verify_buffer_pair( + &self, + buf_1: &Buffer, + buf_2: &Buffer, + ) -> Result<(), anyhow::Error> { let mut buf_vec = Vec::new(); + if buf_1.len != buf_2.len { + return Err(anyhow::anyhow!( + "Buffers have different lengths: {} vs {}", + buf_1.len, + buf_2.len + )); + } + let size = buf_1.len; for (virtual_addr, cuda_context) in [ - (self.buffer_1.ptr, self.cuda_context_1), - (self.buffer_2.ptr, self.cuda_context_2), + (buf_1.ptr, self.cuda_context_1), + (buf_2.ptr, self.cuda_context_2), ] { if cuda_context.is_some() { let mut temp_buffer = vec![0u8; size].into_boxed_slice(); @@ -605,11 +860,34 @@ pub mod test_utils { let ptr2: *mut u8 = buf_vec[1].ptr as *mut u8; for i in 0..buf_vec[0].len { if *ptr1.add(i) != *ptr2.add(i) { - return Err(anyhow::anyhow!("Buffers are not equal at index {}", i)); + return Err(anyhow::anyhow!( + "Buffers are not equal at index {}, buf_1[{}]={}, buf_2[{}]={}", + i, + i, + *ptr1.add(i), + i, + *ptr2.add(i) + )); } } } Ok(()) } + + pub async fn verify_all_buffer_pairs(&self) -> Result<(), anyhow::Error> { + for i in 0..self.buffer_pairs.len() { + let (buf_1, buf_2) = &self.buffer_pairs[i]; + self.verify_buffer_pair(buf_1, buf_2).await.map_err(|e| { + anyhow::anyhow!( + "Failed to verify buffer pair {}: {:?} and {:?}: {}", + i, + buf_1, + buf_2, + e + ) + })?; + } + Ok(()) + } } } diff --git a/python/tests/rdma_load_test.py b/python/tests/rdma_load_test.py index 3f6df8364..1750a8a36 100644 --- a/python/tests/rdma_load_test.py +++ b/python/tests/rdma_load_test.py @@ -6,12 +6,12 @@ import argparse import asyncio +import dataclasses import os import random import statistics import time - # parse up front to extract env variables. args = None if __name__ == "__main__": @@ -56,6 +56,12 @@ default=10, help="Number of warmup iterations (default: 5)", ) + parser.add_argument( + "--n-concurrent-operations", + type=int, + default=1, + help="Number of concurrent operations (default: 1)", + ) args = parser.parse_args() @@ -72,6 +78,13 @@ from monarch.rdma import RDMABuffer +@dataclasses.dataclass +class RDMATestRequest: + buffer: RDMABuffer + shape: torch.Size + dtype: torch.dtype + + class RDMATest(Actor): def __init__( self, device: str = "cpu", operation: str = "write", size_mb: int = 64 @@ -91,76 +104,97 @@ async def set_other_actor(self, other_actor): self.other_actor = other_actor @endpoint - async def send(self, is_warmup=False) -> None: - shape = int( - 1024 * 1024 * self.size_mb / 4 * (0.5 * random.randint(1, 3)) - ) # Random size with +/- 50% variation based on user size - - # Use the device string directly - tensor = torch.rand(shape, dtype=torch.float32, device=self.device) - size_elem = tensor.numel() * tensor.element_size() - tensor_addr = tensor.data_ptr() - - # Critical validation - this should catch the null pointer issue - assert ( - tensor_addr != 0 - ), f"CRITICAL: Tensor has null pointer! Device: {device}, Shape: {shape}" - assert size_elem > 0, f"CRITICAL: Tensor has zero size! Size: {size_elem}" - - byte_view = tensor.view(torch.uint8).flatten() - # Validate byte_view too - byte_view_addr = byte_view.data_ptr() - assert ( - byte_view_addr != 0 - ), f"CRITICAL: Byte view has null pointer! Original addr: 0x{tensor_addr:x}" - assert ( - byte_view_addr == tensor_addr - ), f"CRITICAL: Address mismatch! Tensor: 0x{tensor_addr:x}, ByteView: 0x{byte_view_addr:x}" - - execution_start = time.time() - buffer = RDMABuffer(byte_view) - execution_end = time.time() - elapsed = execution_end - execution_start - - # Store timing and size data in this actor - size_elem = torch.numel(tensor) * tensor.element_size() - if not is_warmup: - self.timing_data.append(elapsed) - self.size_data.append(size_elem) - buffer_size = buffer.size() - assert buffer_size == size_elem, f"{buffer_size=} != {size_elem=}" + async def send(self, is_warmup=False, n_concurrent_operations=1) -> None: + requests: list[RDMATestRequest] = [] + for _ in range(n_concurrent_operations): + shape = int( + 1024 * 1024 * self.size_mb / 4 * (0.5 * random.randint(1, 3)) + ) # Random size with +/- 50% variation based on user size + + # Use the device string directly + tensor = torch.rand(shape, dtype=torch.float32, device=self.device) + size_elem = tensor.numel() * tensor.element_size() + tensor_addr = tensor.data_ptr() + + # Critical validation - this should catch the null pointer issue + assert ( + tensor_addr != 0 + ), f"CRITICAL: Tensor has null pointer! Device: {device}, Shape: {shape}" + assert size_elem > 0, f"CRITICAL: Tensor has zero size! Size: {size_elem}" + + byte_view = tensor.view(torch.uint8).flatten() + # Validate byte_view too + byte_view_addr = byte_view.data_ptr() + assert ( + byte_view_addr != 0 + ), f"CRITICAL: Byte view has null pointer! Original addr: 0x{tensor_addr:x}" + assert ( + byte_view_addr == tensor_addr + ), f"CRITICAL: Address mismatch! Tensor: 0x{tensor_addr:x}, ByteView: 0x{byte_view_addr:x}" + + execution_start = time.time() + buffer = RDMABuffer(byte_view) + execution_end = time.time() + elapsed = execution_end - execution_start + + # Store timing and size data in this actor + size_elem = torch.numel(tensor) * tensor.element_size() + if not is_warmup: + self.timing_data.append(elapsed) + self.size_data.append(size_elem) + buffer_size = buffer.size() + assert buffer_size == size_elem, f"{buffer_size=} != {size_elem=}" + + requests.append(RDMATestRequest(buffer, tensor.shape, tensor.dtype)) # Call recv - timing happens there - await self.other_actor.recv.call(buffer, tensor.shape, tensor.dtype, is_warmup) + await self.other_actor.recv.call(requests, is_warmup) - # cleanup - await buffer.drop() + for req in requests: + await req.buffer.drop() self.i += 1 @endpoint - async def recv(self, rdma_buffer, shape, dtype, is_warmup): + async def recv(self, requests, is_warmup): # Create receiving tensor on the same device - tensor = torch.rand(shape, dtype=dtype, device=self.device) - byte_view = tensor.view(torch.uint8).flatten() + sizes = [] + byte_views = [] + for req in requests: + shape = req.shape + dtype = req.dtype + tensor = torch.rand(shape, dtype=dtype, device=self.device) + sizes.append(tensor.numel() * tensor.element_size()) + byte_view = tensor.view(torch.uint8).flatten() + byte_views.append(byte_view) + + coros = [] + + for i, req in enumerate(requests): + rdma_buffer = req.buffer + byte_view = byte_views[i] + assert byte_view.numel() == rdma_buffer.size(), "size mismatch" + + async def op_coro(rdma_buffer=rdma_buffer, byte_view=byte_view): + if self.operation == "write": + await rdma_buffer.write_from(byte_view, timeout=5) + elif self.operation == "read": + await rdma_buffer.read_into(byte_view, timeout=5) + elif self.operation == "ping-pong": + if self.i % 2 == 0: + await rdma_buffer.write_from(byte_view, timeout=5) + else: + await rdma_buffer.read_into(byte_view, timeout=5) + + coros.append(op_coro(rdma_buffer=rdma_buffer, byte_view=byte_view)) execution_start = time.time() - - if self.operation == "write": - await rdma_buffer.write_from(byte_view, timeout=5) - elif self.operation == "read": - await rdma_buffer.read_into(byte_view, timeout=5) - elif self.operation == "ping-pong": - if self.i % 2 == 0: - await rdma_buffer.write_from(byte_view, timeout=5) - else: - await rdma_buffer.read_into(byte_view, timeout=5) - + await asyncio.gather(*coros) execution_end = time.time() elapsed = execution_end - execution_start # Store timing and size data in this actor - size_elem = torch.numel(tensor) * tensor.element_size() + size_elem = sum(sizes) if not is_warmup: self.timing_data.append(elapsed) self.size_data.append(size_elem) @@ -227,6 +261,7 @@ async def main( operation: str = "write", size_mb: int = 64, warmup_iterations: int = 10, + n_concurrent_operations: int = 1, ): # Adjust GPU allocation based on the device types device_0, device_1 = devices[0], devices[1] @@ -245,10 +280,12 @@ async def main( await actor_0.set_other_actor.call(actor_1) for i in range(warmup_iterations): - await actor_0.send.call(is_warmup=True) + await actor_0.send.call( + n_concurrent_operations=n_concurrent_operations, is_warmup=True + ) for i in range(iterations): - await actor_0.send.call() + await actor_0.send.call(n_concurrent_operations=n_concurrent_operations) # Have both actors print their statistics print("\n=== ACTOR 0 (Create Buffer) STATISTICS ===") @@ -313,5 +350,6 @@ async def main( args.operation, args.size, args.warmup_iterations, + args.n_concurrent_operations, ) )