From 6a6cbf0fbaf131ae7744eee5241cc7e014e5605b Mon Sep 17 00:00:00 2001 From: Kota Tsuyuzaki Date: Fri, 19 Sep 2025 13:07:52 +0000 Subject: [PATCH 1/3] Add New VRAM Example This commit provides VRAM transfer example in Python. The block base transfer model is inspired by vLLM and show transfer bandwidth with the specific configurations. Signed-off-by: Kota Tsuyuzaki --- examples/python/vram_example/README.md | 83 +++++++++ examples/python/vram_example/client.py | 229 +++++++++++++++++++++++++ examples/python/vram_example/server.py | 143 +++++++++++++++ examples/python/vram_example/utils.py | 82 +++++++++ 4 files changed, 537 insertions(+) create mode 100644 examples/python/vram_example/README.md create mode 100755 examples/python/vram_example/client.py create mode 100755 examples/python/vram_example/server.py create mode 100644 examples/python/vram_example/utils.py diff --git a/examples/python/vram_example/README.md b/examples/python/vram_example/README.md new file mode 100644 index 000000000..523dc6948 --- /dev/null +++ b/examples/python/vram_example/README.md @@ -0,0 +1,83 @@ + + +# NIXL VRAM Transfer Example + +This is an example of VRAM transfer using NIXL, inspired by the [vLLM](https://github.com/vllm-project/vllm) 0.10.0 KV cache transfer algorithm. It demonstrates a basic server-client model for transferring VRAM memory. The server process allocates tensor memory filled with ones, and the client copies these tensors into locally allocated tensors on the client side. + +## Memory Alignment + +As described above, this example follows vLLM's KV cache memory management. The transfer unit is controlled by a specific memory block size within each tensor. + +The memory layout is designed as follows: + +- Let **N** be the number of layers. The process creates **N** tensors. +- Each tensor maintains contiguous memory alignment but is logically divided into blocks of `block_size` tokens. +- The memory size of each block is calculated using the following formula: + +`KV * Heads * Dimensions * Tokens per block * Precision` + +For example, considerating [Mistral-Small-3.1](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/tree/main) for default vLLM config, with: + +- N (Attention layers) = 40 +- KV = 2 +- Heads = 8 +- Dimensions = 128 +- Tokens per block = 16 +- Precision = bf16 (2 bytes) + +The block size in bytes is: + +`2 * 8 * 128 * 16 * 2 = 65536 (64KB)` + +The memory alignment looks like this: + +``` +tensor 0 [(64KB)|(64KB)|(64KB)| ... |(64KB)] +tensor 1 [(64KB)|(64KB)|(64KB)| ... |(64KB)] +... +tensor 39 [(64KB)|(64KB)|(64KB)| ... |(64KB)] +``` + +If 4GB of VRAM is available for Mistral-Small-3.1, the number of blocks per tensor is calculated as: + +`((4 * 1024 * 1024 * 1024) / (64 * 1024)) // 40 = 1638 (blocks)` + + +## Transfer Model + +This example uses the memory alignment described above. NIXL has a capability to manage memory blocks using indices. Following vLLM's approach, the process first reserves blocks for KV cache. These reserved blocks maintain consistent indices across all layer tensors. Therefore, this example transfers data from the head block up to the number of blocks required for the given input tokens. + +## Usage +Start the server process to wait for incoming requests: + +``` +python server.py +``` + +Then, launch the client process: + +``` +python client.py +``` + +Note that, the server process will be running with while-loop so use `Ctrl-C` or kill the process directly to terminate server process at the end. + +If you want to use a different GPU for each process, pass `CUDA_VISIBLE_DEVICES` environment variable to pin the GPU index. + +And the variables of the tranfer model can be configured by the argument. Please use `--help` to confirm available arguments. + + +> [!NOTE] +> The server runs in a `while` loop. Use `Ctrl-C` or terminate the process manually to stop the server. + +> [!TIP] +> To use different GPUs for each process, set the `CUDA_VISIBLE_DEVICES` environment variable to specify the GPU index. + +You can configure the transfer model parameters using command-line arguments. Use `--help` to view all available options. + +> [!NOTE] +> The parameters must be identical between the server and client processes to ensure successful transfer. + diff --git a/examples/python/vram_example/client.py b/examples/python/vram_example/client.py new file mode 100755 index 000000000..c2b081fef --- /dev/null +++ b/examples/python/vram_example/client.py @@ -0,0 +1,229 @@ +#!/usr/bin/env python3 + +# SPDX-FileCopyrightText: Copyright (c) 2025 Kota Tsuyuzaki +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import sys +import time + +import torch + +from nixl._api import nixl_agent, nixl_agent_config +from utils import calc_memory_blocks, get_logger, parse_args + +logger = get_logger(__name__) + + +def transfer( + agent, + size_in_bytes, + local_prep_handle, + addrs, + trans_blocks, + layers, + num_blocks, +): + # Ensure remote metadata has arrived. + ready = False + while not ready: + ready = agent.check_remote_metadata("server") + + # Handshake to server. + agent.send_notif("server", "SYN") + + notifs = agent.get_new_notifs() + + while len(notifs) == 0: + notifs = agent.get_new_notifs() + + target_descs = agent.deserialize_descs(notifs["server"][0]) + logger.debug("target_descs: %s", target_descs) + logger.debug("target descCount: %d", target_descs.descCount()) + logger.debug("target isEmpty: %s", target_descs.isEmpty()) + remote_prep_handle = agent.prep_xfer_dlist("server", target_descs, "VRAM") + + assert local_prep_handle != 0 + assert remote_prep_handle != 0 + + # Ensure remote metadata, again. + ready = False + while not ready: + ready = agent.check_remote_metadata("server") + + start = time.monotonic() + + # Calculate transfer data block indices. Simply this prepares first + # trans_blocks elements for each layer. + indices = [] + for layer in range(layers): + block_offset = layer * num_blocks + indices.extend([block_offset + i for i in range(trans_blocks)]) + + logger.debug("%d blocks will be transferred", len(indices)) + xfer_handle = agent.make_prepped_xfer( + "READ", + local_prep_handle, + indices, + remote_prep_handle, + indices, + b"UUID", + ["UCX"], + ) + + if not xfer_handle: + logger.error("Creating transfer failed.") + sys.exit() + + state = agent.transfer(xfer_handle) + if state == "ERR": + logger.error("Posting transfer failed.") + sys.exit() + + while True: + state = agent.check_xfer_state(xfer_handle) + if state == "ERR": + logger.error("Transfer got to Error state.") + sys.exit() + elif state == "DONE": + break + + end = time.monotonic() + ratio = len(indices) / len(addrs) + logger.info( + "Throughput: %f Mib/sec, in %f sec", + size_in_bytes * ratio / (end - start) / 1024 / 1024, + (end - start), + ) + agent.release_xfer_handle(xfer_handle) + + +def main(): + args = parse_args() + if args.debug: + logger.setLevel(logging.DEBUG) + + tensor_size, shape_len, num_blocks = calc_memory_blocks(args) + + count = args.count + + # Eumurate prefill with input_toekns + input_tokens = args.input_tokens + trans_blocks = input_tokens // args.block_size + logger.info("This test will transfer %d blocks", trans_blocks) + + device = "cuda" + config = nixl_agent_config(True, True, 0) + agent = nixl_agent("client", config) + + # Allocate memory and register with NIXL + tensors = [ + torch.zeros(tensor_size, dtype=torch.bfloat16, device=device) + for x in range(args.layers) + ] + + size_in_bytes = tensors[0].nelement() * tensors[0].element_size() * len(tensors) + logger.info("Client Tensors in MB: %d", size_in_bytes / 1024 / 1024) + + block_len = shape_len * tensors[0].element_size() # bytes of tensor + logger.debug("block_len: %d", block_len) + + reg_addrs = [] + for t in tensors: + reg_addrs.append((t[0].data_ptr(), tensor_size * t.element_size(), 0, "")) + + reg_descs = agent.get_reg_descs(reg_addrs, "VRAM") + + success = agent.register_memory(reg_descs) + if not success: # Same as reg_descs if successful + logger.error("Memory registration failed.") + sys.exit() + + # Create data block chunk to emulate vllm 0.10.0 data transfer + xfer_addrs = [] + for t in tensors: + base_addr = t.data_ptr() + for block_id in range(num_blocks): + offset = block_id * block_len + addr = base_addr + offset + xfer_addrs.append((addr, block_len, 0)) + + logger.info( + "addrs info: layers: %d, elements: %d, shape: %d, block_len: %d" + "addrs_len: %d", + len(tensors), + tensors[0].nelement(), + shape_len, + block_len, + len(xfer_addrs), + ) + + xfer_descs = agent.get_xfer_descs(xfer_addrs, "VRAM") + + logger.debug("xfer_descs: %s", xfer_descs) + logger.debug("descCount: %d", xfer_descs.descCount()) + logger.debug("isEmpty: %s", xfer_descs.isEmpty()) + + logger.info("Client sending to %s", args.ip) + agent.fetch_remote_metadata("server", args.ip, args.port) + agent.send_local_metadata(args.ip, args.port) + + # Prepare descriptor list. + local_prep_handle = agent.prep_xfer_dlist( + "NIXL_INIT_AGENT", + xfer_addrs, + "VRAM", + ) + + assert local_prep_handle != 0 + + trans_strategy = ["KEEPALIVE" for x in range(count)] + trans_strategy[-1] = "COMPLETE" + + for i, msg in enumerate(trans_strategy): + logger.debug("trans with %s", msg) + logger.debug(i) + transfer( + agent, + size_in_bytes, + local_prep_handle, + xfer_addrs, + trans_blocks, + args.layers, + num_blocks, + ) + msg = f"{i}:{msg}" + agent.send_notif("server", msg.encode()) + + # Verify data after read. + for i, tensor in enumerate(tensors): + check_blocks = trans_blocks * shape_len + if not torch.allclose( + tensor[:check_blocks], + torch.ones(check_blocks, dtype=torch.bfloat16, device=device), + ): + logger.error("Data verification failed for tensor %d.", i) + sys.exit() + + logger.info("Client Data verification passed") + logger.debug(tensors) + + agent.remove_remote_agent("server") + agent.invalidate_local_metadata(args.ip, args.port) + logger.info("Test Complete.") + + +if __name__ == "__main__": + main() diff --git a/examples/python/vram_example/server.py b/examples/python/vram_example/server.py new file mode 100755 index 000000000..af961ef68 --- /dev/null +++ b/examples/python/vram_example/server.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 + +# SPDX-FileCopyrightText: Copyright (c) 2025 Kota Tsuyuzaki +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import sys + +import torch + +from nixl._api import nixl_agent, nixl_agent_config +from utils import calc_memory_blocks, get_logger, parse_args + +logger = get_logger(__name__) + + +def _handle_request(agent, xfer_desc_str, msgs): + # Send desc list to initiator when metadata is ready. + ready = False + while not ready: + ready = agent.check_remote_metadata("client") + + # Handshake from client. + logger.debug("Waiting for handshake") + if not msgs: + notifs = agent.get_new_notifs() + while len(notifs) == 0: + notifs = agent.get_new_notifs() + msgs.extend(notifs["client"]) + assert msgs.pop(0) == b"SYN" + logger.info("Request received") + + agent.send_notif("client", xfer_desc_str) + + logger.debug("Waiting for transfer") + + # Waiting for transfer. + # For now the notification is just UUID, could be any python bytes. + # Also can have more than UUID, and check_remote_xfer_done returns + # the full python bytes, here it would be just UUID. + while not agent.check_remote_xfer_done("client", b"UUID"): + continue + + if not msgs: + # Check if we have to keep this connection. + client_notifs = agent.update_notifs(["UCX"])["client"] + while len(client_notifs) == 0: + notifs = agent.get_new_notifs() + client_notifs = notifs.get("client", []).copy() + msgs.extend(client_notifs) + agent.notifs["client"].clear() + + msg = msgs.pop(0).decode() + logger.debug("msgs: %s", msgs) + seq, msg = msg.split(":") + logger.debug(seq) + logger.info("Finalize request") + if msg != "KEEPALIVE": + logger.debug("Got trans message %s, %s, %s", notifs["client"], msg, seq) + logger.debug("Remove remote agent and fall back to initiation") + agent.remove_remote_agent("client") + + +def main(): + args = parse_args() + if args.debug: + logger.setLevel(logging.DEBUG) + + tensor_size, shape_len, num_blocks = calc_memory_blocks(args) + listen_port = args.port + device = "cuda" + + config = nixl_agent_config(True, True, listen_port) + agent = nixl_agent("server", config) + + # Allocate memory and register with NIXL. + tensors = [ + torch.ones(tensor_size, dtype=torch.bfloat16, device=device) + for x in range(args.layers) + ] + logger.debug("Tensor buffer for transfer... %s", tensors) + size_in_bytes = tensors[0].nelement() * tensors[0].element_size() * len(tensors) + logger.info("Server Tensor Buffer in MB: %d", size_in_bytes / 1024 / 1024) + + block_len = shape_len * tensors[0].element_size() # Bytes of tensor. + logger.debug("block_len: %d", block_len) + logger.debug("num_blocks: %d", num_blocks) + + reg_addrs = [] + t = tensors[0] + logger.debug( + "first ptr: %d, second ptr %d", t[0].data_ptr(), t[shape_len].data_ptr() + ) + logger.debug("distance: %d", t[shape_len].data_ptr() - t[0].data_ptr()) + logger.debug("nelement: %d", t.nelement()) + for t in tensors: + reg_addrs.append((t[0].data_ptr(), tensor_size * t.element_size(), 0, "")) + + reg_descs = agent.get_reg_descs(reg_addrs, "VRAM") + success = agent.register_memory(reg_descs) + + if not success: # Same as reg_descs if successful. + logger.error("Memory registration failed.") + sys.exit() + + xfer_addrs = [] + for t in tensors: + base_addr = t.data_ptr() + for block_id in range(num_blocks): + offset = block_id * block_len + addr = base_addr + offset + xfer_addrs.append((addr, block_len, 0)) + + xfer_descs = agent.get_xfer_descs(xfer_addrs, "VRAM") + xfer_desc_str = agent.get_serialized_descs(xfer_descs) + logger.info("Serialized xfer_desc str len: %d", len(xfer_desc_str)) + + try: + # Daemonize server process for testing until killed by hand. + msgs = [] + while True: + logger.debug("Waiting for initialization with msg: %s", msgs) + logger.debug(tensors) + _handle_request(agent, xfer_desc_str, msgs) + finally: + agent.deregister_memory(reg_descs) + logger.info("Test Complete.") + + +if __name__ == "__main__": + main() diff --git a/examples/python/vram_example/utils.py b/examples/python/vram_example/utils.py new file mode 100644 index 000000000..372818f1b --- /dev/null +++ b/examples/python/vram_example/utils.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 Kota Tsuyuzaki +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging + + +def parse_args(): + """ + Parse commandline argument + """ + parser = argparse.ArgumentParser() + parser.add_argument("--ip", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=5555) + parser.add_argument("--block-size", type=int, default=16) + parser.add_argument( + "--available-memory", + type=int, + default=4, + help="Estimated available memory size for kv cache (GB)", + ) + parser.add_argument("--heads", type=int, default=8) + parser.add_argument("--dims", type=int, default=128) + parser.add_argument("--layers", type=int, default=40) + parser.add_argument("--input-tokens", type=int, default=1024) + parser.add_argument("--count", type=int, default=1) + parser.add_argument("--debug", action="store_true", default=False) + return parser.parse_args() + + +def get_logger(name): + """ + Create default logger instance + """ + logger = logging.getLogger(name) + logger.setLevel(logging.INFO) + stream_handler = logging.StreamHandler() + formatter = logging.Formatter("%(levelname)s: %(message)s") + stream_handler.setFormatter(formatter) + logger.addHandler(stream_handler) + return logger + + +def calc_memory_blocks(args): + """ + Calculate tensor_size and num_blocks to allocate for test + """ + + # Emurate vllm kv cache as dense model. + # In default, the parameter is from Mistral-Small-3.1: + # 2(kv), 8 heads, 128 dimensions, 16 tokens per block, bf16 (2bytes), + # 40 layers + # + # And, the calculation of the block size in bytes is: + # 2 * 8 * 128 * 16 * 2 = 65536 (64KB) + # + # Then, if GPU has 4GB available memory for KV Cache, the number of + # blocks per layer is: + # ((4 * 1024 * 1024 * 1024) / (64 * 1024)) // 40 = 1638 (blocks) + + # Let's caclculate block size and alignment + gpu_mem = args.available_memory * 1024 * 1024 * 1024 + block_size = args.block_size + heads = args.heads + dims = args.dims + layers = args.layers + num_blocks = gpu_mem // layers // (2 * heads * dims * block_size * 2) + shape_len = 2 * block_size * heads * dims + tensor_size = shape_len * num_blocks + return tensor_size, shape_len, num_blocks From bcba809044bef49ef415416c74b3f058f21355d1 Mon Sep 17 00:00:00 2001 From: Kota Tsuyuzaki Date: Fri, 3 Oct 2025 00:50:34 +0000 Subject: [PATCH 2/3] Fix Copyright Identifier Because copyright check apparently requires NVIDIA's copyright. https://github.com/ai-dynamo/nixl/blob/main/.github/workflows/copyright-check.sh#L32-L35 --- examples/python/vram_example/README.md | 2 +- examples/python/vram_example/client.py | 2 +- examples/python/vram_example/server.py | 2 +- examples/python/vram_example/utils.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/python/vram_example/README.md b/examples/python/vram_example/README.md index 523dc6948..f8af662cb 100644 --- a/examples/python/vram_example/README.md +++ b/examples/python/vram_example/README.md @@ -1,5 +1,5 @@ diff --git a/examples/python/vram_example/client.py b/examples/python/vram_example/client.py index c2b081fef..dcd6cefb0 100755 --- a/examples/python/vram_example/client.py +++ b/examples/python/vram_example/client.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -# SPDX-FileCopyrightText: Copyright (c) 2025 Kota Tsuyuzaki +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/examples/python/vram_example/server.py b/examples/python/vram_example/server.py index af961ef68..186945781 100755 --- a/examples/python/vram_example/server.py +++ b/examples/python/vram_example/server.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -# SPDX-FileCopyrightText: Copyright (c) 2025 Kota Tsuyuzaki +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/examples/python/vram_example/utils.py b/examples/python/vram_example/utils.py index 372818f1b..aa1bb7ccb 100644 --- a/examples/python/vram_example/utils.py +++ b/examples/python/vram_example/utils.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 Kota Tsuyuzaki +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); From 78e15c09864afda0158dc33c719a0b8b4045c98e Mon Sep 17 00:00:00 2001 From: Kota Tsuyuzaki Date: Fri, 7 Nov 2025 08:03:42 +0000 Subject: [PATCH 3/3] Move check_remote_metadata call only once per connection --- examples/python/vram_example/client.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/examples/python/vram_example/client.py b/examples/python/vram_example/client.py index dcd6cefb0..92c63fa6b 100755 --- a/examples/python/vram_example/client.py +++ b/examples/python/vram_example/client.py @@ -36,11 +36,6 @@ def transfer( layers, num_blocks, ): - # Ensure remote metadata has arrived. - ready = False - while not ready: - ready = agent.check_remote_metadata("server") - # Handshake to server. agent.send_notif("server", "SYN") @@ -58,11 +53,6 @@ def transfer( assert local_prep_handle != 0 assert remote_prep_handle != 0 - # Ensure remote metadata, again. - ready = False - while not ready: - ready = agent.check_remote_metadata("server") - start = time.monotonic() # Calculate transfer data block indices. Simply this prepares first @@ -103,7 +93,7 @@ def transfer( end = time.monotonic() ratio = len(indices) / len(addrs) logger.info( - "Throughput: %f Mib/sec, in %f sec", + "Throughput: %f MiB/sec, in %f sec", size_in_bytes * ratio / (end - start) / 1024 / 1024, (end - start), ) @@ -178,6 +168,12 @@ def main(): logger.info("Client sending to %s", args.ip) agent.fetch_remote_metadata("server", args.ip, args.port) + + # Check if remote server is available. + ready = False + while not ready: + ready = agent.check_remote_metadata("server") + agent.send_local_metadata(args.ip, args.port) # Prepare descriptor list.