diff --git a/examples/python/remote_storage_example/README.md b/examples/python/remote_storage_example/README.md index cc5e39ee9..6234ea164 100644 --- a/examples/python/remote_storage_example/README.md +++ b/examples/python/remote_storage_example/README.md @@ -108,3 +108,21 @@ The system automatically selects the best available storage backend: 1. Initiator sends memory descriptors to target 2. Target performs storage-to-memory or memory-to-storage operations 3. Data is transferred between initiator and target memory + +Remote reads are implemented as a read from storage followed by a network write. + +Remote writes are implemented as a read from network following by a storage write. + +### Pipelining + +To improve performance of the remote storage server, we can pipeline operations to network and storage. This pipelining allows multiple threads to handle each request. However, in order to maintain correctness, the order of network and storage must happen in order for each individual remote storage operation. To do this, we implemented a simple pipelining scheme. This pipeline for remote writes is implemented as a simple read into NIXL descriptors from the network, followed by a write to storage (also through NIXL, but a different plugin). A remote read is similar, just reading into NIXL descriptors from storage and then writing to network. + +![Remote Operation Pipelines](storage_pipelines.png) + +### Performance Tips + +For high-speed storage and network hardware, you may need to tweak performance with a couple of environment variables. + +First, for optimal GDS performance, ensure you are using the GDS_MT backend with default concurrency. Additionally, you can use the cufile options described in the [GDS README](https://github.com/ai-dynamo/nixl/blob/main/src/plugins/cuda_gds/README.md). Also a reminder to check that your GDS setup is running true GPU-direct IO and not in compatibility mode. + +On the network side, remote reads from VRAM to DRAM can be limited by UCX rail selection. This can be tweaked by setting UCX_MAX_RMA_RAILS=1. However, with larger batch or message sizes, this might limit bandwidth and a higher number of rails might be needed. diff --git a/examples/python/remote_storage_example/client_server_diagram.png b/examples/python/remote_storage_example/client_server_diagram.png index 651995d89..28b57034b 100644 Binary files a/examples/python/remote_storage_example/client_server_diagram.png and b/examples/python/remote_storage_example/client_server_diagram.png differ diff --git a/examples/python/remote_storage_example/nixl_p2p_storage_example.py b/examples/python/remote_storage_example/nixl_p2p_storage_example.py index 1a63648bf..41141bb96 100644 --- a/examples/python/remote_storage_example/nixl_p2p_storage_example.py +++ b/examples/python/remote_storage_example/nixl_p2p_storage_example.py @@ -18,6 +18,7 @@ Demonstrates peer-to-peer storage transfers using NIXL with initiator and target modes. """ +import concurrent.futures import time import nixl_storage_utils as nsu @@ -27,14 +28,20 @@ logger = get_logger(__name__) -def execute_transfer(my_agent, local_descs, remote_descs, remote_name, operation): - handle = my_agent.initialize_xfer(operation, local_descs, remote_descs, remote_name) +def execute_transfer( + my_agent, local_descs, remote_descs, remote_name, operation, use_backends=[] +): + handle = my_agent.initialize_xfer( + operation, local_descs, remote_descs, remote_name, backends=use_backends + ) my_agent.transfer(handle) nsu.wait_for_transfer(my_agent, handle) my_agent.release_xfer_handle(handle) -def remote_storage_transfer(my_agent, my_mem_descs, operation, remote_agent_name): +def remote_storage_transfer( + my_agent, my_mem_descs, operation, remote_agent_name, iterations +): """Initiate remote memory transfer.""" if operation != "READ" and operation != "WRITE": logger.error("Invalid operation, exiting") @@ -45,14 +52,24 @@ def remote_storage_transfer(my_agent, my_mem_descs, operation, remote_agent_name else: operation = b"READ" + iterations_str = bytes(f"{iterations:04d}", "utf-8") # Send the descriptors that you want to read into or write from - logger.info(f"Sending {operation} request to {remote_agent_name}") + logger.info( + "Sending %s request to %s", operation.decode("utf-8"), remote_agent_name + ) test_descs_str = my_agent.get_serialized_descs(my_mem_descs) - my_agent.send_notif(remote_agent_name, operation + test_descs_str) + + start_time = time.time() + + my_agent.send_notif(remote_agent_name, operation + iterations_str + test_descs_str) while not my_agent.check_remote_xfer_done(remote_agent_name, b"COMPLETE"): continue + elapsed = time.time() - start_time + + logger.info("Time for %d iterations: %f seconds", iterations, elapsed) + def connect_to_agents(my_agent, agents_file): target_agents = [] @@ -66,12 +83,12 @@ def connect_to_agents(my_agent, agents_file): my_agent.fetch_remote_metadata(parts[0], parts[1], int(parts[2])) while my_agent.check_remote_metadata(parts[0]) is False: - logger.info(f"Waiting for remote metadata for {parts[0]}...") + logger.info("Waiting for remote metadata for %s...", parts[0]) time.sleep(0.2) - logger.info(f"Remote metadata for {parts[0]} fetched") + logger.info("Remote metadata for %s fetched", parts[0]) else: - logger.error(f"Invalid line in {agents_file}: {line}") + logger.error("Invalid line in %s: %s", agents_file, line) exit(-1) logger.info("All remote metadata fetched") @@ -79,13 +96,141 @@ def connect_to_agents(my_agent, agents_file): return target_agents +def pipeline_reads( + my_agent, req_agent, my_mem_descs, my_file_descs, sent_descs, iterations +): + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + n = 0 + s = 0 + futures = [] + + while n < iterations or s < iterations: + if s == 0: + futures.append( + executor.submit( + execute_transfer, + my_agent, + my_mem_descs, + my_file_descs, + my_agent.name, + "READ", + ) + ) + s += 1 + continue + + if s == iterations: + futures.append( + executor.submit( + execute_transfer, + my_agent, + my_mem_descs, + sent_descs, + req_agent, + "WRITE", + ) + ) + n += 1 + continue + + # Do two storage and network in parallel + futures.append( + executor.submit( + execute_transfer, + my_agent, + my_mem_descs, + my_file_descs, + my_agent.name, + "READ", + ) + ) + futures.append( + executor.submit( + execute_transfer, + my_agent, + my_mem_descs, + sent_descs, + req_agent, + "WRITE", + ) + ) + s += 1 + n += 1 + + _, not_done = concurrent.futures.wait( + futures, return_when=concurrent.futures.ALL_COMPLETED + ) + assert not not_done + + +def pipeline_writes( + my_agent, req_agent, my_mem_descs, my_file_descs, sent_descs, iterations +): + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + n = 0 + s = 1 + futures = [] + + futures.append( + executor.submit( + execute_transfer, + my_agent, + my_mem_descs, + sent_descs, + req_agent, + "READ", + ) + ) + while n < iterations or s < iterations: + if s == iterations: + futures.append( + executor.submit( + execute_transfer, + my_agent, + my_mem_descs, + my_file_descs, + my_agent.name, + "WRITE", + ) + ) + n += 1 + continue + + # Do two storage and network in parallel + futures.append( + executor.submit( + execute_transfer, + my_agent, + my_mem_descs, + sent_descs, + req_agent, + "READ", + ) + ) + futures.append( + executor.submit( + execute_transfer, + my_agent, + my_mem_descs, + my_file_descs, + my_agent.name, + "WRITE", + ) + ) + s += 1 + n += 1 + + _, not_done = concurrent.futures.wait( + futures, return_when=concurrent.futures.ALL_COMPLETED + ) + assert not not_done + + def handle_remote_transfer_request(my_agent, my_mem_descs, my_file_descs): """Handle remote memory and storage transfers as target.""" # Wait for initiator to send list of memory descriptors notifs = my_agent.get_new_notifs() - logger.info("Waiting for a remote transfer request...") - while len(notifs) == 0: notifs = my_agent.get_new_notifs() @@ -101,57 +246,65 @@ def handle_remote_transfer_request(my_agent, my_mem_descs, my_file_descs): logger.error("Invalid operation, exiting") exit(-1) - sent_descs = my_agent.deserialize_descs(recv_msg[4:]) + iterations = int(recv_msg[4:8]) - logger.info("Checking to ensure metadata is loaded...") - while my_agent.check_remote_metadata(req_agent, sent_descs) is False: - continue + logger.info("Performing %s with %d iterations", operation, iterations) - if operation == "READ": - logger.info("Starting READ operation") + sent_descs = my_agent.deserialize_descs(recv_msg[8:]) - # Read from file first - execute_transfer( - my_agent, my_mem_descs, my_file_descs, my_agent.name, "READ" + if operation == "READ": + pipeline_reads( + my_agent, req_agent, my_mem_descs, my_file_descs, sent_descs, iterations ) - # Send to client - execute_transfer(my_agent, my_mem_descs, sent_descs, req_agent, "WRITE") - elif operation == "WRITE": - logger.info("Starting WRITE operation") - - # Read from client first - execute_transfer(my_agent, my_mem_descs, sent_descs, req_agent, "READ") - # Write to storage - execute_transfer( - my_agent, my_mem_descs, my_file_descs, my_agent.name, "WRITE" + pipeline_writes( + my_agent, req_agent, my_mem_descs, my_file_descs, sent_descs, iterations ) # Send completion notification to initiator my_agent.send_notif(req_agent, b"COMPLETE") - logger.info("One transfer test complete.") - -def run_client(my_agent, nixl_mem_reg_descs, nixl_file_reg_descs, agents_file): +def run_client( + my_agent, nixl_mem_reg_descs, nixl_file_reg_descs, agents_file, iterations +): logger.info("Client initialized, ready for local transfer test...") # For sample purposes, write to and then read from local storage logger.info("Starting local transfer test...") - execute_transfer( - my_agent, - nixl_mem_reg_descs.trim(), - nixl_file_reg_descs.trim(), - my_agent.name, - "WRITE", - ) - execute_transfer( - my_agent, - nixl_mem_reg_descs.trim(), - nixl_file_reg_descs.trim(), - my_agent.name, - "READ", - ) + + start_time = time.time() + + for i in range(1, iterations): + execute_transfer( + my_agent, + nixl_mem_reg_descs.trim(), + nixl_file_reg_descs.trim(), + my_agent.name, + "WRITE", + ["GDS_MT"], + ) + + elapsed = time.time() - start_time + + logger.info("Time for %d WRITE iterations: %f seconds", iterations, elapsed) + + start_time = time.time() + + for i in range(1, iterations): + execute_transfer( + my_agent, + nixl_mem_reg_descs.trim(), + nixl_file_reg_descs.trim(), + my_agent.name, + "READ", + ["GDS_MT"], + ) + + elapsed = time.time() - start_time + + logger.info("Time for %d READ iterations: %f seconds", iterations, elapsed) + logger.info("Local transfer test complete") logger.info("Starting remote transfer test...") @@ -161,10 +314,10 @@ def run_client(my_agent, nixl_mem_reg_descs, nixl_file_reg_descs, agents_file): # For sample purposes, write to and then read from each target agent for target_agent in target_agents: remote_storage_transfer( - my_agent, nixl_mem_reg_descs.trim(), "WRITE", target_agent + my_agent, nixl_mem_reg_descs.trim(), "WRITE", target_agent, iterations ) remote_storage_transfer( - my_agent, nixl_mem_reg_descs.trim(), "READ", target_agent + my_agent, nixl_mem_reg_descs.trim(), "READ", target_agent, iterations ) logger.info("Remote transfer test complete") @@ -199,8 +352,19 @@ def run_storage_server(my_agent, nixl_mem_reg_descs, nixl_file_reg_descs): type=str, help="File containing list of target agents (only needed for client)", ) + parser.add_argument( + "--iterations", + type=int, + default=100, + help="Number of iterations for each transfer", + ) args = parser.parse_args() + mem = "DRAM" + + if args.role == "client": + mem = "VRAM" + my_agent = nsu.create_agent_with_plugins(args.name, args.port) ( @@ -209,7 +373,7 @@ def run_storage_server(my_agent, nixl_mem_reg_descs, nixl_file_reg_descs): nixl_mem_reg_descs, nixl_file_reg_descs, ) = nsu.setup_memory_and_files( - my_agent, args.batch_size, args.buf_size, args.fileprefix + my_agent, args.batch_size, args.buf_size, args.fileprefix, mem ) if args.role == "client": @@ -217,7 +381,11 @@ def run_storage_server(my_agent, nixl_mem_reg_descs, nixl_file_reg_descs): parser.error("--agents_file is required when role is client") try: run_client( - my_agent, nixl_mem_reg_descs, nixl_file_reg_descs, args.agents_file + my_agent, + nixl_mem_reg_descs, + nixl_file_reg_descs, + args.agents_file, + args.iterations, ) finally: nsu.cleanup_resources( diff --git a/examples/python/remote_storage_example/nixl_storage_utils/common.py b/examples/python/remote_storage_example/nixl_storage_utils/common.py index 5d25f8acd..13d6e899a 100644 --- a/examples/python/remote_storage_example/nixl_storage_utils/common.py +++ b/examples/python/remote_storage_example/nixl_storage_utils/common.py @@ -21,6 +21,8 @@ import argparse import os +import torch + import nixl._utils as nixl_utils from nixl._api import nixl_agent, nixl_agent_config from nixl._bindings import DRAM_SEG @@ -36,14 +38,14 @@ def create_agent_with_plugins(agent_name, port): plugin_list = new_nixl_agent.get_plugin_list() - if "GDS" in plugin_list: - new_nixl_agent.create_backend("GDS") + if "GDS_MT" in plugin_list: + new_nixl_agent.create_backend("GDS_MT") logger.info("Using GDS storage backend") if "POSIX" in plugin_list: new_nixl_agent.create_backend("POSIX") logger.info("Using POSIX storage backend") - if "GDS" not in plugin_list and "POSIX" not in plugin_list: + if "GDS_MT" not in plugin_list and "POSIX" not in plugin_list: logger.error("No storage backends available, exiting") exit(-1) @@ -57,20 +59,30 @@ def create_agent_with_plugins(agent_name, port): return new_nixl_agent -def setup_memory_and_files(agent, batch_size, buf_size, fileprefix): +def setup_memory_and_files(agent, batch_size, buf_size, fileprefix, mem="DRAM"): """Setup memory and file resources.""" my_mem_list = [] my_file_list = [] nixl_mem_reg_list = [] nixl_file_reg_list = [] + if mem == "VRAM": + torch.set_default_device("cuda:0") + for i in range(batch_size): - my_mem_list.append(nixl_utils.malloc_passthru(buf_size)) - my_file_list.append(os.open(f"{fileprefix}_{i}", os.O_RDWR | os.O_CREAT)) - nixl_mem_reg_list.append((my_mem_list[-1], buf_size, 0, str(i))) + if mem == "VRAM": + my_mem_list.append(torch.full((buf_size,), 0, dtype=torch.int8)) + nixl_mem_reg_list.append(my_mem_list[-1]) + else: + my_mem_list.append(nixl_utils.malloc_passthru(buf_size)) + nixl_mem_reg_list.append((my_mem_list[-1], buf_size, 0, str(i))) + + my_file_list.append( + os.open(f"{fileprefix}_{i}", os.O_RDWR | os.O_CREAT | os.O_DIRECT) + ) nixl_file_reg_list.append((0, buf_size, my_file_list[-1], str(i))) - nixl_mem_reg_descs = agent.register_memory(nixl_mem_reg_list, "DRAM") + nixl_mem_reg_descs = agent.register_memory(nixl_mem_reg_list, mem) nixl_file_reg_descs = agent.register_memory(nixl_file_reg_list, "FILE") assert nixl_mem_reg_descs is not None @@ -89,8 +101,7 @@ def cleanup_resources(agent, mem_reg_descs, file_reg_descs, mem_list, file_list) for mem in mem_list: nixl_utils.free_passthru(mem) else: - agent.deregister_memory(file_reg_descs, backends=["GDS"]) - # TODO: cudaFree + agent.deregister_memory(file_reg_descs, backends=["GDS_MT"]) for file in file_list: os.close(file) diff --git a/examples/python/remote_storage_example/storage_pipeline.png b/examples/python/remote_storage_example/storage_pipeline.png new file mode 100644 index 000000000..9a847a5f3 Binary files /dev/null and b/examples/python/remote_storage_example/storage_pipeline.png differ diff --git a/src/bindings/python/nixl_utils.cpp b/src/bindings/python/nixl_utils.cpp index 5a7da8d07..e72544ca2 100644 --- a/src/bindings/python/nixl_utils.cpp +++ b/src/bindings/python/nixl_utils.cpp @@ -20,7 +20,7 @@ namespace py = pybind11; //JUST FOR TESTING uintptr_t malloc_passthru(int size) { - return (uintptr_t) malloc(size); + return (uintptr_t)aligned_alloc(4096, size); } //JUST FOR TESTING