diff --git a/python/tests/rdma_load_test.py b/python/tests/rdma_load_test.py index 3f6df8364..e8b5f4f5a 100644 --- a/python/tests/rdma_load_test.py +++ b/python/tests/rdma_load_test.py @@ -6,12 +6,12 @@ import argparse import asyncio +import dataclasses import os import random import statistics import time - # parse up front to extract env variables. args = None if __name__ == "__main__": @@ -56,6 +56,12 @@ default=10, help="Number of warmup iterations (default: 5)", ) + parser.add_argument( + "--n-concurrent-operations", + type=int, + default=1, + help="Number of concurrent operations (default: 1)", + ) args = parser.parse_args() @@ -72,6 +78,13 @@ from monarch.rdma import RDMABuffer +@dataclasses.dataclass +class RDMATestRequest: + buffer: RDMABuffer + shape: torch.Size + dtype: torch.dtype + + class RDMATest(Actor): def __init__( self, device: str = "cpu", operation: str = "write", size_mb: int = 64 @@ -91,76 +104,96 @@ async def set_other_actor(self, other_actor): self.other_actor = other_actor @endpoint - async def send(self, is_warmup=False) -> None: - shape = int( - 1024 * 1024 * self.size_mb / 4 * (0.5 * random.randint(1, 3)) - ) # Random size with +/- 50% variation based on user size - - # Use the device string directly - tensor = torch.rand(shape, dtype=torch.float32, device=self.device) - size_elem = tensor.numel() * tensor.element_size() - tensor_addr = tensor.data_ptr() - - # Critical validation - this should catch the null pointer issue - assert ( - tensor_addr != 0 - ), f"CRITICAL: Tensor has null pointer! Device: {device}, Shape: {shape}" - assert size_elem > 0, f"CRITICAL: Tensor has zero size! Size: {size_elem}" - - byte_view = tensor.view(torch.uint8).flatten() - # Validate byte_view too - byte_view_addr = byte_view.data_ptr() - assert ( - byte_view_addr != 0 - ), f"CRITICAL: Byte view has null pointer! Original addr: 0x{tensor_addr:x}" - assert ( - byte_view_addr == tensor_addr - ), f"CRITICAL: Address mismatch! Tensor: 0x{tensor_addr:x}, ByteView: 0x{byte_view_addr:x}" - - execution_start = time.time() - buffer = RDMABuffer(byte_view) - execution_end = time.time() - elapsed = execution_end - execution_start - - # Store timing and size data in this actor - size_elem = torch.numel(tensor) * tensor.element_size() - if not is_warmup: - self.timing_data.append(elapsed) - self.size_data.append(size_elem) - buffer_size = buffer.size() - assert buffer_size == size_elem, f"{buffer_size=} != {size_elem=}" + async def send(self, is_warmup=False, n_concurrent_operations=1) -> None: + requests: list[RDMATestRequest] = [] + for _ in range(n_concurrent_operations): + shape = int( + 1024 * 1024 * self.size_mb / 4 * (0.5 * random.randint(1, 3)) + ) # Random size with +/- 50% variation based on user size + + # Use the device string directly + tensor = torch.rand(shape, dtype=torch.float32, device=self.device) + size_elem = tensor.numel() * tensor.element_size() + tensor_addr = tensor.data_ptr() + + # Critical validation - this should catch the null pointer issue + assert ( + tensor_addr != 0 + ), f"CRITICAL: Tensor has null pointer! Device: {device}, Shape: {shape}" + assert size_elem > 0, f"CRITICAL: Tensor has zero size! Size: {size_elem}" + + byte_view = tensor.view(torch.uint8).flatten() + # Validate byte_view too + byte_view_addr = byte_view.data_ptr() + assert ( + byte_view_addr != 0 + ), f"CRITICAL: Byte view has null pointer! Original addr: 0x{tensor_addr:x}" + assert ( + byte_view_addr == tensor_addr + ), f"CRITICAL: Address mismatch! Tensor: 0x{tensor_addr:x}, ByteView: 0x{byte_view_addr:x}" + + execution_start = time.time() + buffer = RDMABuffer(byte_view) + execution_end = time.time() + elapsed = execution_end - execution_start + + # Store timing and size data in this actor + size_elem = torch.numel(tensor) * tensor.element_size() + if not is_warmup: + self.timing_data.append(elapsed) + self.size_data.append(size_elem) + buffer_size = buffer.size() + assert buffer_size == size_elem, f"{buffer_size=} != {size_elem=}" + + requests.append(RDMATestRequest(buffer, tensor.shape, tensor.dtype)) # Call recv - timing happens there - await self.other_actor.recv.call(buffer, tensor.shape, tensor.dtype, is_warmup) + await self.other_actor.recv.call(requests, is_warmup) - # cleanup - await buffer.drop() + for req in requests: + await req.buffer.drop() self.i += 1 @endpoint - async def recv(self, rdma_buffer, shape, dtype, is_warmup): + async def recv(self, requests, is_warmup): # Create receiving tensor on the same device - tensor = torch.rand(shape, dtype=dtype, device=self.device) - byte_view = tensor.view(torch.uint8).flatten() + sizes = [] + byte_views = [] + for req in requests: + shape = req.shape + dtype = req.dtype + tensor = torch.rand(shape, dtype=dtype, device=self.device) + sizes.append(tensor.numel() * tensor.element_size()) + byte_view = tensor.view(torch.uint8).flatten() + byte_views.append(byte_view) + + coros = [] + + for i, req in enumerate(requests): + rdma_buffer = req.buffer + byte_view = byte_views[i] + + async def op_coro(rdma_buffer=rdma_buffer, byte_view=byte_view): + if self.operation == "write": + await rdma_buffer.write_from(byte_view, timeout=5) + elif self.operation == "read": + await rdma_buffer.read_into(byte_view, timeout=5) + elif self.operation == "ping-pong": + if self.i % 2 == 0: + await rdma_buffer.write_from(byte_view, timeout=5) + else: + await rdma_buffer.read_into(byte_view, timeout=5) + + coros.append(op_coro(rdma_buffer=rdma_buffer, byte_view=byte_view)) execution_start = time.time() - - if self.operation == "write": - await rdma_buffer.write_from(byte_view, timeout=5) - elif self.operation == "read": - await rdma_buffer.read_into(byte_view, timeout=5) - elif self.operation == "ping-pong": - if self.i % 2 == 0: - await rdma_buffer.write_from(byte_view, timeout=5) - else: - await rdma_buffer.read_into(byte_view, timeout=5) - + await asyncio.gather(*coros) execution_end = time.time() elapsed = execution_end - execution_start # Store timing and size data in this actor - size_elem = torch.numel(tensor) * tensor.element_size() + size_elem = sum(sizes) if not is_warmup: self.timing_data.append(elapsed) self.size_data.append(size_elem) @@ -227,6 +260,7 @@ async def main( operation: str = "write", size_mb: int = 64, warmup_iterations: int = 10, + n_concurrent_operations: int = 1, ): # Adjust GPU allocation based on the device types device_0, device_1 = devices[0], devices[1] @@ -248,7 +282,7 @@ async def main( await actor_0.send.call(is_warmup=True) for i in range(iterations): - await actor_0.send.call() + await actor_0.send.call(n_concurrent_operations=n_concurrent_operations) # Have both actors print their statistics print("\n=== ACTOR 0 (Create Buffer) STATISTICS ===") @@ -313,5 +347,6 @@ async def main( args.operation, args.size, args.warmup_iterations, + args.n_concurrent_operations, ) )