Skip to content

Commit a2670e8

Browse files
committed
Incorporate 2056; Add test for legacy APIs
1 parent 6344177 commit a2670e8

File tree

3 files changed

+229
-36
lines changed

3 files changed

+229
-36
lines changed

flashinfer/comm/mnnvl.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1133,6 +1133,7 @@ def __init__(
11331133
group_rank: int,
11341134
device: torch.device,
11351135
mn_nvlink: bool = True,
1136+
comm_backend_for_handle_transfer: Optional[CommBackend] = None,
11361137
):
11371138
"""
11381139
Constructor for McastGpuBuffer.
@@ -1143,9 +1144,15 @@ def __init__(
11431144
group_rank: The rank of the local process within the group
11441145
device: The CUDA device for buffer allocation
11451146
mn_nvlink: Flag indicating if multi-node NVLink is used
1147+
comm_backend_for_handle_transfer: The communicator to use for handle transfer
11461148
"""
11471149
self.mcast_device_memory = McastDeviceMemory(
1148-
buf_size, group_size, group_rank, device.index, mn_nvlink
1150+
buf_size,
1151+
group_size,
1152+
group_rank,
1153+
device.index,
1154+
mn_nvlink,
1155+
comm_backend_for_handle_transfer,
11491156
)
11501157
self.buf_size = buf_size
11511158
self.local_device = device

flashinfer/comm/trtllm_mnnvl_ar.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from ..jit import gen_trtllm_mnnvl_comm_module
1919
from ..utils import register_custom_op
20-
from .mnnvl import McastGPUBuffer
20+
from .mnnvl import McastGPUBuffer, CommBackend, MPIBackend
2121

2222

