diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index a8adba99e7..5db65115a3 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -1343,18 +1343,130 @@ def post_patch_model( # See: https://github.com/unslothai/unsloth/issues/3713. use_reentrant = not is_distributed() if not use_reentrant: - # Under DDP, avoid the offloaded/re-entrant checkpoint patch. + # Under DDP, switch to non-reentrant checkpointing with CPU + # activation offloading via saved_tensors_hooks. This preserves + # memory savings from offloading while being DDP-safe. unpatch_unsloth_gradient_checkpointing() unpatch_unsloth_smart_gradient_checkpointing() - # Force native checkpoint to default to non-reentrant for downstream calls. _orig_checkpoint = torch_checkpoint.checkpoint - def _nonre_checkpoint(function, *args, **kwargs): + def _nonre_offloaded_checkpoint(function, *args, **kwargs): kwargs["use_reentrant"] = False - return _orig_checkpoint(function, *args, **kwargs) - torch_checkpoint.checkpoint = _nonre_checkpoint - hf_modeling_utils.checkpoint = _nonre_checkpoint + def _pack(x): + if isinstance(x, torch.Tensor) and x.device.type != "cpu": + return ("offload", x.device, x.to("cpu", non_blocking = True)) + return ("pass", x) + + def _unpack(packed): + if packed[0] == "offload": + return packed[2].to(packed[1], non_blocking = True) + return packed[1] + + with torch.autograd.graph.saved_tensors_hooks(_pack, _unpack): + return _orig_checkpoint(function, *args, **kwargs) + + torch_checkpoint.checkpoint = _nonre_offloaded_checkpoint + hf_modeling_utils.checkpoint = _nonre_offloaded_checkpoint + + # Fix TiledMLP for DDP: its backward calls torch.autograd.backward() + # per sequence chunk, triggering DDP hooks multiple times. Patch it + # to use functional torch.autograd.grad() for all but the last chunk, + # then .backward() for the final chunk (fires DDP hooks exactly once). + from unsloth_zoo.tiled_mlp import TiledMLP, torch_amp_custom_bwd + from unsloth_zoo.gradient_checkpointing import ( + set_device_states as _set_dev_states, + ) + + @staticmethod + @torch_amp_custom_bwd + def _ddp_safe_backward(ctx, grad_output, *args): + rng_devices = [] + x = ctx.saved_tensors[0] + B, S, H = x.shape + if ctx.preserve_rng_state and ctx.had_device_in_fwd: + rng_devices = ctx.fwd_devices + with torch.random.fork_rng( + devices = rng_devices, + enabled = ctx.preserve_rng_state, + device_type = ctx.device_type, + ): + if ctx.preserve_rng_state: + torch.set_rng_state(ctx.fwd_cpu_state) + if ctx.had_device_in_fwd: + _set_dev_states( + ctx.fwd_devices, + ctx.fwd_device_states, + device_type = ctx.device_type, + ) + + mlp_params = [ + p for p in ctx.mlp_module.parameters() if p.requires_grad + ] + x_gradients = torch.zeros_like( + x, memory_format = torch.preserve_format + ) + x_flat = x.view(-1, H) + x_splits = torch.split(x_flat, ctx.split_sizes, dim = 0) + start_idx = 0 + extra_outputs = [] + n_chunks = len(x_splits) + + for i, x_split in enumerate(x_splits): + x_split = x_split.unsqueeze(0) + split_size = x_split.numel() + x_grad_slice = ( + x_gradients.view(-1) + .narrow( + dim = 0, + start = start_idx, + length = split_size, + ) + .view_as(x_split) + ) + grad_shard = ( + grad_output.view(-1) + .narrow( + dim = 0, + start = start_idx, + length = split_size, + ) + .view_as(x_split) + ) + + x_split.requires_grad_(True) + with torch.enable_grad(): + outputs = TiledMLP.handle_output( + ctx.mlp_forward(x_split), + extra_outputs, + ) + + if i < n_chunks - 1: + # Functional grad: no DDP hooks fired + grads = torch.autograd.grad( + outputs, + [x_split] + mlp_params, + grad_shard, + allow_unused = True, + ) + x_grad_slice.copy_(grads[0]) + for p, g in zip(mlp_params, grads[1:]): + if g is not None: + if p.grad is None: + p.grad = g + else: + p.grad.add_(g) + else: + # Last chunk: .backward() fires DDP hooks once + x_split.grad = x_grad_slice + torch.autograd.backward(outputs, grad_shard) + + start_idx += split_size + + return None, None, x_gradients, None, None, None + + TiledMLP.backward = _ddp_safe_backward + logger.info("Unsloth: Patched TiledMLP backward for DDP compatibility") model = prepare_model_for_training( model,