diff --git a/aiter/dist/device_communicators/custom_all_reduce.py b/aiter/dist/device_communicators/custom_all_reduce.py index d01010ae53..bf15f70d43 100644 --- a/aiter/dist/device_communicators/custom_all_reduce.py +++ b/aiter/dist/device_communicators/custom_all_reduce.py @@ -55,6 +55,7 @@ def __init__( group: ProcessGroup, device: Union[int, str, torch.device], max_size=8192 * 1024 * 8 * 2, # In allreduce 2stage writemode, use 2x tmp buffer + enable_register_for_capturing: bool = True, ) -> None: """ Args: @@ -147,6 +148,7 @@ def __init__( # return self.disabled = False + self.enable_register_for_capturing = enable_register_for_capturing # buffers memory are owned by this Python class and passed to C++ # meta data composes of two parts: meta data for synchronization # (256 bytes) and a temporary buffer for storing intermediate @@ -313,8 +315,8 @@ def custom_all_reduce( input, use_new=use_new, open_fp8_quant=open_fp8_quant, - registered_input=True, - registered_output=True + registered_input=self.enable_register_for_capturing, + registered_output=self.enable_register_for_capturing ) else: # if warm up, mimic the allocation pattern