Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 14 additions & 16 deletions src/accelerate/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,20 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
"bnb_4bit_quant_storage to a floating dtype (e.g. bf16)."
)

# FSDP2 requires uniform orig_dtype among trainable params in each group.
# Upcast to fp32 master weights; MixedPrecisionPolicy.param_dtype handles compute cast.
if accelerator.mixed_precision != "no" and not model_has_params4bit:
upcasted_params = []
for name, param in model.named_parameters():
if param.requires_grad and param.dtype != torch.float32:
upcasted_params.append(name)
param.data = param.data.to(torch.float32)
if accelerator.is_main_process and upcasted_params:
warnings.warn(
"FSDP upcast of low precision parameters to fp32 (since mixed_precision != 'no') may affect the precision of model checkpoints. "
f"This effects {len(upcasted_params)} parameters: {upcasted_params}..."
)

if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
# Context: `fully_shard` moves the model to GPU if it was on CPU, however it can also be on `meta` and then it stays there even after `fully_shard`
# For this reason, we need to move the model to `meta` device, as then sharding happens on `meta` device
Expand Down Expand Up @@ -748,22 +762,6 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
if hasattr(model, "tie_weights"):
model.tie_weights()

# There is no `dtype` attribution for nn.Module
# Set it to None if it doesn't exist and do the upcast always
model_dtype = getattr(model, "dtype", None)
if accelerator.mixed_precision != "no" and (model_dtype is None or model_dtype != torch.float32):
# We upcast the trainable parameters according to `deepspeed`'s implementation
# More info about this can be found in `accelerator.py:prepare_model`s FSDP1 section
upcasted_params = []
for name, param in model.named_parameters():
if param.requires_grad and param.dtype != torch.float32:
upcasted_params.append(name)
param = param.to(torch.float32)
if accelerator.is_main_process and upcasted_params:
warnings.warn(
"FSDP upcast of low precision parameters to fp32 (since mixed_precision != 'no') may affect the precision of model checkpoints. "
f"This effects {len(upcasted_params)} parameters: {upcasted_params}..."
)
return model


Expand Down
168 changes: 168 additions & 0 deletions tests/fsdp/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,174 @@ def test_param_mapping_error_handling(self):

AcceleratorState._reset_state(True)

def test_fsdp2_uniform_dtype_upcast_bf16(self):
"""Test that fsdp2_prepare_model upcasts mixed-dtype trainable params to fp32 master weights
when mixed_precision='bf16'. Many HF models (Llama, Mistral) store norm weights in fp32,
and FSDP2 requires uniform orig_dtype among trainable params within each FSDP group."""
from unittest.mock import Mock, patch
from accelerate.utils.fsdp_utils import fsdp2_prepare_model

# Create model with mixed dtypes: linear=bf16, norm=fp32 (simulates HF Llama)
model = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.LayerNorm(4),
)
model[0].to(torch.bfloat16)
dtypes_before = {p.dtype for p in model.parameters()}
assert dtypes_before == {torch.bfloat16, torch.float32}

mock_accelerator = Mock()
mock_accelerator.mixed_precision = "bf16"
mock_accelerator.torch_device_mesh = None
mock_accelerator.device = torch.device("cpu")
mock_accelerator.is_main_process = True

mock_mp_policy = Mock()
mock_mp_policy.param_dtype = torch.bfloat16
mock_plugin = Mock()
mock_plugin.mixed_precision_policy = mock_mp_policy
mock_plugin.reshard_after_forward = True
mock_plugin.cpu_offload = None
mock_plugin.cpu_ram_efficient_loading = False
mock_plugin.ignored_modules = None
mock_accelerator.state.fsdp_plugin = mock_plugin

with (
patch("torch.distributed.fsdp.fully_shard"),
patch("accelerate.utils.fsdp_utils.is_compiled_module", return_value=False),
patch("accelerate.utils.fsdp_utils.fsdp2_prepare_auto_wrap_policy", return_value=None),
):
result = fsdp2_prepare_model(mock_accelerator, model)

dtypes_after = {p.dtype for p in result.parameters()}
assert dtypes_after == {torch.float32}, f"Expected all fp32 master weights, got {dtypes_after}"

def test_fsdp2_uniform_dtype_upcast_fp16(self):
"""Test that fsdp2_prepare_model upcasts mixed-dtype trainable params to fp32 master weights
when mixed_precision='fp16'."""
from unittest.mock import Mock, patch
from accelerate.utils.fsdp_utils import fsdp2_prepare_model

model = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.LayerNorm(4),
)
model[0].to(torch.float16)
dtypes_before = {p.dtype for p in model.parameters()}
assert dtypes_before == {torch.float16, torch.float32}