2323
def mpi_barrier():
@@ -41,14 +41,18 @@ def is_one_shot(
4141

4242

4343
# Empirical result calculated from num_tokens * hidden_dim * tp_size * elem_size
44-
# TODO(Refactor): Consider moving this to a configuration class or file
4544
MNNVL_ONE_SHOT_THRESHOLD = 64 * 1024 * 8 * 2
4645

4746

4847
class MNNVLAllreduceFusionWorkspace:
4948
NUM_LAMPORT_BUFFERS = 3
5049

51-
def __init__(self, mapping: Mapping, buffer_size_in_bytes: Optional[int] = None):
50+
def __init__(
51+
self,
52+
mapping: Mapping,
53+
buffer_size_in_bytes: Optional[int] = None,
54+
comm_backend: Optional[CommBackend] = None,
55+
):
5256
"""
5357
Initialize the MNNVL Allreduce Fusion Workspace. COMM_WORLD will be used for creating the workspace and synchronization. The process might hang if the intended communication group in mapping is not COMM_WORLD.
5458
@@ -64,7 +68,8 @@ def __init__(self, mapping: Mapping, buffer_size_in_bytes: Optional[int] = None)
6468
buffer_size_in_bytes = math.ceil(buffer_size_in_bytes / (8 * (1024**2))) * (
6569
8 * (1024**2)
6670
)
67-
71+
if comm_backend is None:
72+
comm_backend = MPIBackend()
6873
if buffer_size_in_bytes > (2**32 - 1):
6974
raise ValueError(
7075
f"The buffer size in bytes {buffer_size_in_bytes} is greater than the maximum supported size (UINT32_MAX)."
@@ -83,14 +88,14 @@ def __init__(self, mapping: Mapping, buffer_size_in_bytes: Optional[int] = None)
8388
mapping.tp_rank,
8489
torch.device("cuda", mapping.local_rank),
8590
mapping.is_multi_node(),
91+
comm_backend,
8692
)
8793

8894
# We use FP32 for sentinel value regardless of the real dtype
8995
self.mcast_buffer_handle.lamport_initialize(mapping.tp_rank, torch.float32)
9096
# Wait until the initialization is done
9197
torch.cuda.synchronize()
92-
# FIXME: We are assuming using the COMM_WORLD.
93-
mpi_barrier()
98+
comm_backend.barrier()
9499

95100
# This is a buffer to maintain the state of this allreduce Op
96101
# Should have the same lifetime with self._buffer
@@ -391,7 +396,10 @@ def trtllm_mnnvl_fused_allreduce_add_rmsnorm(
391396
"get_allreduce_mnnvl_workspace is deprecated, use MNNVLAllreduceFusionWorkspace class to manage the workspace instead"
392397
)
393398
def get_allreduce_mnnvl_workspace(
394-
mapping: Mapping, dtype: torch.dtype, buffer_size_in_bytes: Optional[int] = None
399+
mapping: Mapping,
400+
dtype: torch.dtype,
401+
comm_backend_for_handle_transfer: Optional[CommBackend] = None,
402+
buffer_size_in_bytes: Optional[int] = None,
395403
) -> Tuple[McastGPUBuffer, torch.Tensor, int]:
396404
"""Get workspace buffers needed for multi-node NVLink all-reduce operation.
397405
@@ -428,7 +436,9 @@ def get_allreduce_mnnvl_workspace(
428436
) * (lcm_hidden_dim * stride)
429437

430438
# Redirect to the new workspace allocation logic. The new kernel needs the new flag buffer layout.
431-
workspace = MNNVLAllreduceFusionWorkspace(mapping, buffer_size_in_bytes)
439+
workspace = MNNVLAllreduceFusionWorkspace(
440+
mapping, buffer_size_in_bytes, comm_backend_for_handle_transfer
441+
)
432442

433443
mcast_buffer = workspace.mcast_buffer_handle
434444
buffer_flags = workspace.buffer_flags
@@ -497,7 +507,7 @@ def trtllm_mnnvl_all_reduce(
497507
)
498508
module = get_trtllm_mnnvl_comm_module()
499509
module.trtllm_mnnvl_allreduce_fusion(
500-
input,
510+
inp,
501511
multicast_buffer_ptr,
502512
buffer_ptrs_dev,
503513
0, # Allreduce kernel itself does not use this local pointer; still this could be risky but it is only used for legacy code compatibility.

tests/comm/test_trtllm_mnnvl_allreduce.py

Lines changed: 202 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,131 @@ def func(
101101
)
102102

103103

104+
@torch.inference_mode()
105+
def row_linear_residual_norm_fusion_forward_legacy(
106+
x: torch.Tensor,
107+
residual: torch.Tensor,
108+
norm_weight: torch.Tensor,
109+
eps: float,
110+
hidden_size: int,
111+
dtype: torch.dtype,
112+
mapping: Mapping,
113+
fusion: bool,
114+
reference_output: tuple[torch.Tensor, ...],
115+
multicast_ptr: int,
116+
buffer_ptrs_dev: int,
117+
unicast_ptr: int,
118+
max_num_elements_mnnvl: int,
119+
buffer_flags_mnnvl: torch.Tensor,
120+
):
121+
tensor_parallel_size = mapping.tp_size
122+
tensor_parallel_rank = mapping.tp_rank
123+
MPI.COMM_WORLD.barrier()
124+
125+
def func(
126+
input,
127+
residual,
128+
norm_weight,
129+
eps,
130+
enable_fusion,
131+
multicast_ptr,
132+
buffer_ptrs_dev,
133+
unicast_ptr,
134+
max_num_elements_mnnvl,
135+
):
136+
# For both fused and unfused cases:
137+
shape = input.shape
138+
input = input.view(-1, shape[-1])
139+
buffer_M = max_num_elements_mnnvl // hidden_size
140+
141+
if enable_fusion:
142+
use_pdl = True
143+
144+
prenorm_output = torch.empty_like(residual)
145+
normed_output = torch.empty_like(residual)
146+
147+
trtllm_mnnvl_ar.mpi_barrier()
148+
149+
trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_rmsnorm(
150+
prenorm_output,
151+
normed_output,
152+
input,
153+
multicast_ptr,
154+
buffer_ptrs_dev,
155+
unicast_ptr,
156+
buffer_M,
157+
buffer_flags_mnnvl,
158+
tensor_parallel_size,
159+
tensor_parallel_rank,
160+
norm_weight,
161+
eps,
162+
residual,
163+
use_pdl,
164+
)
165+
166+
return normed_output.view(shape), prenorm_output.view(shape)
167+
168+
else:
169+
output = torch.empty_like(input)
170+
171+
trtllm_mnnvl_ar.trtllm_mnnvl_all_reduce(
172+
input,
173+
multicast_ptr,
174+
buffer_ptrs_dev,
175+
buffer_M,
176+
buffer_flags_mnnvl,
177+
tensor_parallel_size,
178+
tensor_parallel_rank,
179+
True, # wait_for_results
180+
False, # launch_with_pdl
181+
output, # Need to provide output tensor since we are writing them out.
182+
)
183+
return (output.view(shape),)
184+
185+
output = func(
186+
x.clone(),
187+
residual.clone(),
188+
norm_weight,
189+
eps,
190+
fusion,
191+
multicast_ptr,
192+
buffer_ptrs_dev,
193+
unicast_ptr,
194+
max_num_elements_mnnvl,
195+
)
196+
197+
assert output[0].shape == reference_output[0].shape
198+
199+
if tensor_parallel_rank == 0:
200+
print("output[0] (first 10 values):", output[0].flatten()[:10])
201+
print(
202+
"reference_output[0] (first 10 values):",
203+
reference_output[0].flatten()[:10],
204+
)
205+
206+
if fusion:
207+
print("output[1] (first 10 values):", output[1].flatten()[:10])
208+
print(
209+
"reference_output[1] (first 10 values):",
210+
reference_output[1].flatten()[:10],
211+
)
212+
213+
torch.testing.assert_close(
214+
output[0],
215+
reference_output[0],
216+
rtol=0.05,
217+
atol=0.15,
218+
)
219+
220+
if fusion:
221+
torch.testing.assert_close(
222+
output[1],
223+
reference_output[1],
224+
rtol=0.05,
225+
atol=0.15,
226+
)
227+
228+
104229
"""Helper function to run the core MNNVL AllReduce test logic"""
105230

106231

@@ -146,7 +271,13 @@ def prepare_test_data(seq_len: int, hidden_size: int, dtype: torch.dtype, fusion
146271

147272

148273
def run_mnnvl_ar_full(
149-
monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int
274+
monkeypatch,
275+
seq_lens: list[int],
276+
fusion: bool,
277+
dtype: torch.dtype,
278+
hidden_size: int,
279+
legacy_explicit_workspace_bytes: int = None,
280+
legacy_api: bool = False,
150281
):
151282
"""Core test logic for MNNVL AllReduce operations.
152283
@@ -195,16 +326,30 @@ def run_mnnvl_ar_full(
195326
failure_message = ""
196327

197328
try:
198-
required_workspace_bytes = trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace.get_required_buffer_size_bytes(
199-
mapping.tp_size,
200-
max(seq_lens),
201-
hidden_size,
202-
dtype,
203-
trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.AUTO,
204-
)
205-
workspace = trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace(
206-
mapping, required_workspace_bytes
207-
)
329+
if legacy_api:
330+
mcast_buffer_mnnvl, buffer_flags_mnnvl, max_num_elements_mnnvl = (
331+
trtllm_mnnvl_ar.get_allreduce_mnnvl_workspace(
332+
mapping, dtype, buffer_size_in_bytes=legacy_explicit_workspace_bytes
333+
)
334+
)
335+
336+
multicast_ptr = mcast_buffer_mnnvl.get_multicast_ptr()
337+
buffer_ptrs_dev = mcast_buffer_mnnvl.get_buffer_ptrs_dev()
338+
unicast_ptr = mcast_buffer_mnnvl.mcast_device_memory.get_unicast_ptr(
339+
mapping.tp_rank
340+
)
341+
342+
else:
343+
required_workspace_bytes = trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace.get_required_buffer_size_bytes(
344+
mapping.tp_size,
345+
max(seq_lens),
346+
hidden_size,
347+
dtype,
348+
trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.AUTO,
349+
)
350+
workspace = trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace(
351+
mapping, required_workspace_bytes
352+
)
208353

209354
test_data = []
210355
for seq_len in seq_lens:
@@ -221,18 +366,34 @@ def run_mnnvl_ar_full(
221366
print(
222367
f"Testing seq_len={seq_len}, hidden_size={hidden_size}, fusion={fusion}, dtype={dtype}"
223368
)
224-
225-
# Run the test with the same workspace
226-
row_linear_residual_norm_fusion_forward(
227-
x,
228-
residual,
229-
norm_weight,
230-
eps,
231-
mapping,
232-
fusion,
233-
reference_output,
234-
workspace,
235-
)
369+
if legacy_api:
370+
row_linear_residual_norm_fusion_forward_legacy(
371+
x,
372+
residual,
373+
norm_weight,
374+
eps,
375+
hidden_size,
376+
dtype,
377+
mapping,
378+
fusion,
379+
reference_output,
380+
multicast_ptr,
381+
buffer_ptrs_dev,
382+
unicast_ptr,
383+
max_num_elements_mnnvl,
384+
buffer_flags_mnnvl,
385+
)
386+
else:
387+
row_linear_residual_norm_fusion_forward(
388+
x,
389+
residual,
390+
norm_weight,
391+
eps,
392+
mapping,
393+
fusion,
394+
reference_output,
395+
workspace,
396+
)
236397

237398
# Synchronize before next test
238399
trtllm_mnnvl_ar.mpi_barrier()
@@ -283,8 +444,23 @@ def run_mnnvl_ar_full(
283444
@pytest.mark.parametrize("fusion", [False, True])
284445
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
285446
@pytest.mark.parametrize("hidden_size", [2880, 5120, 7168, 8192])
286-
def test_mnnvl_allreduce_default_workspace(
447+
def test_mnnvl_allreduce_refactored(
448+
monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int
449+
):
450+
"""Test MNNVL AllReduce with refactored API."""
451+
run_mnnvl_ar_full(
452+
monkeypatch, seq_lens, fusion, dtype, hidden_size, legacy_api=False
453+
)
454+
455+
456+
@pytest.mark.parametrize("seq_lens", [[1], [4], [15], [27, 11, 24], [127]])
457+
@pytest.mark.parametrize("fusion", [False, True])
458+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
459+
@pytest.mark.parametrize("hidden_size", [2048, 4096, 5120, 7168, 8192])
460+
def test_mnnvl_allreduce_legacy(
287461
monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int
288462
):
289-
"""Test MNNVL AllReduce with default workspace size."""
290-
run_mnnvl_ar_full(monkeypatch, seq_lens, fusion, dtype, hidden_size)
463+
"""Test MNNVL AllReduce with legacy API."""
464+
run_mnnvl_ar_full(
465+
monkeypatch, seq_lens, fusion, dtype, hidden_size, legacy_api=True
466+
)

0 commit comments

Comments
 (0)