-
Notifications
You must be signed in to change notification settings - Fork 181
Remote storage pipeline #899
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
a800b56
93ab2a6
b7b3d68
b4640a3
4f83249
5f4fc62
1cc674c
43c0cd8
0126efc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -108,3 +108,21 @@ The system automatically selects the best available storage backend: | |
| 1. Initiator sends memory descriptors to target | ||
| 2. Target performs storage-to-memory or memory-to-storage operations | ||
| 3. Data is transferred between initiator and target memory | ||
|
|
||
| Remote reads are implemented as a read from storage followed by a network write. | ||
|
|
||
| Remote writes are implemented as a read from network following by a storage write. | ||
|
|
||
| ### Pipelining | ||
|
|
||
| To improve performance of the remote storage server, we can pipeline operations to network and storage. This pipelining allows multiple threads to handle each request. However, in order to maintain correctness, the order of network and storage must happen in order for each individual remote storage operation. To do this, we implemented a simple pipelining scheme. | ||
|
|
||
|  | ||
|
|
||
| ### Performance Tips | ||
|
|
||
| For high-speed storage and network hardware, you may need to tweak performance with a couple of environment variables. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please provide suggested list of env - vars, (like providing example configuration with CX-7, SPX and providing what UCX tuning/GDS tuning that is seen to be beneficial?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The next two lines are these suggestions. I don't know what other example configuration options you might be talking about. |
||
|
|
||
| First, for optimal GDS performance, ensure you are using the GDS_MT backend with default concurrency. Additionally, you can use the cufile options described in the [GDS README](https://github.com/ai-dynamo/nixl/blob/main/src/plugins/cuda_gds/README.md). | ||
|
||
|
|
||
| On the network side, remote reads from VRAM to DRAM can be limited by UCX rail selection. This can be tweaked by setting UCX_MAX_RMA_RAILS=1. However, with larger batch or message sizes, this might limit bandwidth and a higher number of rails might be needed. | ||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |||||||||
| Demonstrates peer-to-peer storage transfers using NIXL with initiator and target modes. | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| import concurrent.futures | ||||||||||
| import time | ||||||||||
|
|
||||||||||
| import nixl_storage_utils as nsu | ||||||||||
|
|
@@ -27,14 +28,20 @@ | |||||||||
| logger = get_logger(__name__) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def execute_transfer(my_agent, local_descs, remote_descs, remote_name, operation): | ||||||||||
| handle = my_agent.initialize_xfer(operation, local_descs, remote_descs, remote_name) | ||||||||||
| def execute_transfer( | ||||||||||
| my_agent, local_descs, remote_descs, remote_name, operation, use_backends=[] | ||||||||||
| ): | ||||||||||
| handle = my_agent.initialize_xfer( | ||||||||||
| operation, local_descs, remote_descs, remote_name, backends=use_backends | ||||||||||
| ) | ||||||||||
| my_agent.transfer(handle) | ||||||||||
| nsu.wait_for_transfer(my_agent, handle) | ||||||||||
| my_agent.release_xfer_handle(handle) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def remote_storage_transfer(my_agent, my_mem_descs, operation, remote_agent_name): | ||||||||||
| def remote_storage_transfer( | ||||||||||
| my_agent, my_mem_descs, operation, remote_agent_name, iterations | ||||||||||
| ): | ||||||||||
| """Initiate remote memory transfer.""" | ||||||||||
| if operation != "READ" and operation != "WRITE": | ||||||||||
| logger.error("Invalid operation, exiting") | ||||||||||
|
|
@@ -45,14 +52,22 @@ def remote_storage_transfer(my_agent, my_mem_descs, operation, remote_agent_name | |||||||||
| else: | ||||||||||
| operation = b"READ" | ||||||||||
|
|
||||||||||
| iterations_str = bytes(f"{iterations:04d}", "utf-8") | ||||||||||
| # Send the descriptors that you want to read into or write from | ||||||||||
| logger.info(f"Sending {operation} request to {remote_agent_name}") | ||||||||||
| test_descs_str = my_agent.get_serialized_descs(my_mem_descs) | ||||||||||
| my_agent.send_notif(remote_agent_name, operation + test_descs_str) | ||||||||||
|
|
||||||||||
| start_time = time.time() | ||||||||||
|
|
||||||||||
| my_agent.send_notif(remote_agent_name, operation + iterations_str + test_descs_str) | ||||||||||
|
|
||||||||||
| while not my_agent.check_remote_xfer_done(remote_agent_name, b"COMPLETE"): | ||||||||||
| continue | ||||||||||
|
|
||||||||||
| end_time = time.time() | ||||||||||
|
|
||||||||||
| logger.info(f"Time for {iterations} iterations: {end_time - start_time} seconds") | ||||||||||
|
||||||||||
|
|
||||||||||
|
|
||||||||||
| def connect_to_agents(my_agent, agents_file): | ||||||||||
| target_agents = [] | ||||||||||
|
|
@@ -79,13 +94,145 @@ def connect_to_agents(my_agent, agents_file): | |||||||||
| return target_agents | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def pipeline_reads( | ||||||||||
| my_agent, req_agent, my_mem_descs, my_file_descs, sent_descs, iterations | ||||||||||
| ): | ||||||||||
| with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: | ||||||||||
| n = 0 | ||||||||||
| s = 0 | ||||||||||
| futures = [] | ||||||||||
|
|
||||||||||
| while n < iterations and s < iterations: | ||||||||||
| if s == 0: | ||||||||||
| futures.append( | ||||||||||
| executor.submit( | ||||||||||
| execute_transfer, | ||||||||||
| my_agent, | ||||||||||
| my_mem_descs, | ||||||||||
| my_file_descs, | ||||||||||
| my_agent.name, | ||||||||||
| "READ", | ||||||||||
| ) | ||||||||||
| ) | ||||||||||
| s += 1 | ||||||||||
| continue | ||||||||||
|
Comment on lines
+108
to
+120
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can move this line before the loop, initiating I think it would simplify the loop and help avoid a branch
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure |
||||||||||
|
|
||||||||||
| if s == iterations: | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How can this flow happen with the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch, while loop should be an or |
||||||||||
| futures.append( | ||||||||||
| executor.submit( | ||||||||||
| execute_transfer, | ||||||||||
| my_agent, | ||||||||||
| my_mem_descs, | ||||||||||
| sent_descs, | ||||||||||
| req_agent, | ||||||||||
| "WRITE", | ||||||||||
| ) | ||||||||||
| ) | ||||||||||
| n += 1 | ||||||||||
| continue | ||||||||||
|
|
||||||||||
| # Do two storage and network in parallel | ||||||||||
| futures.append( | ||||||||||
| executor.submit( | ||||||||||
| execute_transfer, | ||||||||||
| my_agent, | ||||||||||
| my_mem_descs, | ||||||||||
| my_file_descs, | ||||||||||
| my_agent.name, | ||||||||||
| "READ", | ||||||||||
| ) | ||||||||||
| ) | ||||||||||
| futures.append( | ||||||||||
| executor.submit( | ||||||||||
| execute_transfer, | ||||||||||
| my_agent, | ||||||||||
| my_mem_descs, | ||||||||||
| sent_descs, | ||||||||||
| req_agent, | ||||||||||
| "WRITE", | ||||||||||
| ) | ||||||||||
| ) | ||||||||||
| s += 1 | ||||||||||
| n += 1 | ||||||||||
|
|
||||||||||
| _, not_done = concurrent.futures.wait( | ||||||||||
| futures, return_when=concurrent.futures.ALL_COMPLETED | ||||||||||
| ) | ||||||||||
| assert not not_done | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def pipeline_writes( | ||||||||||
| my_agent, req_agent, my_mem_descs, my_file_descs, sent_descs, iterations | ||||||||||
| ): | ||||||||||
| with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: | ||||||||||
| n = 0 | ||||||||||
| s = 0 | ||||||||||
| futures = [] | ||||||||||
|
|
||||||||||
| while n < iterations and s < iterations: | ||||||||||
| if s == 0: | ||||||||||
| futures.append( | ||||||||||
| executor.submit( | ||||||||||
| execute_transfer, | ||||||||||
| my_agent, | ||||||||||
| my_mem_descs, | ||||||||||
| sent_descs, | ||||||||||
| req_agent, | ||||||||||
| "READ", | ||||||||||
| ) | ||||||||||
| ) | ||||||||||
| s += 1 | ||||||||||
| continue | ||||||||||
|
|
||||||||||
| if s == iterations: | ||||||||||
| futures.append( | ||||||||||
| executor.submit( | ||||||||||
| execute_transfer, | ||||||||||
| my_agent, | ||||||||||
| my_mem_descs, | ||||||||||
| my_file_descs, | ||||||||||
| my_agent.name, | ||||||||||
| "WRITE", | ||||||||||
| ) | ||||||||||
| ) | ||||||||||
| n += 1 | ||||||||||
| continue | ||||||||||
|
|
||||||||||
| # Do two storage and network in parallel | ||||||||||
| futures.append( | ||||||||||
| executor.submit( | ||||||||||
| execute_transfer, | ||||||||||
| my_agent, | ||||||||||
| my_mem_descs, | ||||||||||
| sent_descs, | ||||||||||
| req_agent, | ||||||||||
| "READ", | ||||||||||
| ) | ||||||||||
| ) | ||||||||||
| futures.append( | ||||||||||
| executor.submit( | ||||||||||
| execute_transfer, | ||||||||||
| my_agent, | ||||||||||
| my_mem_descs, | ||||||||||
| my_file_descs, | ||||||||||
| my_agent.name, | ||||||||||
| "WRITE", | ||||||||||
| ) | ||||||||||
| ) | ||||||||||
| s += 1 | ||||||||||
| n += 1 | ||||||||||
|
|
||||||||||
| _, not_done = concurrent.futures.wait( | ||||||||||
| futures, return_when=concurrent.futures.ALL_COMPLETED | ||||||||||
| ) | ||||||||||
| assert not not_done | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def handle_remote_transfer_request(my_agent, my_mem_descs, my_file_descs): | ||||||||||
| """Handle remote memory and storage transfers as target.""" | ||||||||||
| # Wait for initiator to send list of memory descriptors | ||||||||||
| notifs = my_agent.get_new_notifs() | ||||||||||
|
|
||||||||||
| logger.info("Waiting for a remote transfer request...") | ||||||||||
|
|
||||||||||
| while len(notifs) == 0: | ||||||||||
| notifs = my_agent.get_new_notifs() | ||||||||||
|
|
||||||||||
|
|
@@ -101,57 +248,69 @@ def handle_remote_transfer_request(my_agent, my_mem_descs, my_file_descs): | |||||||||
| logger.error("Invalid operation, exiting") | ||||||||||
| exit(-1) | ||||||||||
|
|
||||||||||
| sent_descs = my_agent.deserialize_descs(recv_msg[4:]) | ||||||||||
| iterations = int(recv_msg[4:8]) | ||||||||||
|
|
||||||||||
| logger.info("Checking to ensure metadata is loaded...") | ||||||||||
| while my_agent.check_remote_metadata(req_agent, sent_descs) is False: | ||||||||||
| continue | ||||||||||
| logger.info(f"Performing {operation} with {iterations} iterations") | ||||||||||
|
|
||||||||||
| if operation == "READ": | ||||||||||
| logger.info("Starting READ operation") | ||||||||||
| sent_descs = my_agent.deserialize_descs(recv_msg[8:]) | ||||||||||
|
|
||||||||||
| # Read from file first | ||||||||||
| execute_transfer( | ||||||||||
| my_agent, my_mem_descs, my_file_descs, my_agent.name, "READ" | ||||||||||
| if operation == "READ": | ||||||||||
| pipeline_reads( | ||||||||||
| my_agent, req_agent, my_mem_descs, my_file_descs, sent_descs, iterations | ||||||||||
| ) | ||||||||||
| # Send to client | ||||||||||
| execute_transfer(my_agent, my_mem_descs, sent_descs, req_agent, "WRITE") | ||||||||||
|
|
||||||||||
| elif operation == "WRITE": | ||||||||||
| logger.info("Starting WRITE operation") | ||||||||||
|
|
||||||||||
| # Read from client first | ||||||||||
| execute_transfer(my_agent, my_mem_descs, sent_descs, req_agent, "READ") | ||||||||||
| # Write to storage | ||||||||||
| execute_transfer( | ||||||||||
| my_agent, my_mem_descs, my_file_descs, my_agent.name, "WRITE" | ||||||||||
| pipeline_writes( | ||||||||||
| my_agent, req_agent, my_mem_descs, my_file_descs, sent_descs, iterations | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| # Send completion notification to initiator | ||||||||||
| my_agent.send_notif(req_agent, b"COMPLETE") | ||||||||||
|
|
||||||||||
| logger.info("One transfer test complete.") | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def run_client(my_agent, nixl_mem_reg_descs, nixl_file_reg_descs, agents_file): | ||||||||||
| def run_client( | ||||||||||
| my_agent, nixl_mem_reg_descs, nixl_file_reg_descs, agents_file, iterations | ||||||||||
| ): | ||||||||||
| logger.info("Client initialized, ready for local transfer test...") | ||||||||||
|
|
||||||||||
| # For sample purposes, write to and then read from local storage | ||||||||||
| logger.info("Starting local transfer test...") | ||||||||||
| execute_transfer( | ||||||||||
| my_agent, | ||||||||||
| nixl_mem_reg_descs.trim(), | ||||||||||
| nixl_file_reg_descs.trim(), | ||||||||||
| my_agent.name, | ||||||||||
| "WRITE", | ||||||||||
| ) | ||||||||||
| execute_transfer( | ||||||||||
| my_agent, | ||||||||||
| nixl_mem_reg_descs.trim(), | ||||||||||
| nixl_file_reg_descs.trim(), | ||||||||||
| my_agent.name, | ||||||||||
| "READ", | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| start_time = time.time() | ||||||||||
|
|
||||||||||
| for i in range(1, iterations): | ||||||||||
| execute_transfer( | ||||||||||
| my_agent, | ||||||||||
| nixl_mem_reg_descs.trim(), | ||||||||||
| nixl_file_reg_descs.trim(), | ||||||||||
| my_agent.name, | ||||||||||
| "WRITE", | ||||||||||
| ["GDS_MT"], | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| end_time = time.time() | ||||||||||
|
|
||||||||||
| elapsed = end_time - start_time | ||||||||||
|
||||||||||
| end_time = time.time() | |
| elapsed = end_time - start_time | |
| elapsed = time.time() - start_time |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| end_time = time.time() | |
| elapsed = end_time - start_time | |
| elapsed = time.time() - start_time |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| logger.info(f"Time for {iterations} READ iterations: {elapsed} seconds") | |
| logger.info("Time for %s READ iterations: %s seconds", iterations, elapsed) |
We shouldn't use f-strings in loggers, it's not optimized, pylint would warn about it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we have some constants for "DRAM", "VRAM", "GDS_MT" and so on?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We pretty consistently expect python applications to just use strings, its a little clearer than constant definitions in the bindings. The python API will do these translations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant constant strings
Something like:
class MemoryType(StrEnum):
DRAM = "DRAM"
VRAM = "VRAM"
...Then usage will be:
MemoryType.DRAM
We have those enums in Rust and C++, thought we had something similar here, but if there isn't, it requires a separate PR as it's unrelated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tstamler please can you add some more description on how this is implemented with NIXL in this example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a brief description