Skip to content

Commit f23a211

Browse files
committed
[None][fix] Fix possible mpi broadcast and gather issue on large object (#7507)
Signed-off-by: Dongxu Yang <[email protected]>
1 parent 236f71e commit f23a211

File tree

1 file changed

+240
-10
lines changed

1 file changed

+240
-10
lines changed

tensorrt_llm/_torch/distributed/communicator.py

Lines changed: 240 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
1+
import math
12
import os
3+
import pickle # nosec B403
24
from abc import ABC, abstractmethod
35
from typing import Optional
46

57
import numpy as np
68
import torch
79
import torch.distributed as dist
810

9-
from tensorrt_llm._utils import (mpi_allgather, mpi_barrier, mpi_broadcast,
10-
mpi_comm, mpi_isend, mpi_isend_object,
11-
mpi_recv, mpi_recv_object, mpi_send,
12-
mpi_send_object)
11+
try:
12+
from mpi4py import MPI
13+
except Exception:
14+
MPI = None # deferred; functions will error if used when ENABLE_MULTI_DEVICE is True
15+
16+
from tensorrt_llm._utils import (mpi_allgather, mpi_barrier, mpi_comm,
17+
mpi_isend, mpi_isend_object, mpi_recv,
18+
mpi_recv_object, mpi_send, mpi_send_object)
19+
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
1320
from tensorrt_llm.mapping import Mapping
1421

1522

@@ -95,15 +102,236 @@ def allgather(self, obj, root=0):
95102
pass
96103

97104

105+
def safe_broadcast(comm, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
106+
"""
107+
Safely broadcasts potentially large objects by splitting into fixed-size chunks,
108+
using raw-byte MPI.Bcast to avoid pickle5's out-of-band buffer allocations.
109+
110+
Args:
111+
comm: communicator to broadcast
112+
obj: Python object to broadcast
113+
root: Rank of the broadcasting process
114+
chunk_size: Maximum size of each chunk in bytes (default: 4MB)
115+
116+
Returns:
117+
The broadcasted object on all ranks
118+
"""
119+
if not ENABLE_MULTI_DEVICE:
120+
return obj
121+
if ENABLE_MULTI_DEVICE and MPI is None:
122+
raise RuntimeError(
123+
"mpi4py is required when ENABLE_MULTI_DEVICE is True")
124+
if chunk_size <= 0:
125+
raise ValueError("chunk_size must be > 0")
126+
rank = comm.Get_rank()
127+
128+
# ---- Serialization phase (root only) ----
129+
# Header layout: [ok_flag, total_size, num_chunks] as int64
130+
header = np.zeros(3, dtype=np.int64)
131+
if rank == root:
132+
try:
133+
serialized = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
134+
total_size = len(serialized)
135+
num_chunks = math.ceil(total_size /
136+
chunk_size) if total_size > 0 else 0
137+
header[:] = (1, total_size, num_chunks)
138+
except Exception as e:
139+
# Signal failure to all ranks, then raise
140+
header[:] = (0, 0, 0)
141+
comm.Bcast([header, MPI.INT64_T], root=root)
142+
raise RuntimeError(f"Serialization failed: {str(e)}") from e
143+
else:
144+
serialized = None # not used on non-root before Bcast
145+
146+
# ---- Metadata broadcast (Bcast the fixed-size header) ----
147+
comm.Bcast([header, MPI.INT64_T], root=root)
148+
ok_flag, total_size, num_chunks = int(header[0]), int(header[1]), int(
149+
header[2])
150+
if not ok_flag:
151+
raise RuntimeError("Root rank failed during serialization")
152+
153+
# ---- Allocate receive buffer (non-root) or build a view (root) ----
154+
# We broadcast raw bytes chunk by chunk.
155+
if rank == root:
156+
src_view = memoryview(serialized)
157+
dst_buf = None
158+
dst_view = None
159+
else:
160+
# Pre-allocate a contiguous byte buffer to receive the payload
161+
dst_buf = bytearray(total_size)
162+
dst_view = memoryview(dst_buf)
163+
src_view = None # not used on non-root
164+
165+
# ---- Chunked raw-byte broadcast with MPI.Bcast ----
166+
# Each round sends exactly `cur` bytes of the global payload.
167+
offset = 0
168+
for i in range(num_chunks):
169+
cur = min(chunk_size, total_size - offset)
170+
if cur <= 0:
171+
break # safety guard for zero-size payloads
172+
173+
if rank == root:
174+
# Root sends a slice of the source view
175+
part = src_view[offset:offset + cur]
176+
comm.Bcast([part, MPI.BYTE], root=root)
177+
else:
178+
# Non-root receives directly into the destination view
179+
part = dst_view[offset:offset + cur]
180+
comm.Bcast([part, MPI.BYTE], root=root)
181+
182+
offset += cur
183+
184+
# ---- Reconstruction and deserialization ----
185+
# Validate the received byte count and unpickle.
186+
if rank == root:
187+
# Root already has `serialized`
188+
if len(serialized) != total_size:
189+
raise RuntimeError(
190+
f"Data size mismatch at root: expected {total_size}, got {len(serialized)}"
191+
)
192+
try:
193+
return pickle.loads(serialized) # nosec B301
194+
except Exception as e:
195+
raise RuntimeError(f"Deserialization failed: {str(e)}") from e
196+
else:
197+
if len(dst_buf) != total_size:
198+
raise RuntimeError(
199+
f"Data size mismatch at rank {rank}: expected {total_size}, got {len(dst_buf)}"
200+
)
201+
try:
202+
return pickle.loads(dst_buf) # nosec B301
203+
except Exception as e:
204+
raise RuntimeError(f"Deserialization failed: {str(e)}") from e
205+
206+
207+
def safe_gather(comm, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
208+
"""
209+
Safely gather potentially large objects by splitting into fixed-size chunks,
210+
using raw-byte MPI.Gatherv. This variant uses Allgather on lengths so every
211+
rank can compute sizes/displacements/total locally, removing extra broadcasts.
212+
213+
Args:
214+
comm: communicator to gather
215+
obj: Python object to gather
216+
root: Rank that receives the gathered objects
217+
chunk_size: Per-round max bytes each rank contributes (default: 4MB)
218+
219+
Returns:
220+
On root: list of deserialized objects (len == comm.size)
221+
On non-root: None
222+
"""
223+
if not ENABLE_MULTI_DEVICE:
224+
return [obj]
225+
if ENABLE_MULTI_DEVICE and MPI is None:
226+
raise RuntimeError(
227+
"mpi4py is required when ENABLE_MULTI_DEVICE is True")
228+
if chunk_size <= 0:
229+
raise ValueError("chunk_size must be > 0")
230+
231+
rank = comm.Get_rank()
232+
size = comm.Get_size()
233+
if chunk_size <= 0:
234+
raise ValueError("chunk_size must be > 0")
235+
236+
# -- Serialize locally --
237+
try:
238+
payload = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
239+
my_n = np.int64(len(payload))
240+
except Exception as e:
241+
# Keep collectives aligned: every rank must call Allgather exactly once
242+
_ = comm.allgather(int(-1))
243+
raise RuntimeError(f"Rank {rank} serialization failed: {e}") from e
244+
245+
# -- Allgather lengths so all ranks know sizes and can compute displacements --
246+
# We allgather just the int64 length to minimize traffic.
247+
lengths = np.array(comm.allgather(int(my_n)),
248+
dtype=np.int64) # shape (size,)
249+
if (lengths < 0).any():
250+
raise RuntimeError(f"Serialization failed on at least one rank")
251+
# Every rank computes displacements & total locally and identically:
252+
displs = np.zeros(size, dtype=np.int64)
253+
if size > 1:
254+
displs[1:] = np.cumsum(lengths[:-1])
255+
total = int(lengths.sum())
256+
257+
# -- Prepare buffers --
258+
sendbuf_full = np.frombuffer(payload, dtype=np.uint8, count=len(payload))
259+
if rank == root:
260+
recvbuf = np.empty(total,
261+
dtype=np.uint8) # single contiguous receive buffer
262+
else:
263+
recvbuf = None
264+
265+
# -- Chunked Gatherv loop --
266+
# IMPORTANT: All ranks must execute the same number of Gatherv rounds.
267+
# Using a deterministic schedule based only on (lengths, chunk_size):
268+
# num_rounds = ceil(max(lengths)/chunk_size)
269+
max_len = int(lengths.max()) if size > 0 else 0
270+
num_rounds = (max_len + chunk_size - 1) // chunk_size if max_len > 0 else 0
271+
272+
for r in range(num_rounds):
273+
# Each rank contributes up to chunk_size bytes from its remaining payload
274+
# this round. Round-local offset is r * chunk_size.
275+
round_offs = r * chunk_size
276+
# Per-rank count this round:
277+
# count = max(0, min(chunk, length - round_offs))
278+
remaining = lengths - round_offs
279+
remaining = np.maximum(remaining, 0)
280+
counts64 = np.minimum(remaining, chunk_size).astype(np.int64)
281+
282+
# Target displacements this round are base displs + round_offs (where count>0)
283+
round_displs64 = displs + np.minimum(np.maximum(lengths, 0), round_offs)
284+
285+
# Many MPI impls expect 32-bit ints for counts/displs in Gatherv
286+
counts32 = counts64.astype(np.int32)
287+
displs32 = round_displs64.astype(np.int32)
288+
289+
# Local slice to send this round (may be zero-length)
290+
send_start = min(round_offs, int(my_n))
291+
send_len = int(counts32[rank])
292+
send_part = sendbuf_full[send_start:send_start + send_len]
293+
294+
if rank == root:
295+
comm.Gatherv([send_part, MPI.BYTE],
296+
[recvbuf, counts32, displs32, MPI.BYTE],
297+
root=root)
298+
else:
299+
comm.Gatherv([send_part, MPI.BYTE], None, root=root)
300+
301+
# Note: ranks with zero data (my_n == 0) still participate in every Gatherv
302+
# round with count=0. This is required to keep the collectives matched.
303+
304+
# -- Reconstruct on root --
305+
if rank == root:
306+
out = []
307+
for i in range(size):
308+
sz = int(lengths[i])
309+
if sz == 0:
310+
# Deserialize a canonical empty/None. Adjust to your needs.
311+
out.append(None) # None
312+
continue
313+
start = int(displs[i])
314+
blob = recvbuf[start:start + sz].tobytes()
315+
try:
316+
out.append(pickle.loads(blob)) # nosec B301
317+
except Exception as e:
318+
raise RuntimeError(
319+
f"Deserialization failed for rank {i}: {e}") from e
320+
return out
321+
322+
return None
323+
324+
98325
class MPIDist(Distributed):
99326

100327
def __init__(self, mapping: Mapping):
101328
super().__init__(mapping)
102329
self.create_tp_comm()
103330
self.create_pp_comm()
104331

105-
def broadcast(self, obj, root=0):
106-
return mpi_broadcast(obj, root)
332+
def broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
333+
comm = mpi_comm()
334+
return safe_broadcast(comm, obj, root=root, chunk_size=chunk_size)
107335

108336
def allgather(self, obj):
109337
return mpi_allgather(obj)
@@ -143,11 +371,13 @@ def create_pp_comm(self):
143371
def tp_allgather(self, obj):
144372
return self.tp_comm.allgather(obj)
145373

146-
def tp_gather(self, obj):
147-
return self.tp_comm.gather(obj)
374+
def tp_gather(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
375+
comm = self.tp_comm
376+
return safe_gather(comm, obj, root=root, chunk_size=chunk_size)
148377

149-
def tp_broadcast(self, obj, root=0):
150-
return self.tp_comm.bcast(obj, root)
378+
def tp_broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
379+
comm = self.tp_comm
380+
return safe_broadcast(comm, obj, root=root, chunk_size=chunk_size)
151381

152382
def pp_allgather(self, obj):
153383
return self.pp_comm.allgather(obj)

0 commit comments

Comments
 (0)