Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 118 additions & 6 deletions unsloth/models/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,18 +1343,130 @@ def post_patch_model(
# See: https://github.com/unslothai/unsloth/issues/3713.
use_reentrant = not is_distributed()
if not use_reentrant:
# Under DDP, avoid the offloaded/re-entrant checkpoint patch.
# Under DDP, switch to non-reentrant checkpointing with CPU
# activation offloading via saved_tensors_hooks. This preserves
# memory savings from offloading while being DDP-safe.
unpatch_unsloth_gradient_checkpointing()
unpatch_unsloth_smart_gradient_checkpointing()
# Force native checkpoint to default to non-reentrant for downstream calls.
_orig_checkpoint = torch_checkpoint.checkpoint

def _nonre_checkpoint(function, *args, **kwargs):
def _nonre_offloaded_checkpoint(function, *args, **kwargs):
kwargs["use_reentrant"] = False
return _orig_checkpoint(function, *args, **kwargs)

torch_checkpoint.checkpoint = _nonre_checkpoint
hf_modeling_utils.checkpoint = _nonre_checkpoint
def _pack(x):
if isinstance(x, torch.Tensor) and x.device.type != "cpu":
return ("offload", x.device, x.to("cpu", non_blocking = True))
return ("pass", x)

def _unpack(packed):
if packed[0] == "offload":
return packed[2].to(packed[1], non_blocking = True)
return packed[1]

with torch.autograd.graph.saved_tensors_hooks(_pack, _unpack):
return _orig_checkpoint(function, *args, **kwargs)

torch_checkpoint.checkpoint = _nonre_offloaded_checkpoint
hf_modeling_utils.checkpoint = _nonre_offloaded_checkpoint

# Fix TiledMLP for DDP: its backward calls torch.autograd.backward()
# per sequence chunk, triggering DDP hooks multiple times. Patch it
# to use functional torch.autograd.grad() for all but the last chunk,
# then .backward() for the final chunk (fires DDP hooks exactly once).
from unsloth_zoo.tiled_mlp import TiledMLP, torch_amp_custom_bwd
from unsloth_zoo.gradient_checkpointing import (
set_device_states as _set_dev_states,
)

@staticmethod
@torch_amp_custom_bwd
def _ddp_safe_backward(ctx, grad_output, *args):
rng_devices = []
x = ctx.saved_tensors[0]
B, S, H = x.shape
if ctx.preserve_rng_state and ctx.had_device_in_fwd:
rng_devices = ctx.fwd_devices
with torch.random.fork_rng(
devices = rng_devices,
enabled = ctx.preserve_rng_state,
device_type = ctx.device_type,
):
if ctx.preserve_rng_state:
torch.set_rng_state(ctx.fwd_cpu_state)
if ctx.had_device_in_fwd:
_set_dev_states(
ctx.fwd_devices,
ctx.fwd_device_states,
device_type = ctx.device_type,
)

mlp_params = [
p for p in ctx.mlp_module.parameters() if p.requires_grad
]
x_gradients = torch.zeros_like(
x, memory_format = torch.preserve_format
)
x_flat = x.view(-1, H)
x_splits = torch.split(x_flat, ctx.split_sizes, dim = 0)
start_idx = 0
extra_outputs = []
n_chunks = len(x_splits)

for i, x_split in enumerate(x_splits):
x_split = x_split.unsqueeze(0)
split_size = x_split.numel()
x_grad_slice = (
x_gradients.view(-1)
.narrow(
dim = 0,
start = start_idx,
length = split_size,
)
.view_as(x_split)
)
grad_shard = (
grad_output.view(-1)
.narrow(
dim = 0,
start = start_idx,
length = split_size,
)
.view_as(x_split)
)

x_split.requires_grad_(True)
with torch.enable_grad():
outputs = TiledMLP.handle_output(
ctx.mlp_forward(x_split),
extra_outputs,
)

if i < n_chunks - 1:
# Functional grad: no DDP hooks fired
grads = torch.autograd.grad(
outputs,
[x_split] + mlp_params,
grad_shard,
allow_unused = True,
)
x_grad_slice.copy_(grads[0])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For robustness, it's a good practice to check if grads[0] is not None before calling .copy_(). While x_split is expected to have a gradient since requires_grad is set to True, torch.autograd.grad with allow_unused=True can return None for inputs that don't receive gradients. If grads[0] were None for some reason (e.g., an unexpected graph structure where x_split is not used), this would raise an error. Adding this check makes the code more robust.

Suggested change
x_grad_slice.copy_(grads[0])
if grads[0] is not None:
x_grad_slice.copy_(grads[0])

for p, g in zip(mlp_params, grads[1:]):
if g is not None:
if p.grad is None:
p.grad = g
else:
p.grad.add_(g)
else:
# Last chunk: .backward() fires DDP hooks once
x_split.grad = x_grad_slice
torch.autograd.backward(outputs, grad_shard)

start_idx += split_size
Comment on lines +1415 to +1464
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic inside this for loop is quite complex and makes the _ddp_safe_backward function long (over 80 lines). To improve readability and maintainability, consider refactoring the loop body into a separate helper function. This helper could encapsulate the logic for processing a single chunk, including gradient calculation and accumulation, making the main function's flow easier to understand.


return None, None, x_gradients, None, None, None

TiledMLP.backward = _ddp_safe_backward
logger.info("Unsloth: Patched TiledMLP backward for DDP compatibility")

model = prepare_model_for_training(
model,
Expand Down
Loading