diff --git a/tensorrt_llm/_torch/distributed/communicator.py b/tensorrt_llm/_torch/distributed/communicator.py index 83eb7157495..6abef56bc93 100644 --- a/tensorrt_llm/_torch/distributed/communicator.py +++ b/tensorrt_llm/_torch/distributed/communicator.py @@ -1,4 +1,6 @@ +import math import os +import pickle # nosec B403 from abc import ABC, abstractmethod from typing import Optional @@ -6,10 +8,15 @@ import torch import torch.distributed as dist -from tensorrt_llm._utils import (mpi_allgather, mpi_barrier, mpi_broadcast, - mpi_comm, mpi_isend, mpi_isend_object, - mpi_recv, mpi_recv_object, mpi_send, - mpi_send_object) +try: + from mpi4py import MPI +except Exception: + MPI = None # deferred; functions will error if used when ENABLE_MULTI_DEVICE is True + +from tensorrt_llm._utils import (mpi_allgather, mpi_barrier, mpi_comm, + mpi_isend, mpi_isend_object, mpi_recv, + mpi_recv_object, mpi_send, mpi_send_object) +from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE from tensorrt_llm.mapping import Mapping @@ -95,14 +102,235 @@ def allgather(self, obj, root=0): pass +def safe_broadcast(comm, obj, root=0, chunk_size: int = 4 * 1024 * 1024): + """ + Safely broadcasts potentially large objects by splitting into fixed-size chunks, + using raw-byte MPI.Bcast to avoid pickle5's out-of-band buffer allocations. + + Args: + comm: communicator to broadcast + obj: Python object to broadcast + root: Rank of the broadcasting process + chunk_size: Maximum size of each chunk in bytes (default: 4MB) + + Returns: + The broadcasted object on all ranks + """ + if not ENABLE_MULTI_DEVICE: + return obj + if ENABLE_MULTI_DEVICE and MPI is None: + raise RuntimeError( + "mpi4py is required when ENABLE_MULTI_DEVICE is True") + if chunk_size <= 0: + raise ValueError("chunk_size must be > 0") + rank = comm.Get_rank() + + # ---- Serialization phase (root only) ---- + # Header layout: [ok_flag, total_size, num_chunks] as int64 + header = np.zeros(3, dtype=np.int64) + if rank == root: + try: + serialized = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) + total_size = len(serialized) + num_chunks = math.ceil(total_size / + chunk_size) if total_size > 0 else 0 + header[:] = (1, total_size, num_chunks) + except Exception as e: + # Signal failure to all ranks, then raise + header[:] = (0, 0, 0) + comm.Bcast([header, MPI.INT64_T], root=root) + raise RuntimeError(f"Serialization failed: {str(e)}") from e + else: + serialized = None # not used on non-root before Bcast + + # ---- Metadata broadcast (Bcast the fixed-size header) ---- + comm.Bcast([header, MPI.INT64_T], root=root) + ok_flag, total_size, num_chunks = int(header[0]), int(header[1]), int( + header[2]) + if not ok_flag: + raise RuntimeError("Root rank failed during serialization") + + # ---- Allocate receive buffer (non-root) or build a view (root) ---- + # We broadcast raw bytes chunk by chunk. + if rank == root: + src_view = memoryview(serialized) + dst_buf = None + dst_view = None + else: + # Pre-allocate a contiguous byte buffer to receive the payload + dst_buf = bytearray(total_size) + dst_view = memoryview(dst_buf) + src_view = None # not used on non-root + + # ---- Chunked raw-byte broadcast with MPI.Bcast ---- + # Each round sends exactly `cur` bytes of the global payload. + offset = 0 + for i in range(num_chunks): + cur = min(chunk_size, total_size - offset) + if cur <= 0: + break # safety guard for zero-size payloads + + if rank == root: + # Root sends a slice of the source view + part = src_view[offset:offset + cur] + comm.Bcast([part, MPI.BYTE], root=root) + else: + # Non-root receives directly into the destination view + part = dst_view[offset:offset + cur] + comm.Bcast([part, MPI.BYTE], root=root) + + offset += cur + + # ---- Reconstruction and deserialization ---- + # Validate the received byte count and unpickle. + if rank == root: + # Root already has `serialized` + if len(serialized) != total_size: + raise RuntimeError( + f"Data size mismatch at root: expected {total_size}, got {len(serialized)}" + ) + try: + return pickle.loads(serialized) # nosec B301 + except Exception as e: + raise RuntimeError(f"Deserialization failed: {str(e)}") from e + else: + if len(dst_buf) != total_size: + raise RuntimeError( + f"Data size mismatch at rank {rank}: expected {total_size}, got {len(dst_buf)}" + ) + try: + return pickle.loads(dst_buf) # nosec B301 + except Exception as e: + raise RuntimeError(f"Deserialization failed: {str(e)}") from e + + +def safe_gather(comm, obj, root=0, chunk_size: int = 4 * 1024 * 1024): + """ + Safely gather potentially large objects by splitting into fixed-size chunks, + using raw-byte MPI.Gatherv. This variant uses Allgather on lengths so every + rank can compute sizes/displacements/total locally, removing extra broadcasts. + + Args: + comm: communicator to gather + obj: Python object to gather + root: Rank that receives the gathered objects + chunk_size: Per-round max bytes each rank contributes (default: 4MB) + + Returns: + On root: list of deserialized objects (len == comm.size) + On non-root: None + """ + if not ENABLE_MULTI_DEVICE: + return [obj] + if ENABLE_MULTI_DEVICE and MPI is None: + raise RuntimeError( + "mpi4py is required when ENABLE_MULTI_DEVICE is True") + if chunk_size <= 0: + raise ValueError("chunk_size must be > 0") + + rank = comm.Get_rank() + size = comm.Get_size() + if chunk_size <= 0: + raise ValueError("chunk_size must be > 0") + + # -- Serialize locally -- + try: + payload = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) + my_n = np.int64(len(payload)) + except Exception as e: + # Keep collectives aligned: every rank must call Allgather exactly once + _ = comm.allgather(int(-1)) + raise RuntimeError(f"Rank {rank} serialization failed: {e}") from e + + # -- Allgather lengths so all ranks know sizes and can compute displacements -- + # We allgather just the int64 length to minimize traffic. + lengths = np.array(comm.allgather(int(my_n)), + dtype=np.int64) # shape (size,) + if (lengths < 0).any(): + raise RuntimeError(f"Serialization failed on at least one rank") + # Every rank computes displacements & total locally and identically: + displs = np.zeros(size, dtype=np.int64) + if size > 1: + displs[1:] = np.cumsum(lengths[:-1]) + total = int(lengths.sum()) + + # -- Prepare buffers -- + sendbuf_full = np.frombuffer(payload, dtype=np.uint8, count=len(payload)) + if rank == root: + recvbuf = np.empty(total, + dtype=np.uint8) # single contiguous receive buffer + else: + recvbuf = None + + # -- Chunked Gatherv loop -- + # IMPORTANT: All ranks must execute the same number of Gatherv rounds. + # Using a deterministic schedule based only on (lengths, chunk_size): + # num_rounds = ceil(max(lengths)/chunk_size) + max_len = int(lengths.max()) if size > 0 else 0 + num_rounds = (max_len + chunk_size - 1) // chunk_size if max_len > 0 else 0 + + for r in range(num_rounds): + # Each rank contributes up to chunk_size bytes from its remaining payload + # this round. Round-local offset is r * chunk_size. + round_offs = r * chunk_size + # Per-rank count this round: + # count = max(0, min(chunk, length - round_offs)) + remaining = lengths - round_offs + remaining = np.maximum(remaining, 0) + counts64 = np.minimum(remaining, chunk_size).astype(np.int64) + + # Target displacements this round are base displs + round_offs (where count>0) + round_displs64 = displs + np.minimum(np.maximum(lengths, 0), round_offs) + + # Many MPI impls expect 32-bit ints for counts/displs in Gatherv + counts32 = counts64.astype(np.int32) + displs32 = round_displs64.astype(np.int32) + + # Local slice to send this round (may be zero-length) + send_start = min(round_offs, int(my_n)) + send_len = int(counts32[rank]) + send_part = sendbuf_full[send_start:send_start + send_len] + + if rank == root: + comm.Gatherv([send_part, MPI.BYTE], + [recvbuf, counts32, displs32, MPI.BYTE], + root=root) + else: + comm.Gatherv([send_part, MPI.BYTE], None, root=root) + + # Note: ranks with zero data (my_n == 0) still participate in every Gatherv + # round with count=0. This is required to keep the collectives matched. + + # -- Reconstruct on root -- + if rank == root: + out = [] + for i in range(size): + sz = int(lengths[i]) + if sz == 0: + # Deserialize a canonical empty/None. Adjust to your needs. + out.append(None) # None + continue + start = int(displs[i]) + blob = recvbuf[start:start + sz].tobytes() + try: + out.append(pickle.loads(blob)) # nosec B301 + except Exception as e: + raise RuntimeError( + f"Deserialization failed for rank {i}: {e}") from e + return out + + return None + + class MPIDist(Distributed): def __init__(self, mapping: Mapping): super().__init__(mapping) self.create_tp_comm() - def broadcast(self, obj, root=0): - return mpi_broadcast(obj, root) + def broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024): + comm = mpi_comm() + return safe_broadcast(comm, obj, root=root, chunk_size=chunk_size) def allgather(self, obj): return mpi_allgather(obj) @@ -138,11 +366,13 @@ def create_tp_comm(self): def tp_allgather(self, obj): return self.tp_comm.allgather(obj) - def tp_gather(self, obj): - return self.tp_comm.gather(obj) + def tp_gather(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024): + comm = self.tp_comm + return safe_gather(comm, obj, root=root, chunk_size=chunk_size) - def tp_broadcast(self, obj, root=0): - return self.tp_comm.bcast(obj, root) + def tp_broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024): + comm = self.tp_comm + return safe_broadcast(comm, obj, root=root, chunk_size=chunk_size) class TorchDist(Distributed):