Skip to content

Commit c6ed147

Browse files
committed
Address review comments.
1 parent 5be2697 commit c6ed147

File tree

2 files changed

+57
-28
lines changed

2 files changed

+57
-28
lines changed

csrc/trtllm_mnnvl_allreduce.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ using tvm::ffi::Optional;
2626
} \
2727
}()
2828

29-
// FIXME: is bool flag for oneshot a good idea? Trying to avoid defining a new type/enum at this
30-
// level
3129
void trtllm_mnnvl_allreduce_fusion(TensorView input, int64_t multicast_buffer_ptr,
3230
int64_t buffer_ptrs_dev, int64_t buffer_ptr_local,
3331
TensorView buffer_flags_mnnvl, int64_t nranks, int64_t rank,

flashinfer/comm/trtllm_mnnvl_ar.py

Lines changed: 57 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,14 @@ class MNNVLAllreduceFusionStrategy(Enum):
3333
AUTO = 99
3434

3535
@staticmethod
36-
def is_one_shot(
36+
def select_strategy(
3737
tp_size: int, num_tokens: int, hidden_dim: int, dtype: torch.dtype
38-
) -> bool:
38+
) -> "MNNVLAllreduceFusionStrategy":
3939
elem_size = torch.tensor([], dtype=dtype).element_size()
40-
return num_tokens * hidden_dim * tp_size * elem_size <= MNNVL_ONE_SHOT_THRESHOLD
40+
if num_tokens * hidden_dim * tp_size * elem_size <= MNNVL_ONE_SHOT_THRESHOLD:
41+
return MNNVLAllreduceFusionStrategy.ONESHOT
42+
else:
43+
return MNNVLAllreduceFusionStrategy.TWOSHOT
4144

4245

4346
# Empirical result calculated from num_tokens * hidden_dim * tp_size * elem_size
@@ -54,15 +57,15 @@ def __init__(
5457
comm_backend: Optional[CommBackend] = None,
5558
):
5659
"""
57-
Initialize the MNNVL Allreduce Fusion Workspace. COMM_WORLD will be used for creating the workspace and synchronization. The process might hang if the intended communication group in mapping is not COMM_WORLD.
60+
Initialize the MNNVL Allreduce Fusion Workspace. comm_backend will be used for creating the workspace and synchronization. If not provided, MPIBackend will be used which will use COMM_WORLD for synchronization.
5861
5962
Args:
6063
mapping: Mapping configuration containing rank info
6164
buffer_size_in_bytes: The size in bytes for each lamport buffer. The actual allocation size will be NUM_LAMPORT_BUFFERS * buffer_size_in_bytes.
6265
"""
6366
if buffer_size_in_bytes is None:
64-
# Default to 16MB workspace size if not provided
65-
buffer_size_in_bytes = 16 * (1024**2)
67+
# Default to 512MB workspace size if not provided
68+
buffer_size_in_bytes = 512 * (1024**2)
6669
else:
6770
# Round up to the nearest multiple of 8MB
6871
buffer_size_in_bytes = math.ceil(buffer_size_in_bytes / (8 * (1024**2))) * (
@@ -112,7 +115,28 @@ def __init__(
112115
self.uc_ptr_local = self.mcast_buffer_handle.get_unicast_ptr(self.rank)
113116
self.mc_ptr = self.mcast_buffer_handle.get_multicast_ptr()
114117

118+
@functools.cache
119+
def is_buffer_size_sufficient(
120+
self,
121+
tp_size: int,
122+
num_tokens: int,
123+
hidden_dim: int,
124+
dtype: torch.dtype,
125+
strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO,
126+
) -> bool:
127+
"""
128+
Calculate the required buffer size for a given problem size.
129+
"""
130+
required_buffer_size = self.get_required_buffer_size_bytes(
131+
tp_size, num_tokens, hidden_dim, dtype, strategy
132+
)
133+
if required_buffer_size > self.buffer_size_bytes:
134+
return False
135+
else:
136+
return True
137+
115138
@staticmethod
139+
@functools.cache
116140
def get_required_buffer_size_bytes(
117141
tp_size: int,
118142
num_tokens: int,
@@ -124,12 +148,12 @@ def get_required_buffer_size_bytes(
124148
Calculate the required buffer size for a given problem size.
125149
"""
126150
elem_size = torch.tensor([], dtype=dtype).element_size()
127-
is_one_shot = MNNVLAllreduceFusionStrategy.is_one_shot(
128-
tp_size, num_tokens, hidden_dim, dtype
129-
)
130-
if strategy == MNNVLAllreduceFusionStrategy.ONESHOT or (
131-
strategy == MNNVLAllreduceFusionStrategy.AUTO and is_one_shot
132-
):
151+
if strategy == MNNVLAllreduceFusionStrategy.AUTO:
152+
strategy = MNNVLAllreduceFusionStrategy.select_strategy(
153+
tp_size, num_tokens, hidden_dim, dtype
154+
)
155+
156+
if strategy == MNNVLAllreduceFusionStrategy.ONESHOT:
133157
# For one-shot, each rank needs to store num_tokens * tp_size tokens
134158
buffer_size = num_tokens * hidden_dim * tp_size * elem_size
135159
else:
@@ -268,12 +292,18 @@ def trtllm_mnnvl_allreduce(
268292

269293
module = get_trtllm_mnnvl_comm_module()
270294

271-
use_oneshot = strategy == MNNVLAllreduceFusionStrategy.ONESHOT or (
272-
strategy == MNNVLAllreduceFusionStrategy.AUTO
273-
and MNNVLAllreduceFusionStrategy.is_one_shot(
295+
if strategy == MNNVLAllreduceFusionStrategy.AUTO:
296+
strategy = MNNVLAllreduceFusionStrategy.select_strategy(
274297
workspace.tp_size, input.shape[0], input.shape[1], input.dtype
275298
)
276-
)
299+
300+
if not workspace.is_buffer_size_sufficient(
301+
workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy
302+
):
303+
raise ValueError(
304+
f"The buffer size in the given workspace is insufficient for the given problem size. Buffer: {workspace.buffer_size_bytes} bytes, Required: {workspace.get_required_buffer_size_bytes(workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy)} bytes."
305+
)
306+
277307
module.trtllm_mnnvl_allreduce_fusion(
278308
input,
279309
workspace.mc_ptr,
@@ -284,7 +314,7 @@ def trtllm_mnnvl_allreduce(
284314
workspace.rank,
285315
False, # No RMSNorm Fusion
286316
launch_with_pdl,
287-
use_oneshot,
317+
strategy == MNNVLAllreduceFusionStrategy.ONESHOT,
288318
output,
289319
None,
290320
None,
@@ -358,15 +388,16 @@ def trtllm_mnnvl_fused_allreduce_add_rmsnorm(
358388

359389
module = get_trtllm_mnnvl_comm_module()
360390

361-
use_oneshot = strategy == MNNVLAllreduceFusionStrategy.ONESHOT or (
362-
strategy == MNNVLAllreduceFusionStrategy.AUTO
363-
and MNNVLAllreduceFusionStrategy.is_one_shot(
364-
workspace.tp_size,
365-
input.shape[0],
366-
input.shape[1],
367-
input.dtype,
391+
if strategy == MNNVLAllreduceFusionStrategy.AUTO:
392+
strategy = MNNVLAllreduceFusionStrategy.select_strategy(
393+
workspace.tp_size, input.shape[0], input.shape[1], input.dtype
394+
)
395+
if not workspace.is_buffer_size_sufficient(
396+
workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy
397+
):
398+
raise ValueError(
399+
f"The buffer size in the given workspace is insufficient for the given problem size. Buffer: {workspace.buffer_size_bytes} bytes, Required: {workspace.get_required_buffer_size_bytes(workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy)} bytes."
368400
)
369-
)
370401

371402
module.trtllm_mnnvl_allreduce_fusion(
372403
input,
@@ -378,7 +409,7 @@ def trtllm_mnnvl_fused_allreduce_add_rmsnorm(
378409
workspace.rank,
379410
True, # RMSNorm Fusion
380411
launch_with_pdl,
381-
use_oneshot,
412+
strategy == MNNVLAllreduceFusionStrategy.ONESHOT,
382413
output,
383414
residual_out,
384415
residual_in,

0 commit comments

Comments
 (0)