mock_accelerator = Mock()
mock_accelerator.mixed_precision = "fp16"
mock_accelerator.torch_device_mesh = None
mock_accelerator.device = torch.device("cpu")
mock_accelerator.is_main_process = True

mock_mp_policy = Mock()
mock_mp_policy.param_dtype = torch.float16
mock_plugin = Mock()
mock_plugin.mixed_precision_policy = mock_mp_policy
mock_plugin.reshard_after_forward = True
mock_plugin.cpu_offload = None
mock_plugin.cpu_ram_efficient_loading = False
mock_plugin.ignored_modules = None
mock_accelerator.state.fsdp_plugin = mock_plugin

with (
patch("torch.distributed.fsdp.fully_shard"),
patch("accelerate.utils.fsdp_utils.is_compiled_module", return_value=False),
patch("accelerate.utils.fsdp_utils.fsdp2_prepare_auto_wrap_policy", return_value=None),
):
result = fsdp2_prepare_model(mock_accelerator, model)

dtypes_after = {p.dtype for p in result.parameters()}
assert dtypes_after == {torch.float32}, f"Expected all fp32 master weights, got {dtypes_after}"

def test_fsdp2_no_dtype_cast_when_no_mixed_precision(self):
"""Test that no dtype cast happens when mixed_precision='no', preserving original model dtypes."""
from unittest.mock import Mock, patch
from accelerate.utils.fsdp_utils import fsdp2_prepare_model

model = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.LayerNorm(4),
)
model[0].to(torch.bfloat16)
dtypes_before = {p.dtype for p in model.parameters()}
assert dtypes_before == {torch.bfloat16, torch.float32}

mock_accelerator = Mock()
mock_accelerator.mixed_precision = "no"
mock_accelerator.torch_device_mesh = None
mock_accelerator.device = torch.device("cpu")
mock_accelerator.is_main_process = True

mock_plugin = Mock()
mock_plugin.mixed_precision_policy = None
mock_plugin.reshard_after_forward = True
mock_plugin.cpu_offload = None
mock_plugin.cpu_ram_efficient_loading = False
mock_plugin.ignored_modules = None
mock_accelerator.state.fsdp_plugin = mock_plugin

with (
patch("torch.distributed.fsdp.fully_shard"),
patch("accelerate.utils.fsdp_utils.is_compiled_module", return_value=False),
patch("accelerate.utils.fsdp_utils.fsdp2_prepare_auto_wrap_policy", return_value=None),
):
result = fsdp2_prepare_model(mock_accelerator, model)

dtypes_after = {p.dtype for p in result.parameters()}
assert dtypes_after == {torch.bfloat16, torch.float32}, f"Expected mixed dtypes preserved, got {dtypes_after}"

def test_fsdp2_no_dtype_cast_with_params4bit(self):
"""Test that dtype cast is skipped when model has Params4bit (QLoRA),
to avoid destroying quantized weights."""
from unittest.mock import Mock, patch
from accelerate.utils.fsdp_utils import fsdp2_prepare_model

model = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.LayerNorm(4),
)
model[0].to(torch.bfloat16)
dtypes_before = {p.dtype for p in model.parameters()}
assert dtypes_before == {torch.bfloat16, torch.float32}

# Simulate Params4bit by renaming one parameter's class
original_class = model[0].weight.__class__
model[0].weight.__class__ = type("Params4bit", (torch.nn.Parameter,), {})

mock_accelerator = Mock()
mock_accelerator.mixed_precision = "bf16"
mock_accelerator.torch_device_mesh = None
mock_accelerator.device = torch.device("cpu")
mock_accelerator.is_main_process = True

mock_mp_policy = Mock()
mock_mp_policy.param_dtype = torch.bfloat16
mock_plugin = Mock()
mock_plugin.mixed_precision_policy = mock_mp_policy
mock_plugin.reshard_after_forward = True
mock_plugin.cpu_offload = None
mock_plugin.cpu_ram_efficient_loading = False
mock_plugin.ignored_modules = None
mock_accelerator.state.fsdp_plugin = mock_plugin

try:
with (
patch("torch.distributed.fsdp.fully_shard"),
patch("accelerate.utils.fsdp_utils.is_compiled_module", return_value=False),
patch("accelerate.utils.fsdp_utils.fsdp2_prepare_auto_wrap_policy", return_value=None),
):
result = fsdp2_prepare_model(mock_accelerator, model)

dtypes_after = {p.dtype for p in result.parameters()}
assert dtypes_after == {torch.bfloat16, torch.float32}, (
f"Expected mixed dtypes preserved (Params4bit skip), got {dtypes_after}"
)
finally:
model[0].weight.__class__ = original_class


@run_first
# Skip this test when TorchXLA is available because accelerate.launch does not support TorchXLA FSDP.
Expand Down
Loading