@@ -214,6 +214,13 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
214214        instead of reacquiring the references each iteration, then it will not 
215215        see FSDP's newly created views, and autograd will not work correctly. 
216216
217+     .. note:: 
218+         With ``limit_all_gathers=True``, you may see a gap in the FSDP 
219+         pre-forward where the CPU thread is not issuing any kernels. This is 
220+         intentional and shows the rate limiter in effect. Synchronizing the CPU 
221+         thread in that way prevents over-allocating memory for subsequent 
222+         all-gathers, and it should not actually delay GPU kernel execution. 
223+ 
217224    Args: 
218225        module (nn.Module): 
219226            This is the module to be wrapped with FSDP. 
@@ -334,12 +341,16 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
334341            bound workloads. This should only be used for static graph models 
335342            since the forward order is fixed based on the first iteration's 
336343            execution. (Default: ``False``) 
337-         limit_all_gathers (bool): If ``False``, then FSDP allows the CPU 
338-             thread to schedule all-gathers without any extra synchronization. 
339-             If ``True``, then FSDP explicitly synchronizes the CPU thread to 
340-             prevent too many in-flight all-gathers. This ``bool`` only affects 
341-             the sharded strategies that schedule all-gathers. Enabling this can 
342-             help lower the number of CUDA malloc retries. 
344+         limit_all_gathers (bool): If ``True``, then FSDP explicitly 
345+             synchronizes the CPU thread to ensure GPU memory usage from only 
346+             *two* consecutive FSDP instances (the current instance running 
347+             computation and the next instance whose all-gather is prefetched). 
348+             If ``False``, then FSDP allows the CPU thread to issue all-gathers 
349+             without any extra synchronization. (Default: ``True``) We often 
350+             refer to this feature as the "rate limiter". This flag should only 
351+             be set to ``False`` for specific CPU-bound workloads with low 
352+             memory pressure in which case the CPU thread can aggressively issue 
353+             all kernels without concern for the GPU memory usage. 
343354        use_orig_params (bool): Setting this to ``True`` has FSDP use 
344355            ``module`` 's original parameters. FSDP exposes those original 
345356            parameters to the user via :meth:`nn.Module.named_parameters` 
@@ -382,7 +393,7 @@ def __init__(
382393        device_id : Optional [Union [int , torch .device ]] =  None ,
383394        sync_module_states : bool  =  False ,
384395        forward_prefetch : bool  =  False ,
385-         limit_all_gathers : bool  =  False ,
396+         limit_all_gathers : bool  =  True ,
386397        use_orig_params : bool  =  False ,
387398        ignored_states : Union [
388399            Optional [Iterable [torch .nn .Parameter ]], Optional [Iterable [torch .nn .Module ]]
0 commit comments