Skip to content

Commit

Permalink
Fix parameter count in ModelSummary when parameters are DTensors (#20163
Browse files Browse the repository at this point in the history
)
  • Loading branch information
awaelchli authored Aug 5, 2024
1 parent 3de60f4 commit 345450b
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 13 deletions.
13 changes: 12 additions & 1 deletion src/lightning/fabric/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
from lightning_utilities.core.imports import package_available
from torch import Tensor
from torch.utils.data import Dataset, DistributedSampler, Sampler
from typing_extensions import Self, override
from typing_extensions import Self, TypeGuard, override

from lightning.fabric.utilities.cloud_io import _is_local_file_protocol
from lightning.fabric.utilities.data import _num_cpus_available
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.rank_zero import rank_zero_info
from lightning.fabric.utilities.types import _PATH, ReduceOp

Expand All @@ -30,6 +31,8 @@ class group: # type: ignore


if TYPE_CHECKING:
from torch.distributed._tensor import DTensor

from lightning.fabric.plugins import ClusterEnvironment
from lightning.fabric.strategies import Strategy

Expand Down Expand Up @@ -427,3 +430,11 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self.barrier()
if self.group is not None:
torch.distributed.destroy_process_group(self.group)


def _is_dtensor(tensor: Tensor) -> TypeGuard["DTensor"]:
if _TORCH_GREATER_EQUAL_2_4:
from torch.distributed._tensor import DTensor

return isinstance(tensor, DTensor)
return False
1 change: 1 addition & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `_LoggerConnector`'s `_ResultMetric` to move all registered keys to the device of the logged value if needed ([#19814](https://github.com/Lightning-AI/pytorch-lightning/issues/19814))

- Fixed parameter counts in `ModelSummary` when model has distributed parameters (DTensor) ([#20163](https://github.com/Lightning-AI/pytorch-lightning/pull/20163))


## [2.3.0] - 2024-06-13
Expand Down
14 changes: 7 additions & 7 deletions src/lightning/pytorch/utilities/model_summary/model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torch.utils.hooks import RemovableHandle

import lightning.pytorch as pl
from lightning.fabric.utilities.distributed import _is_dtensor
from lightning.pytorch.utilities.model_helpers import _ModuleMode
from lightning.pytorch.utilities.rank_zero import WarningCache

Expand Down Expand Up @@ -135,7 +136,7 @@ def layer_type(self) -> str:
@property
def num_parameters(self) -> int:
"""Returns the number of parameters in this module."""
return sum(math.prod(p.shape) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters())
return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._module.parameters())

@property
def training(self) -> bool:
Expand Down Expand Up @@ -264,13 +265,11 @@ def total_training_modes(self) -> Dict[str, int]:

@property
def total_parameters(self) -> int:
return sum(p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters())
return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._model.parameters())

@property
def trainable_parameters(self) -> int:
return sum(
p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters() if p.requires_grad
)
return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._model.parameters() if p.requires_grad)

@property
def total_layer_params(self) -> int:
Expand Down Expand Up @@ -470,10 +469,11 @@ def get_human_readable_count(number: int) -> str:
return f"{number:,.1f} {labels[index]}"


def _is_lazy_weight_tensor(p: Tensor) -> bool:
def _tensor_has_shape(p: Tensor) -> bool:
from torch.nn.parameter import UninitializedParameter

if isinstance(p, UninitializedParameter):
# DTensor is a subtype of `UninitializedParameter`, but the shape is known
if isinstance(p, UninitializedParameter) and not _is_dtensor(p):
warning_cache.warn(
"The total number of parameters detected may be inaccurate because the model contains"
" an instance of `UninitializedParameter`. To get an accurate number, set `self.example_input_array`"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
NOT_APPLICABLE,
LayerSummary,
ModelSummary,
_is_lazy_weight_tensor,
_tensor_has_shape,
get_human_readable_count,
)

Expand All @@ -40,7 +40,7 @@ class DeepSpeedLayerSummary(LayerSummary):
@override
def num_parameters(self) -> int:
"""Returns the number of parameters in this module."""
return sum(deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters())
return sum(deepspeed_param_size(p) if not _tensor_has_shape(p) else 0 for p in self._module.parameters())

@property
def average_shard_parameters(self) -> int:
Expand All @@ -49,7 +49,7 @@ def average_shard_parameters(self) -> int:
def partitioned_size(p: Parameter) -> int:
return p.partitioned_size() if RequirementCache("deepspeed<0.6.6") else p.partition_numel()

return sum(partitioned_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters())
return sum(partitioned_size(p) if not _tensor_has_shape(p) else 0 for p in self._module.parameters())


class DeepSpeedSummary(ModelSummary):
Expand All @@ -71,13 +71,13 @@ def summarize(self) -> Dict[str, DeepSpeedLayerSummary]: # type: ignore[overrid
@property
@override
def total_parameters(self) -> int:
return sum(deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters())
return sum(deepspeed_param_size(p) if not _tensor_has_shape(p) else 0 for p in self._model.parameters())

@property
@override
def trainable_parameters(self) -> int:
return sum(
deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0
deepspeed_param_size(p) if not _tensor_has_shape(p) else 0
for p in self._model.parameters()
if p.requires_grad
)
Expand Down
14 changes: 14 additions & 0 deletions tests/tests_fabric/utilities/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from functools import partial
from pathlib import Path
from unittest import mock
from unittest.mock import Mock

import lightning.fabric
import pytest
import torch
from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator
Expand All @@ -15,6 +17,7 @@
_gather_all_tensors,
_InfiniteBarrier,
_init_dist_connection,
_is_dtensor,
_set_num_threads_if_needed,
_suggested_max_num_threads,
_sync_ddp,
Expand Down Expand Up @@ -234,3 +237,14 @@ def test_init_dist_connection_registers_destruction_handler(_, atexit_mock):
atexit_mock.reset_mock()
_init_dist_connection(LightningEnvironment(), "gloo")
atexit_mock.register.assert_not_called()


@RunIf(min_torch="2.4")
def test_is_dtensor(monkeypatch):
from torch.distributed._tensor import DTensor

assert _is_dtensor(Mock(spec=DTensor))
assert not _is_dtensor(torch.zeros(2, 2))

monkeypatch.setattr(lightning.fabric.utilities.distributed, "_TORCH_GREATER_EQUAL_2_4", False)
assert not _is_dtensor(Mock(spec=DTensor))
13 changes: 13 additions & 0 deletions tests/tests_pytorch/utilities/test_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from collections import OrderedDict
from typing import Any
from unittest import mock

import pytest
import torch
Expand Down Expand Up @@ -345,6 +346,18 @@ def test_lazy_model_summary():
assert summary.trainable_parameters == 0


@mock.patch("lightning.pytorch.utilities.model_summary.model_summary._is_dtensor", return_value=True)
def test_dtensor_model_summary(_):
"""Test that the model summary can work with layers that have DTensor parameters."""
# We mock the `_is_dtensor` to pretend parameters are DTensors, because testing with real DTensors
# would require setting up distributed
dtensor_model = UnorderedModel()
summary = ModelSummary(dtensor_model)
assert summary.total_layer_params > 0
assert summary.total_parameters > 0
assert summary.trainable_parameters > 0


@pytest.mark.parametrize("max_depth", [-1, 0, 1, 3, 999])
def test_max_depth_param(max_depth):
"""Test that only the modules up to the desired depth are shown."""
Expand Down

0 comments on commit 345450b

Please sign in to comment.