diff --git a/nemo_automodel/components/distributed/parallelizer.py b/nemo_automodel/components/distributed/parallelizer.py index 10b1d5e7f..21754e59f 100644 --- a/nemo_automodel/components/distributed/parallelizer.py +++ b/nemo_automodel/components/distributed/parallelizer.py @@ -146,6 +146,18 @@ def parallelize( for i, layer in enumerate(layers): if hasattr(layer, "mlp"): layers[i].mlp = checkpoint_wrapper(layer.mlp) + if hasattr(layer, "self_attn"): + layers[i].self_attn = checkpoint_wrapper(layers[i].self_attn) # type: ignore + + if hasattr(layer, "input_layernorm"): + layers[i].input_layernorm = checkpoint_wrapper( + layers[i].input_layernorm # type: ignore + ) + + if hasattr(layer, "post_attention_layernorm"): + layers[i].post_attention_layernorm = checkpoint_wrapper( + layers[i].post_attention_layernorm # type: ignore + ) # Set up mixed precision policy if not mp_policy: diff --git a/tests/unit_tests/distributed/test_parallelization_strategies.py b/tests/unit_tests/distributed/test_parallelization_strategies.py index 41b80ba33..ac2a66e3e 100644 --- a/tests/unit_tests/distributed/test_parallelization_strategies.py +++ b/tests/unit_tests/distributed/test_parallelization_strategies.py @@ -287,9 +287,12 @@ def test_parallelize_with_activation_checkpointing(self, strategy, mock_device_m """Test parallelization with activation checkpointing enabled.""" mesh, dp_replicate_mesh, dp_shard_mesh, tp_mesh = mock_device_mesh - # Mock layers with MLP + # Mock layers with all the attributes that get checkpointed mock_layer = MagicMock() mock_layer.mlp = nn.Linear(10, 10) + mock_layer.self_attn = MagicMock() + mock_layer.input_layernorm = MagicMock() + mock_layer.post_attention_layernorm = MagicMock() mock_distributed_env["extract_layers"].return_value = [mock_layer] model = MockModel() @@ -301,8 +304,17 @@ def test_parallelize_with_activation_checkpointing(self, strategy, mock_device_m activation_checkpointing=True, ) - # Should apply checkpoint wrapper to MLP layers - mock_distributed_env["checkpoint_wrapper"].assert_called_with(mock_layer.mlp) + # Should apply checkpoint wrapper to all expected layer components + checkpoint_wrapper_mock = mock_distributed_env["checkpoint_wrapper"] + + # Check that checkpoint_wrapper was called with all expected attributes + expected_calls = [ + call(mock_layer.mlp), + call(mock_layer.self_attn), + call(mock_layer.input_layernorm), + call(mock_layer.post_attention_layernorm), + ] + checkpoint_wrapper_mock.assert_has_calls(expected_calls, any_order=False) def test_parallelize_with_custom_mesh_names(self, strategy, mock_device_mesh, mock_distributed_env): """Test parallelization with custom mesh names."""