Skip to content

Commit a390685

Browse files
committed
Rounding up workspace size according to allocation (page size).
1 parent c6ed147 commit a390685

File tree

2 files changed

+37
-12
lines changed

2 files changed

+37
-12
lines changed

flashinfer/comm/mnnvl.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,14 @@ def get_world_size(self) -> int:
882882
"""Get the total number of devices in the group"""
883883
return self.group_size
884884

885+
def get_allocation_size(self) -> int:
886+
"""Get the total allocation size (including signal pad)"""
887+
return self.allocation_size
888+
889+
def get_usable_buffer_size(self) -> int:
890+
"""Get the usable buffer size (excluding signal pad)"""
891+
return self.allocation_size - self.SIGNAL_PAD_SIZE
892+
885893
def _init_ipc_socket(self):
886894
if self.group_rank == 0:
887895
# Gnerate the opId
@@ -921,7 +929,7 @@ def _alloc_mn_mcast_mem(self, buf_size: int):
921929
alloc_granularity = checkCudaErrors(
922930
cuda.cuMemGetAllocationGranularity(
923931
allocation_prop,
924-
cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_MINIMUM,
932+
cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_RECOMMENDED,
925933
)
926934
)
927935

@@ -1124,8 +1132,8 @@ def lamport_initialize(self, rank: int, dtype: torch.dtype):
11241132
else:
11251133
raise ValueError(f"Unsupported dtype: {dtype}")
11261134

1127-
# Calculate number of elements that fit in allocation_size
1128-
num_elements = self.allocation_size // dsize
1135+
# Calculate number of elements that fit in allocation_size; We don't want to include the signal pad.
1136+
num_elements = (self.allocation_size - self.SIGNAL_PAD_SIZE) // dsize
11291137

11301138
checkCudaErrors(
11311139
memset_func(int(self.uc_ptrs[self.group_rank]), neg_zero, num_elements)
@@ -1153,7 +1161,7 @@ def __init__(
11531161
Constructor for McastGpuBuffer.
11541162
11551163
Args:
1156-
buf_size: The total size of the buffer in bytes
1164+
buf_size: The requested size of the buffer in bytes. The actual usable size may differ due to alignment requirements.
11571165
group_size: The number of ranks in the communication group
11581166
group_rank: The rank of the local process within the group
11591167
device: The CUDA device for buffer allocation
@@ -1168,7 +1176,8 @@ def __init__(
11681176
mn_nvlink,
11691177
comm_backend_for_handle_transfer,
11701178
)
1171-
self.buf_size = buf_size
1179+
# Update buf_size to reflect the actual usable buffer size after allocation
1180+
self.buf_size = self.mcast_device_memory.get_usable_buffer_size()
11721181
self.local_device = device
11731182

11741183
def lamport_initialize(self, rank: int, dtype: torch.dtype):

flashinfer/comm/trtllm_mnnvl_ar.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,11 @@ def __init__(
6161
6262
Args:
6363
mapping: Mapping configuration containing rank info
64-
buffer_size_in_bytes: The size in bytes for each lamport buffer. The actual allocation size will be NUM_LAMPORT_BUFFERS * buffer_size_in_bytes.
64+
buffer_size_in_bytes: The requested size in bytes for each lamport buffer. The actual allocation size may be larger due to alignment requirements. The actual usable size will be NUM_LAMPORT_BUFFERS * actual_buffer_size_per_lamport_buffer.
6565
"""
6666
if buffer_size_in_bytes is None:
67-
# Default to 512MB workspace size if not provided
68-
buffer_size_in_bytes = 512 * (1024**2)
67+
# Default to 16MB workspace size if not provided
68+
buffer_size_in_bytes = 16 * (1024**2)
6969
else:
7070
# Round up to the nearest multiple of 8MB
7171
buffer_size_in_bytes = math.ceil(buffer_size_in_bytes / (8 * (1024**2))) * (
@@ -78,22 +78,38 @@ def __init__(
7878
f"The buffer size in bytes {buffer_size_in_bytes} is greater than the maximum supported size (UINT32_MAX)."
7979
)
8080

81-
self.buffer_size_bytes = buffer_size_in_bytes
82-
self.workspace_size_bytes = buffer_size_in_bytes * self.NUM_LAMPORT_BUFFERS
81+
# Calculate total requested workspace size
82+
requested_workspace_size = buffer_size_in_bytes * self.NUM_LAMPORT_BUFFERS
83+
8384
self.rank = mapping.tp_rank
8485
self.tp_size = mapping.tp_size
8586
logging.debug(
86-
f"[MNNVL Allreduce] TP size: {mapping.tp_size}, rank: {mapping.tp_rank}, Allocating workspace with size {buffer_size_in_bytes} bytes."
87+
f"[MNNVL Allreduce] TP size: {mapping.tp_size}, rank: {mapping.tp_rank}, Allocating workspace with requested size {buffer_size_in_bytes} bytes per buffer."
8788
)
89+
90+
# Allocate the workspace
8891
self.mcast_buffer_handle = McastGPUBuffer(
89-
self.workspace_size_bytes,
92+
requested_workspace_size,
9093
mapping.tp_size,
9194
mapping.tp_rank,
9295
torch.device("cuda", mapping.local_rank),
9396
mapping.is_multi_node(),
9497
comm_backend,
9598
)
9699

100+
# Get the actual usable buffer size after allocation (buf_size is updated by McastGPUBuffer)
101+
allocated_size = self.mcast_buffer_handle.buf_size
102+
# We want the buffer size to be aligned to 16B which is the granularity for buffer management.
103+
self.buffer_size_bytes = (
104+
math.floor(allocated_size / self.NUM_LAMPORT_BUFFERS) // 16 * 16
105+
)
106+
# This workspace size is used for checking the buffer. We need to set it to the actual size in use. The buffer free logic does not rely on this size.
107+
self.workspace_size_bytes = self.buffer_size_bytes * self.NUM_LAMPORT_BUFFERS
108+
109+
logging.debug(
110+
f"[MNNVL Allreduce] Actual allocated size: {allocated_size} bytes, Actual buffer size per lamport buffer: {self.buffer_size_bytes} bytes, total workspace: {self.workspace_size_bytes} bytes."
111+
)
112+
97113
# We use FP32 for sentinel value regardless of the real dtype
98114
self.mcast_buffer_handle.lamport_initialize(mapping.tp_rank, torch.float32)
99115
# Wait until the initialization is done

0 commit comments

Comments
 (0)