Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `exclude_frozen_parameters` to `DeepSpeedStrategy` ([#21060](https://github.com/Lightning-AI/pytorch-lightning/pull/21060))


-


Expand Down
8 changes: 7 additions & 1 deletion src/lightning/fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(
precision: Optional[Precision] = None,
process_group_backend: Optional[str] = None,
timeout: Optional[timedelta] = default_pg_timeout,
exclude_frozen_parameters: bool = False,
) -> None:
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
billion parameter models. `For more information: https://pytorch-
Expand Down Expand Up @@ -229,6 +230,8 @@ def __init__(
when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards
per worker.

exclude_frozen_parameters: Exclude frozen parameters when saving checkpoints.

"""
if not _DEEPSPEED_AVAILABLE:
raise ImportError(
Expand Down Expand Up @@ -289,6 +292,7 @@ def __init__(

self.remote_device = remote_device
self.load_full_weights = load_full_weights
self.exclude_frozen_parameters = exclude_frozen_parameters

# default FP16 parameters.
self.loss_scale = loss_scale
Expand Down Expand Up @@ -444,7 +448,9 @@ def save_checkpoint(
# there might be other stateful objects unrelated to the deepspeed engine - convert them to a state_dict
state = self._convert_stateful_objects_in_state(state, filter={})
# use deepspeed's internal checkpointing function to handle partitioned weights across processes
engine.save_checkpoint(path, client_state=state, tag="checkpoint")
engine.save_checkpoint(
path, client_state=state, tag="checkpoint", exclude_frozen_parameters=self.exclude_frozen_parameters
)

@override
def load_checkpoint(
Expand Down
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added Torch-Tensorrt integration with `LightningModule` ([#20808](https://github.com/Lightning-AI/pytorch-lightning/pull/20808))


- Added `exclude_frozen_parameters` to `DeepSpeedStrategy` ([#21060](https://github.com/Lightning-AI/pytorch-lightning/pull/21060))


### Changed

- Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise. ([#9580](https://github.com/Lightning-AI/pytorch-lightning/pull/9580))
Expand Down
11 changes: 10 additions & 1 deletion src/lightning/pytorch/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def __init__(
precision_plugin: Optional[Precision] = None,
process_group_backend: Optional[str] = None,
timeout: Optional[timedelta] = default_pg_timeout,
exclude_frozen_parameters: bool = False,
) -> None:
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
billion parameter models. `For more information: https://pytorch-
Expand Down Expand Up @@ -253,6 +254,8 @@ def __init__(
when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards
per worker.

exclude_frozen_parameters: Exclude frozen parameters when saving checkpoints.

"""
if not _DEEPSPEED_AVAILABLE:
raise MisconfigurationException(
Expand Down Expand Up @@ -311,6 +314,7 @@ def __init__(

self.remote_device = remote_device
self.load_full_weights = load_full_weights
self.exclude_frozen_parameters = exclude_frozen_parameters

# default FP16 parameters.
self.loss_scale = loss_scale
Expand Down Expand Up @@ -648,7 +652,12 @@ def save_checkpoint(self, checkpoint: dict, filepath: _PATH, storage_options: Op
# dump states as a checkpoint dictionary object
_exclude_keys = ["state_dict", "optimizer_states"]
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint, tag="checkpoint")
self.deepspeed_engine.save_checkpoint(
filepath,
client_state=checkpoint,
tag="checkpoint",
exclude_frozen_parameters=self.exclude_frozen_parameters,
)

@override
def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]:
Expand Down
29 changes: 27 additions & 2 deletions tests/tests_fabric/strategies/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,15 +194,19 @@ def test_deepspeed_save_checkpoint_client_state_separation(tmp_path):
model.modules.return_value = [model]
strategy.save_checkpoint(path=tmp_path, state={"model": model, "test": "data"})
# the client_state should not contain any deepspeed engine or deepspeed optimizer
model.save_checkpoint.assert_called_with(tmp_path, client_state={"test": "data"}, tag="checkpoint")
model.save_checkpoint.assert_called_with(
tmp_path, client_state={"test": "data"}, tag="checkpoint", exclude_frozen_parameters=False
)

# Model and optimizer
optimizer = Mock()
model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
model.modules.return_value = [model]
strategy.save_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "test": "data"})
# the client_state should not contain any deepspeed engine or deepspeed optimizer
model.save_checkpoint.assert_called_with(tmp_path, client_state={"test": "data"}, tag="checkpoint")
model.save_checkpoint.assert_called_with(
tmp_path, client_state={"test": "data"}, tag="checkpoint", exclude_frozen_parameters=False
)


@RunIf(deepspeed=True)
Expand All @@ -219,6 +223,27 @@ def test_deepspeed_save_checkpoint_warn_colliding_keys(tmp_path):
strategy.save_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "mp_world_size": 2})


@RunIf(deepspeed=True)
@pytest.mark.parametrize("exclude_frozen_parameters", [True, False])
def test_deepspeed_save_checkpoint_exclude_frozen_parameters(exclude_frozen_parameters):
"""Test that the DeepSpeed strategy can save checkpoints with the `exclude_frozen_parameters` argument."""
from deepspeed import DeepSpeedEngine

strategy = DeepSpeedStrategy(exclude_frozen_parameters=exclude_frozen_parameters)
assert strategy.exclude_frozen_parameters is exclude_frozen_parameters

model = Mock(spec=DeepSpeedEngine, optimizer=None)
model.modules.return_value = [model]
strategy.save_checkpoint(path="test_path", state={"model": model, "extra": "data"})

model.save_checkpoint.assert_called_with(
"test_path",
client_state={"extra": "data"},
tag="checkpoint",
exclude_frozen_parameters=exclude_frozen_parameters,
)


@RunIf(deepspeed=True)
def test_deepspeed_load_checkpoint_validate_path(tmp_path):
"""Test that we validate the checkpoint path for a DeepSpeed checkpoint and give suggestions for user error."""
Expand Down
40 changes: 40 additions & 0 deletions tests/tests_pytorch/strategies/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,46 @@ def test_deepspeed_multigpu_single_file(tmp_path):
trainer.test(model, ckpt_path=checkpoint_path)


@RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True)
def test_deepspeed_strategy_exclude_frozen_parameters_integration(tmp_path):
"""Test end-to-end integration of exclude_frozen_parameters with actual model training and checkpointing."""

class TestModelWithFrozenParams(BoringModel):
def __init__(self):
super().__init__()
self.frozen_layer = torch.nn.Linear(32, 32)

def configure_model(self) -> None:
super().configure_model()
# Freeze the additional layer parameters
for param in self.frozen_layer.parameters():
param.requires_grad = False

def forward(self, x):
x = self.frozen_layer(x)
return super().forward(x)

model = TestModelWithFrozenParams()

trainer = Trainer(
default_root_dir=tmp_path,
strategy=DeepSpeedStrategy(exclude_frozen_parameters=True),
accelerator="gpu",
devices=1,
fast_dev_run=True,
precision="16-mixed",
enable_progress_bar=False,
enable_model_summary=False,
)

trainer.fit(model)
checkpoint_path = os.path.join(tmp_path, "checkpoint_exclude_frozen.ckpt")
trainer.save_checkpoint(checkpoint_path)

# Verify checkpoint was created
assert os.path.exists(checkpoint_path)


class ModelParallelClassificationModel(LightningModule):
def __init__(self, lr: float = 0.01, num_blocks: int = 5):
super().__init__()
Expand Down
Loading