Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions aiter/dist/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
group: ProcessGroup,
device: Union[int, str, torch.device],
max_size=8192 * 1024 * 8 * 2, # In allreduce 2stage writemode, use 2x tmp buffer
enable_register_for_capturing: bool = True,
) -> None:
"""
Args:
Expand Down Expand Up @@ -147,6 +148,7 @@ def __init__(
# return

self.disabled = False
self.enable_register_for_capturing = enable_register_for_capturing
# buffers memory are owned by this Python class and passed to C++
# meta data composes of two parts: meta data for synchronization
# (256 bytes) and a temporary buffer for storing intermediate
Expand Down Expand Up @@ -313,8 +315,8 @@ def custom_all_reduce(
input,
use_new=use_new,
open_fp8_quant=open_fp8_quant,
registered_input=True,
registered_output=True
registered_input=self.enable_register_for_capturing,
registered_output=self.enable_register_for_capturing
)
else:
# if warm up, mimic the allocation pattern
Expand Down
Loading