-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Fix DDP "marked ready twice" for VLMs with CPU offload + TiledMLP #4240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
8c02182
5fb4397
30c54c8
e32db1e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1343,18 +1343,130 @@ def post_patch_model( | |
| # See: https://github.com/unslothai/unsloth/issues/3713. | ||
| use_reentrant = not is_distributed() | ||
| if not use_reentrant: | ||
| # Under DDP, avoid the offloaded/re-entrant checkpoint patch. | ||
| # Under DDP, switch to non-reentrant checkpointing with CPU | ||
| # activation offloading via saved_tensors_hooks. This preserves | ||
| # memory savings from offloading while being DDP-safe. | ||
| unpatch_unsloth_gradient_checkpointing() | ||
| unpatch_unsloth_smart_gradient_checkpointing() | ||
| # Force native checkpoint to default to non-reentrant for downstream calls. | ||
| _orig_checkpoint = torch_checkpoint.checkpoint | ||
|
|
||
| def _nonre_checkpoint(function, *args, **kwargs): | ||
| def _nonre_offloaded_checkpoint(function, *args, **kwargs): | ||
| kwargs["use_reentrant"] = False | ||
| return _orig_checkpoint(function, *args, **kwargs) | ||
|
|
||
| torch_checkpoint.checkpoint = _nonre_checkpoint | ||
| hf_modeling_utils.checkpoint = _nonre_checkpoint | ||
| def _pack(x): | ||
| if isinstance(x, torch.Tensor) and x.device.type != "cpu": | ||
| return ("offload", x.device, x.to("cpu", non_blocking = True)) | ||
| return ("pass", x) | ||
|
|
||
| def _unpack(packed): | ||
| if packed[0] == "offload": | ||
| return packed[2].to(packed[1], non_blocking = True) | ||
| return packed[1] | ||
|
|
||
| with torch.autograd.graph.saved_tensors_hooks(_pack, _unpack): | ||
| return _orig_checkpoint(function, *args, **kwargs) | ||
|
|
||
| torch_checkpoint.checkpoint = _nonre_offloaded_checkpoint | ||
| hf_modeling_utils.checkpoint = _nonre_offloaded_checkpoint | ||
|
|
||
| # Fix TiledMLP for DDP: its backward calls torch.autograd.backward() | ||
| # per sequence chunk, triggering DDP hooks multiple times. Patch it | ||
| # to use functional torch.autograd.grad() for all but the last chunk, | ||
| # then .backward() for the final chunk (fires DDP hooks exactly once). | ||
| from unsloth_zoo.tiled_mlp import TiledMLP, torch_amp_custom_bwd | ||
| from unsloth_zoo.gradient_checkpointing import ( | ||
| set_device_states as _set_dev_states, | ||
| ) | ||
|
|
||
| @staticmethod | ||
| @torch_amp_custom_bwd | ||
| def _ddp_safe_backward(ctx, grad_output, *args): | ||
| rng_devices = [] | ||
| x = ctx.saved_tensors[0] | ||
| B, S, H = x.shape | ||
| if ctx.preserve_rng_state and ctx.had_device_in_fwd: | ||
| rng_devices = ctx.fwd_devices | ||
| with torch.random.fork_rng( | ||
| devices = rng_devices, | ||
| enabled = ctx.preserve_rng_state, | ||
| device_type = ctx.device_type, | ||
| ): | ||
| if ctx.preserve_rng_state: | ||
| torch.set_rng_state(ctx.fwd_cpu_state) | ||
| if ctx.had_device_in_fwd: | ||
| _set_dev_states( | ||
| ctx.fwd_devices, | ||
| ctx.fwd_device_states, | ||
| device_type = ctx.device_type, | ||
| ) | ||
|
|
||
| mlp_params = [ | ||
| p for p in ctx.mlp_module.parameters() if p.requires_grad | ||
| ] | ||
| x_gradients = torch.zeros_like( | ||
| x, memory_format = torch.preserve_format | ||
| ) | ||
| x_flat = x.view(-1, H) | ||
| x_splits = torch.split(x_flat, ctx.split_sizes, dim = 0) | ||
| start_idx = 0 | ||
| extra_outputs = [] | ||
| n_chunks = len(x_splits) | ||
|
|
||
| for i, x_split in enumerate(x_splits): | ||
| x_split = x_split.unsqueeze(0) | ||
| split_size = x_split.numel() | ||
| x_grad_slice = ( | ||
| x_gradients.view(-1) | ||
| .narrow( | ||
| dim = 0, | ||
| start = start_idx, | ||
| length = split_size, | ||
| ) | ||
| .view_as(x_split) | ||
| ) | ||
| grad_shard = ( | ||
| grad_output.view(-1) | ||
| .narrow( | ||
| dim = 0, | ||
| start = start_idx, | ||
| length = split_size, | ||
| ) | ||
| .view_as(x_split) | ||
| ) | ||
|
|
||
| x_split.requires_grad_(True) | ||
| with torch.enable_grad(): | ||
| outputs = TiledMLP.handle_output( | ||
| ctx.mlp_forward(x_split), | ||
| extra_outputs, | ||
| ) | ||
|
|
||
| if i < n_chunks - 1: | ||
| # Functional grad: no DDP hooks fired | ||
| grads = torch.autograd.grad( | ||
| outputs, | ||
| [x_split] + mlp_params, | ||
| grad_shard, | ||
| allow_unused = True, | ||
| ) | ||
| x_grad_slice.copy_(grads[0]) | ||
| for p, g in zip(mlp_params, grads[1:]): | ||
| if g is not None: | ||
| if p.grad is None: | ||
| p.grad = g | ||
| else: | ||
| p.grad.add_(g) | ||
| else: | ||
| # Last chunk: .backward() fires DDP hooks once | ||
| x_split.grad = x_grad_slice | ||
| torch.autograd.backward(outputs, grad_shard) | ||
|
|
||
| start_idx += split_size | ||
|
Comment on lines
+1415
to
+1464
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic inside this |
||
|
|
||
| return None, None, x_gradients, None, None, None | ||
|
|
||
| TiledMLP.backward = _ddp_safe_backward | ||
| logger.info("Unsloth: Patched TiledMLP backward for DDP compatibility") | ||
|
|
||
| model = prepare_model_for_training( | ||
| model, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For robustness, it's a good practice to check if
grads[0]is notNonebefore calling.copy_(). Whilex_splitis expected to have a gradient sincerequires_gradis set toTrue,torch.autograd.gradwithallow_unused=Truecan returnNonefor inputs that don't receive gradients. Ifgrads[0]wereNonefor some reason (e.g., an unexpected graph structure wherex_splitis not used), this would raise an error. Adding this check makes the code more robust.