Skip to content

Commit

Permalink
Avoid optional Tracker attributes and enable mypy (#9320)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Sep 6, 2021
1 parent ff1e691 commit 49c0485
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 133 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Progress tracking
* Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598)
* Avoid optional `Tracker` attributes ([#9320](https://github.com/PyTorchLightning/pytorch-lightning/pull/9320)


- Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628))
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ module = [
"pytorch_lightning.trainer.evaluation_loop",
"pytorch_lightning.trainer.connectors.logger_connector.fx_validator",
"pytorch_lightning.trainer.connectors.logger_connector.logger_connector",
"pytorch_lightning.trainer.progress",
"pytorch_lightning.tuner.auto_gpu_select",
"pytorch_lightning.utilities.apply_func",
"pytorch_lightning.utilities.argparse",
Expand Down
128 changes: 78 additions & 50 deletions pytorch_lightning/trainer/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import asdict, dataclass, field
from typing import Optional
from typing import Type


@dataclass
class BaseProgress:
"""
Mixin that implements state-loading utiltiies for dataclasses.
Mixin that implements state-loading utilities for dataclasses.
"""

def state_dict(self) -> dict:
Expand All @@ -35,63 +35,83 @@ def from_state_dict(cls, state_dict: dict) -> "BaseProgress":


@dataclass
class Tracker(BaseProgress):
class ReadyCompletedTracker(BaseProgress):
"""
Track an event's progress.
Args:
ready: Intended to track the number of events ready to start.
started: Intended to be incremented after the event is started (e.g. after ``on_*_start`` runs).
processed: Intended to be incremented after the event is processed.
completed: Intended to be incremented after the event completes (e.g. after ``on_*_end`` runs).
These attributes should be increased in order, that is, :attr:`ready` first and :attr:`completed` last.
Attributes set to ``None`` are treated as unused and are restricted.
"""

ready: Optional[int] = 0
started: Optional[int] = 0
processed: Optional[int] = 0
completed: Optional[int] = 0
ready: int = 0
completed: int = 0

def reset(self) -> None:
if self.ready is not None:
self.ready = 0
if self.started is not None:
self.started = 0
if self.processed is not None:
self.processed = 0
if self.completed is not None:
self.completed = 0

def __setattr__(self, key: str, value: int) -> None:
"""Restrict writing to attributes set to ``None``."""
if getattr(self, key, 0) is None:
raise AttributeError(f"The '{key}' attribute is meant to be unused")
return super().__setattr__(key, value)

def __repr__(self) -> str:
"""Custom implementation to hide ``None`` fields."""
args = [f"{k}={v}" for k, v in self.__dict__.items() if v is not None]
return f"{self.__class__.__name__}({', '.join(args)})"
"""Reset the state."""
self.ready = 0
self.completed = 0

def reset_on_restart(self) -> None:
"""
Reset the progress on restart.
If there is a failure before all attributes are increased,
we restore the attributes to the last fully completed value.
restore the attributes to the last fully completed value.
"""
# choose in case `processed` is unused
value = self.completed if self.processed is None else self.processed
self.ready = self.completed


@dataclass
class StartedTracker(ReadyCompletedTracker):
"""
Track an event's progress.
Args:
ready: Intended to track the number of events ready to start.
started: Intended to be incremented after the event is started (e.g. after ``on_*_start`` runs).
completed: Intended to be incremented after the event completes (e.g. after ``on_*_end`` runs).
These attributes should be increased in order, that is, :attr:`ready` first and :attr:`completed` last.
"""

started: int = 0

def reset(self) -> None:
super().reset()
self.started = 0

def reset_on_restart(self) -> None:
super().reset_on_restart()
self.started = self.completed


@dataclass
class ProcessedTracker(StartedTracker):
"""
Track an event's progress.
Args:
ready: Intended to track the number of events ready to start.
started: Intended to be incremented after the event is started (e.g. after ``on_*_start`` runs).
processed: Intended to be incremented after the event is processed.
completed: Intended to be incremented after the event completes (e.g. after ``on_*_end`` runs).
These attributes should be increased in order, that is, :attr:`ready` first and :attr:`completed` last.
"""

if self.ready is not None:
self.ready = value
if self.started is not None:
self.started = value
if self.processed is not None:
self.processed = value
if self.completed is not None:
self.completed = value
processed: int = 0

def reset(self) -> None:
super().reset()
self.processed = 0

def reset_on_restart(self) -> None:
# use `processed` in this case as the reset value
self.completed = self.processed
super().reset_on_restart()


@dataclass
Expand All @@ -104,18 +124,26 @@ class Progress(BaseProgress):
current: Intended to track the current progress of an event.
"""

total: Tracker = field(default_factory=Tracker)
current: Tracker = field(default_factory=Tracker)
total: ReadyCompletedTracker = field(default_factory=ProcessedTracker)
current: ReadyCompletedTracker = field(default_factory=ProcessedTracker)

def __post_init__(self) -> None:
if type(self.total) is not type(self.current): # noqa: E721
raise ValueError("The `total` and `current` instances should be of the same class")

def increment_ready(self) -> None:
self.total.ready += 1
self.current.ready += 1

def increment_started(self) -> None:
if not isinstance(self.total, StartedTracker):
raise TypeError(f"`{self.total.__class__.__name__}` doesn't have a `started` attribute")
self.total.started += 1
self.current.started += 1

def increment_processed(self) -> None:
if not isinstance(self.total, ProcessedTracker):
raise TypeError(f"`{self.total.__class__.__name__}` doesn't have a `processed` attribute")
self.total.processed += 1
self.current.processed += 1

Expand All @@ -124,9 +152,9 @@ def increment_completed(self) -> None:
self.current.completed += 1

@classmethod
def from_defaults(cls, **kwargs: Optional[int]) -> "Progress":
def from_defaults(cls, tracker_cls: Type[ReadyCompletedTracker], **kwargs: int) -> "Progress":
"""Utility function to easily create an instance from keyword arguments to both ``Tracker``s."""
return cls(total=Tracker(**kwargs), current=Tracker(**kwargs))
return cls(total=tracker_cls(**kwargs), current=tracker_cls(**kwargs))

def load_state_dict(self, state_dict: dict) -> None:
self.total.load_state_dict(state_dict["total"])
Expand All @@ -144,8 +172,8 @@ class DataLoaderProgress(Progress):
current: Tracks the current dataloader progress.
"""

total: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None))
current: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None))
total: ReadyCompletedTracker = field(default_factory=ReadyCompletedTracker)
current: ReadyCompletedTracker = field(default_factory=ReadyCompletedTracker)


