Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions nemo_automodel/components/distributed/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,18 @@ def parallelize(
for i, layer in enumerate(layers):
if hasattr(layer, "mlp"):
layers[i].mlp = checkpoint_wrapper(layer.mlp)
if hasattr(layer, "self_attn"):
layers[i].self_attn = checkpoint_wrapper(layers[i].self_attn) # type: ignore

if hasattr(layer, "input_layernorm"):
layers[i].input_layernorm = checkpoint_wrapper(
layers[i].input_layernorm # type: ignore
)

if hasattr(layer, "post_attention_layernorm"):
layers[i].post_attention_layernorm = checkpoint_wrapper(
layers[i].post_attention_layernorm # type: ignore
)

# Set up mixed precision policy
if not mp_policy:
Expand Down
18 changes: 15 additions & 3 deletions tests/unit_tests/distributed/test_parallelization_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,12 @@ def test_parallelize_with_activation_checkpointing(self, strategy, mock_device_m
"""Test parallelization with activation checkpointing enabled."""
mesh, dp_replicate_mesh, dp_shard_mesh, tp_mesh = mock_device_mesh

# Mock layers with MLP
# Mock layers with all the attributes that get checkpointed
mock_layer = MagicMock()
mock_layer.mlp = nn.Linear(10, 10)
mock_layer.self_attn = MagicMock()
mock_layer.input_layernorm = MagicMock()
mock_layer.post_attention_layernorm = MagicMock()
mock_distributed_env["extract_layers"].return_value = [mock_layer]

model = MockModel()
Expand All @@ -301,8 +304,17 @@ def test_parallelize_with_activation_checkpointing(self, strategy, mock_device_m
activation_checkpointing=True,
)

# Should apply checkpoint wrapper to MLP layers
mock_distributed_env["checkpoint_wrapper"].assert_called_with(mock_layer.mlp)
# Should apply checkpoint wrapper to all expected layer components
checkpoint_wrapper_mock = mock_distributed_env["checkpoint_wrapper"]

# Check that checkpoint_wrapper was called with all expected attributes
expected_calls = [
call(mock_layer.mlp),
call(mock_layer.self_attn),
call(mock_layer.input_layernorm),
call(mock_layer.post_attention_layernorm),
]
checkpoint_wrapper_mock.assert_has_calls(expected_calls, any_order=False)

def test_parallelize_with_custom_mesh_names(self, strategy, mock_device_mesh, mock_distributed_env):
"""Test parallelization with custom mesh names."""
Expand Down
Loading