Skip to content

Commit 63d1fb2

Browse files
Andrew Gupytorchmergebot
authored andcommitted
[FSDP] Default limit_all_gathers=True (pytorch#104900)
This PR defaults to `limit_all_gathers=True`. I included a `record_function()` for the rate limiter synchronization to help with user confusion on the gap in the pre-forward: <img width="874" alt="Screenshot 2023-07-10 at 3 28 18 PM" src="https://github.com/pytorch/pytorch/assets/31054793/61f55e0e-58d7-4162-9395-bea06d3e8d8a"> Pull Request resolved: pytorch#104900 Approved by: https://github.com/fegin
1 parent 7c3c3dd commit 63d1fb2

File tree

3 files changed

+30
-13
lines changed

3 files changed

+30
-13
lines changed

test/distributed/fsdp/test_fsdp_overlap.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,16 @@ def get_time(self):
5858

5959

6060
def _create_model(compute_cycles, has_params: bool):
61+
# Use `limit_all_gathers=False` since the timing being tested relies on the
62+
# CPU running ahead of the GPU
6163
model = FSDP(
6264
nn.Sequential(
63-
FSDP(Layer(compute_cycles, has_params)),
64-
FSDP(Layer(compute_cycles, has_params)),
65-
FSDP(Layer(compute_cycles, has_params)),
66-
FSDP(Layer(compute_cycles, has_params)),
67-
)
65+
FSDP(Layer(compute_cycles, has_params), limit_all_gathers=False),
66+
FSDP(Layer(compute_cycles, has_params), limit_all_gathers=False),
67+
FSDP(Layer(compute_cycles, has_params), limit_all_gathers=False),
68+
FSDP(Layer(compute_cycles, has_params), limit_all_gathers=False),
69+
),
70+
limit_all_gathers=False,
6871
).cuda()
6972
return model
7073

torch/distributed/fsdp/_runtime_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,10 @@ def _unshard(
358358
if state.limit_all_gathers:
359359
event = state._free_event_queue.dequeue_if_needed()
360360
if event:
361-
event.synchronize()
361+
with torch.profiler.record_function(
362+
"FullyShardedDataParallel.rate_limiter"
363+
):
364+
event.synchronize()
362365
with state._device_handle.stream(unshard_stream):
363366
for handle in handles:
364367
handle.unshard()

torch/distributed/fsdp/fully_sharded_data_parallel.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,13 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
214214
instead of reacquiring the references each iteration, then it will not
215215
see FSDP's newly created views, and autograd will not work correctly.
216216
217+
.. note::
218+
With ``limit_all_gathers=True``, you may see a gap in the FSDP
219+
pre-forward where the CPU thread is not issuing any kernels. This is
220+
intentional and shows the rate limiter in effect. Synchronizing the CPU
221+
thread in that way prevents over-allocating memory for subsequent
222+
all-gathers, and it should not actually delay GPU kernel execution.
223+
217224
Args:
218225
module (nn.Module):
219226
This is the module to be wrapped with FSDP.
@@ -334,12 +341,16 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
334341
bound workloads. This should only be used for static graph models
335342
since the forward order is fixed based on the first iteration's
336343
execution. (Default: ``False``)
337-
limit_all_gathers (bool): If ``False``, then FSDP allows the CPU
338-
thread to schedule all-gathers without any extra synchronization.
339-
If ``True``, then FSDP explicitly synchronizes the CPU thread to
340-
prevent too many in-flight all-gathers. This ``bool`` only affects
341-
the sharded strategies that schedule all-gathers. Enabling this can
342-
help lower the number of CUDA malloc retries.
344+
limit_all_gathers (bool): If ``True``, then FSDP explicitly
345+
synchronizes the CPU thread to ensure GPU memory usage from only
346+
*two* consecutive FSDP instances (the current instance running
347+
computation and the next instance whose all-gather is prefetched).
348+
If ``False``, then FSDP allows the CPU thread to issue all-gathers
349+
without any extra synchronization. (Default: ``True``) We often
350+
refer to this feature as the "rate limiter". This flag should only
351+
be set to ``False`` for specific CPU-bound workloads with low
352+
memory pressure in which case the CPU thread can aggressively issue
353+
all kernels without concern for the GPU memory usage.
343354
use_orig_params (bool): Setting this to ``True`` has FSDP use
344355
``module`` 's original parameters. FSDP exposes those original
345356
parameters to the user via :meth:`nn.Module.named_parameters`
@@ -382,7 +393,7 @@ def __init__(
382393
device_id: Optional[Union[int, torch.device]] = None,
383394
sync_module_states: bool = False,
384395
forward_prefetch: bool = False,
385-
limit_all_gathers: bool = False,
396+
limit_all_gathers: bool = True,
386397
use_orig_params: bool = False,
387398
ignored_states: Union[
388399
Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]

0 commit comments

Comments
 (0)