@dataclass
Expand All @@ -159,8 +187,8 @@ class SchedulerProgress(Progress):
current: Tracks the current scheduler progress.
"""

total: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None))
current: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None))
total: ReadyCompletedTracker = field(default_factory=ReadyCompletedTracker)
current: ReadyCompletedTracker = field(default_factory=ReadyCompletedTracker)


@dataclass
Expand All @@ -173,8 +201,8 @@ class OptimizerProgress(BaseProgress):
zero_grad: Tracks ``optimizer.zero_grad`` calls.
"""

step: Progress = field(default_factory=lambda: Progress.from_defaults(started=None, processed=None))
zero_grad: Progress = field(default_factory=lambda: Progress.from_defaults(processed=None))
step: Progress = field(default_factory=lambda: Progress.from_defaults(ReadyCompletedTracker))
zero_grad: Progress = field(default_factory=lambda: Progress.from_defaults(StartedTracker))

def reset_on_epoch(self) -> None:
self.step.current.reset()
Expand Down
32 changes: 10 additions & 22 deletions tests/loops/test_loop_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,28 +53,25 @@ def test_loops_state_dict_structure():
"current": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
},
"epoch_loop.scheduler_progress": {
"total": {"ready": 0, "started": None, "processed": None, "completed": 0},
"current": {"ready": 0, "started": None, "processed": None, "completed": 0},
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
"epoch_loop.batch_loop.state_dict": {},
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
"optimizer": {
"step": {
"total": {"ready": 0, "started": None, "processed": None, "completed": 0},
"current": {"ready": 0, "started": None, "processed": None, "completed": 0},
},
"step": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}},
"zero_grad": {
"total": {"ready": 0, "started": 0, "processed": None, "completed": 0},
"current": {"ready": 0, "started": 0, "processed": None, "completed": 0},
"total": {"ready": 0, "started": 0, "completed": 0},
"current": {"ready": 0, "started": 0, "completed": 0},
},
},
"optimizer_idx": 0,
},
"epoch_loop.val_loop.state_dict": {},
"epoch_loop.val_loop.dataloader_progress": {
"total": {"ready": 0, "started": None, "processed": None, "completed": 0},
"current": {"ready": 0, "started": None, "processed": None, "completed": 0},
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.val_loop.epoch_loop.state_dict": {},
"epoch_loop.val_loop.epoch_loop.batch_progress": {
Expand Down Expand Up @@ -102,10 +99,7 @@ def test_loops_state_dict_structure():
},
"validate_loop": {
"state_dict": {},
"dataloader_progress": {
"total": {"ready": 0, "started": None, "processed": None, "completed": 0},
"current": {"ready": 0, "started": None, "processed": None, "completed": 0},
},
"dataloader_progress": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}},
"epoch_loop.state_dict": {},
"epoch_loop.batch_progress": {
"total": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
Expand All @@ -121,10 +115,7 @@ def test_loops_state_dict_structure():
},
"test_loop": {
"state_dict": {},
"dataloader_progress": {
"total": {"ready": 0, "started": None, "processed": None, "completed": 0},
"current": {"ready": 0, "started": None, "processed": None, "completed": 0},
},
"dataloader_progress": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}},
"epoch_loop.state_dict": {},
"epoch_loop.batch_progress": {
"total": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
Expand All @@ -140,10 +131,7 @@ def test_loops_state_dict_structure():
},
"predict_loop": {
"state_dict": {},
"dataloader_progress": {
"total": {"ready": 0, "started": None, "processed": None, "completed": 0},
"current": {"ready": 0, "started": None, "processed": None, "completed": 0},
},
"dataloader_progress": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}},
"epoch_loop.state_dict": {},
"epoch_loop.batch_progress": {
"total": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
Expand Down
24 changes: 5 additions & 19 deletions tests/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,8 @@ def val_dataloader(self):

total_dataloader = stop_epoch * n_dataloaders + stop_dataloader
expected = {
"total": {"ready": total_dataloader + 1, "started": None, "processed": None, "completed": total_dataloader},
"current": {"ready": stop_dataloader + 1, "started": None, "processed": None, "completed": stop_dataloader},
"total": {"ready": total_dataloader + 1, "completed": total_dataloader},
"current": {"ready": stop_dataloader + 1, "completed": stop_dataloader},
}
assert checkpoint["epoch_loop.val_loop.dataloader_progress"] == expected

Expand Down Expand Up @@ -452,13 +452,8 @@ def configure_optimizers_multiple(self):
},
},
"epoch_loop.scheduler_progress": {
"total": {
"ready": nbe_sch_steps + be_sch_steps,
"started": None,
"processed": None,
"completed": nbe_sch_steps + be_sch_steps,
},
"current": {"ready": be_sch_steps, "started": None, "processed": None, "completed": be_sch_steps},
"total": {"ready": nbe_sch_steps + be_sch_steps, "completed": nbe_sch_steps + be_sch_steps},
"current": {"ready": be_sch_steps, "completed": be_sch_steps},
},
"epoch_loop.batch_loop.state_dict": ANY,
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
Expand All @@ -468,28 +463,19 @@ def configure_optimizers_multiple(self):
"step": {
"total": {
"ready": nbe_total_opt_steps + be_total_opt_steps + has_opt_stepped_in_be,
"started": None,
"processed": None,
"completed": nbe_total_opt_steps + be_total_opt_steps,
},
"current": {
"ready": be_total_opt_steps + has_opt_stepped_in_be,
"started": None,
"processed": None,
"completed": be_total_opt_steps,
},
"current": {"ready": be_total_opt_steps + has_opt_stepped_in_be, "completed": be_total_opt_steps},
},
"zero_grad": {
"total": {
"ready": nbe_total_zero_grad + be_total_zero_grad,
"started": nbe_total_zero_grad + be_total_zero_grad,
"processed": None,
"completed": nbe_total_zero_grad + be_total_zero_grad,
},
"current": {
"ready": be_total_zero_grad,
"started": be_total_zero_grad,
"processed": None,
"completed": be_total_zero_grad,
},
},
Expand Down
Loading

0 comments on commit 49c0485

Please sign in to comment.