Skip to content

Commit ba8f3ed

Browse files
authored
MNNVL All Reduce for large number of tokens (#2074)
<!-- .github/pull_request_template.md --> ## 📌 Description This PR does two things: * Add a check for the number of tokens and raise an exception if the max token size was exceeded * Adds an optional parameter to allow users to dial in an arbitrary workspace ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added an optional configurable workspace buffer size for all-reduce operations with a sensible default to preserve backwards compatibility. * Runtime input validation now enforces 2D inputs and token-count limits, with clearer error messages guiding corrective actions. * **Tests** * Expanded test coverage for workspace behavior: default sizing, explicit sizing, and negative tests for insufficient workspace. * Tests now allow supplying an explicit workspace size to validate allocation and reuse scenarios. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 37434ed commit ba8f3ed

File tree

2 files changed

+128
-24
lines changed

2 files changed

+128
-24
lines changed

flashinfer/comm/trtllm_mnnvl_ar.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def trtllm_mnnvl_rmsnorm(
122122

123123

124124
def get_allreduce_mnnvl_workspace(
125-
mapping: Mapping, dtype: torch.dtype
125+
mapping: Mapping, dtype: torch.dtype, buffer_size_in_bytes: Optional[int] = None
126126
) -> Tuple[McastGPUBuffer, torch.Tensor, int]:
127127
"""Get workspace buffers needed for multi-node NVLink all-reduce operation.
128128
@@ -138,6 +138,7 @@ def get_allreduce_mnnvl_workspace(
138138
Args:
139139
mapping: Tensor parallel mapping configuration containing rank info
140140
dtype: Data type of the tensors being reduced
141+
buffer_size_in_bytes: Optional buffer size. Practically, assign this to 3 * 2 * dtype.itemsize * hidden_dim * max_tokens
141142
142143
Returns:
143144
Tuple containing:
@@ -152,7 +153,9 @@ def get_allreduce_mnnvl_workspace(
152153
# LCM for hidden_dim: 2048, 4096, 5120, 7168, 8192 = 286720
153154
# max_num_elements must be a multiple of 286720
154155
lcm_hidden_dim = 286720
155-
TARGET_WORKSPACE_SIZE_BYTES = 12_000_000
156+
TARGET_WORKSPACE_SIZE_BYTES = (
157+
buffer_size_in_bytes if buffer_size_in_bytes is not None else 12_000_000
158+
)
156159
buffer_size_in_bytes = math.ceil(
157160
TARGET_WORKSPACE_SIZE_BYTES / (lcm_hidden_dim * stride)
158161
) * (lcm_hidden_dim * stride)
@@ -223,6 +226,17 @@ def trtllm_mnnvl_all_reduce(
223226
[Optional] out: Output tensor to store the result (required if wait_for_results is True)
224227
225228
"""
229+
230+
if len(inp.shape) != 2:
231+
raise ValueError(
232+
f"The input tensor must be 2D, got {len(inp.shape)}D. The shape is {inp.shape}."
233+
)
234+
235+
if inp.shape[0] > buffer_M:
236+
raise ValueError(
237+
f"The number of tokens in the input tensor {inp.shape[0]} is greater than the buffer_M {buffer_M}. This is not supported. Please increase the workspace size, or decrease the amount of tokens to at most {buffer_M}."
238+
)
239+
226240
module = get_trtllm_mnnvl_comm_module()
227241
module.trtllm_mnnvl_all_reduce(
228242
inp,

tests/comm/test_trtllm_mnnvl_allreduce.py

Lines changed: 112 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -147,25 +147,27 @@ def func(
147147
)
148148

149149

150-
"""Main test function that runs on each MPI rank"""
150+
"""Helper function to run the core MNNVL AllReduce test logic"""
151151

152152

153-
@pytest.mark.parametrize(
154-
"seq_lens",
155-
[
156-
[1],
157-
[4],
158-
[15],
159-
[27, 11, 24],
160-
[127],
161-
],
162-
) # Test with different sequence length lists
163-
@pytest.mark.parametrize("fusion", [False, True])
164-
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
165-
@pytest.mark.parametrize("hidden_size", [2048, 4096, 5120, 7168, 8192])
166-
def test_mnnvl_allreduce_full(
167-
monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int
153+
def run_mnnvl_ar_full(
154+
monkeypatch,
155+
seq_lens: list[int],
156+
fusion: bool,
157+
dtype: torch.dtype,
158+
hidden_size: int,
159+
explicit_workspace_bytes: int | None = None,
168160
):
161+
"""Core test logic for MNNVL AllReduce operations.
162+
163+
Args:
164+
monkeypatch: pytest monkeypatch fixture
165+
seq_lens: List of sequence lengths to test
166+
fusion: Whether to test fused allreduce+rmsnorm or just allreduce
167+
dtype: Data type for tensors
168+
hidden_size: Hidden dimension size
169+
explicit_workspace_bytes: If provided, use this workspace size instead of default
170+
"""
169171
monkeypatch.setenv("TRTLLM_FORCE_MNNVL_AR", "1") # force multi-node allreduce.
170172

171173
# Get MPI info
@@ -211,7 +213,9 @@ def test_mnnvl_allreduce_full(
211213
# This workspace is sized for the maximum expected sequence length and can be reused within each list
212214
# Each parameterized list gets its own fresh workspace allocation
213215
mcast_buffer_mnnvl, buffer_flags_mnnvl, max_num_elements_mnnvl = (
214-
trtllm_mnnvl_ar.get_allreduce_mnnvl_workspace(mapping, dtype)
216+
trtllm_mnnvl_ar.get_allreduce_mnnvl_workspace(
217+
mapping, dtype, buffer_size_in_bytes=explicit_workspace_bytes
218+
)
215219
)
216220

217221
multicast_ptr = mcast_buffer_mnnvl.get_multicast_ptr()
@@ -291,18 +295,21 @@ def test_mnnvl_allreduce_full(
291295
rank_failed = True
292296
failure_message = f"FAILED[rank={rank}]: seq_lens={seq_lens}, fusion={fusion}, dtype={dtype} failed: {e}"
293297
print(failure_message)
294-
# Gather failure status from all ranks
298+
299+
# Gather failure status from all ranks for logging
295300
all_failures = MPI.COMM_WORLD.allgather(rank_failed)
296301

297-
# If any rank failed, fail the test
298302
if any(all_failures):
299303
failed_ranks = [i for i, failed in enumerate(all_failures) if failed]
300304
if rank == 0:
301305
print(f"Test failed on ranks: {failed_ranks}")
302306

303-
# Fail the test on all ranks
304-
pytest.fail(f"Test failed on ranks {failed_ranks}")
305-
trtllm_mnnvl_ar.mpi_barrier()
307+
# Cleanup before re-raising
308+
if "mcast_buffer_mnnvl" in locals():
309+
del mcast_buffer_mnnvl
310+
311+
# Re-raise the original exception so it can be caught by pytest.raises in negative tests
312+
raise
306313

307314
finally:
308315
# Ensure cleanup happens for this list's workspace
@@ -311,3 +318,86 @@ def test_mnnvl_allreduce_full(
311318

312319
# Final synchronization and check for failures across all ranks
313320
trtllm_mnnvl_ar.mpi_barrier()
321+
322+
323+
"""Test with default workspace size"""
324+
325+
326+
@pytest.mark.parametrize(
327+
"seq_lens",
328+
[
329+
[1],
330+
[4],
331+
[15],
332+
[27, 11, 24],
333+
[127],
334+
],
335+
)
336+
@pytest.mark.parametrize("fusion", [False, True])
337+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
338+
@pytest.mark.parametrize("hidden_size", [2048, 4096, 5120, 7168, 8192])
339+
def test_mnnvl_allreduce_default_workspace(
340+
monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int
341+
):
342+
"""Test MNNVL AllReduce with default workspace size."""
343+
run_mnnvl_ar_full(monkeypatch, seq_lens, fusion, dtype, hidden_size)
344+
345+
346+
"""Test with explicit workspace size"""
347+
348+
349+
@pytest.mark.parametrize(
350+
"seq_lens",
351+
[
352+
[1, 4, 180],
353+
],
354+
)
355+
@pytest.mark.parametrize("fusion", [False, True])
356+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
357+
@pytest.mark.parametrize("hidden_size", [2048, 4096, 5120, 7168, 8192])
358+
def test_mnnvl_allreduce_explicit_workspace(
359+
monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int
360+
):
361+
"""Test MNNVL AllReduce with explicitly calculated workspace size."""
362+
# Calculate workspace to fit the maximum sequence length
363+
# buffer shape: [3, 2, buffer_tokens, hidden_dim]
364+
explicit_workspace_bytes = 3 * 2 * dtype.itemsize * hidden_size * max(seq_lens)
365+
run_mnnvl_ar_full(
366+
monkeypatch,
367+
seq_lens,
368+
fusion,
369+
dtype,
370+
hidden_size,
371+
explicit_workspace_bytes=explicit_workspace_bytes,
372+
)
373+
374+
375+
"""Negative test: workspace too small"""
376+
377+
378+
@pytest.mark.parametrize("fusion", [False, True])
379+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
380+
@pytest.mark.parametrize("hidden_size", [2048, 4096])
381+
def test_mnnvl_allreduce_workspace_too_small(
382+
monkeypatch, fusion: bool, dtype: torch.dtype, hidden_size: int
383+
):
384+
"""Test that MNNVL AllReduce fails gracefully when workspace is too small."""
385+
# Use a large sequence length that won't fit in a small workspace
386+
seq_len = 180
387+
388+
# Create a workspace that's too small (only enough for 10 tokens)
389+
small_workspace_bytes = 3 * 2 * dtype.itemsize * hidden_size * 10
390+
391+
# Expect a ValueError with a message about buffer_M being too small
392+
with pytest.raises((ValueError, RuntimeError)) as exc_info:
393+
run_mnnvl_ar_full(
394+
monkeypatch,
395+
[seq_len],
396+
fusion,
397+
dtype,
398+
hidden_size,
399+
explicit_workspace_bytes=small_workspace_bytes,
400+
)
401+
402+
# Verify the error message contains the expected text
403+
assert "greater than the buffer_M" in str(exc_info.value)

0 commit comments

Comments
 (0)