Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix resuming from checkpoint for fault-tolerant in case of no failure #9371

Merged
merged 31 commits into from
Sep 10, 2021
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
cfb4dca
w
awaelchli Sep 8, 2021
b77fd31
comment
awaelchli Sep 8, 2021
69ba327
update fix
awaelchli Sep 8, 2021
e1be811
update fix
awaelchli Sep 8, 2021
9a40fd4
move progress update
awaelchli Sep 8, 2021
b5bc8ee
add comments
awaelchli Sep 8, 2021
d8e2fee
fix test after resetting the progress on a successful run
awaelchli Sep 8, 2021
3d59020
fix a test
awaelchli Sep 8, 2021
f41198a
changelog
awaelchli Sep 8, 2021
585210a
add state dict test
awaelchli Sep 8, 2021
2330d54
add comment
awaelchli Sep 8, 2021
2846526
remove repro script
awaelchli Sep 8, 2021
11a587b
udpate
awaelchli Sep 9, 2021
d6f501f
Merge branch 'master' into bugfix/epoch-resume
awaelchli Sep 9, 2021
387bcfc
update
awaelchli Sep 9, 2021
6a6d3d4
fix tbtt test
awaelchli Sep 9, 2021
34f3ebc
drop old change
awaelchli Sep 9, 2021
5c84846
update tests
awaelchli Sep 9, 2021
8d97f7f
add more tests
awaelchli Sep 9, 2021
8f73fa8
add docstring to test
awaelchli Sep 9, 2021
198a779
remove repro
awaelchli Sep 9, 2021
4961a98
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2021
8177a86
update changelog
awaelchli Sep 9, 2021
04b0043
rm todo test
awaelchli Sep 9, 2021
c28dd59
add torch 1.7.0 requirement to test case
awaelchli Sep 9, 2021
d9be028
reset redundant test changes
awaelchli Sep 9, 2021
4af0626
remove failed check
awaelchli Sep 9, 2021
9fab253
keep optimizer restart check
awaelchli Sep 9, 2021
3db5c5a
update test with optimizer idx assertion
awaelchli Sep 9, 2021
65ab234
Merge branch 'master' into bugfix/epoch-resume
awaelchli Sep 10, 2021
8045fd5
nit
awaelchli Sep 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


- Progress tracking
* Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598)
* Avoid optional `Tracker` attributes ([#9320](https://github.com/PyTorchLightning/pytorch-lightning/pull/9320)
* Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598))
* Avoid optional `Tracker` attributes ([#9320](https://github.com/PyTorchLightning/pytorch-lightning/pull/9320))
* Reset `current` progress counters when restarting an epoch loop that had already finished ([#9371](https://github.com/PyTorchLightning/pytorch-lightning/pull/9371))


- Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628))
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def reset(self) -> None:
# track epoch output
self._epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))]

if not self.restarting:
ended = self._num_training_batches_reached()
if not self.restarting or ended:
self.batch_progress.current.reset()
self.scheduler_progress.current.reset()
self.batch_loop.optimizer_loop.optim_progress.reset_on_epoch()
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def teardown(self) -> None:

def on_save_checkpoint(self) -> Dict:
state_dict = super().on_save_checkpoint()
# FIXME(@tchaton) Should pass has_completed=True when iterator is exhausted ?
# TODO: update has_completed to its proper value
state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict(has_completed=False)
return state_dict

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/optimizer/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def connect(self, **kwargs: "Loop") -> None:
raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")

def reset(self) -> None:
if not self.restarting:
if not self.restarting or self.done:
self.optim_progress.optimizer_idx = 0
self.outputs = [[] for _ in range(len(self.trainer.optimizers))]

Expand Down
233 changes: 233 additions & 0 deletions tests/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loops import Loop, TrainingBatchLoop
from pytorch_lightning.trainer.progress import BaseProgress
from tests.helpers import BoringModel
Expand Down Expand Up @@ -513,3 +514,235 @@ def configure_optimizers_multiple(self):
assert state_dict != checkpoint["loops"]["fit_loop"]
assert state_dict["epoch_progress"]["total"]["started"] == stop_epoch + 1
assert state_dict["epoch_progress"]["current"]["started"] == stop_epoch


@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@pytest.mark.parametrize("n_optimizers", (1, 3, 5))
@RunIf(min_torch="1.7.0")
def test_loop_state_on_complete_run(n_optimizers, tmpdir):
n_epochs = 3
n_batches = 3
accumulate_grad_batches = 1

class TestModel(BoringModel):
def __init__(self):
super().__init__()
if n_optimizers > 1:
self.configure_optimizers = self.configure_optimizers_multiple

def training_step(self, batch, batch_idx, optimizer_idx=0):
return super().training_step(batch, batch_idx)

def configure_optimizers_multiple(self):
optimizers = [torch.optim.Adam(self.layer.parameters(), lr=0.1) for _ in range(n_optimizers)]

lr_scheduler_0 = torch.optim.lr_scheduler.StepLR(optimizers[0], step_size=1)
lr_scheduler_1 = torch.optim.lr_scheduler.StepLR(optimizers[1], step_size=1)
# no scheduler for optimizer_2
lr_schedulers = [lr_scheduler_0, {"scheduler": lr_scheduler_1, "interval": "step"}]

return optimizers, lr_schedulers

model = TestModel()
model.training_epoch_end = None

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=n_epochs,
limit_train_batches=n_batches,
limit_val_batches=0,
accumulate_grad_batches=accumulate_grad_batches,
progress_bar_refresh_rate=0,
logger=False,
checkpoint_callback=True,
)
trainer.fit(model)

ckpt_path = trainer.checkpoint_callback.best_model_path
assert os.path.exists(ckpt_path)
checkpoint = torch.load(ckpt_path)

n_sch_steps_total = n_epochs
n_sch_steps_current = 1
if n_optimizers > 1:
n_sch_steps_total = n_epochs + n_epochs * n_batches
n_sch_steps_current = n_batches + 1

expected = {
"state_dict": ANY,
"epoch_progress": {
"total": {
"ready": n_epochs,
"started": n_epochs,
"processed": n_epochs,
# TODO: the following "-1" offset will be fixed by
# https://github.com/PyTorchLightning/pytorch-lightning/pull/8578
"completed": n_epochs - 1,
},
"current": {
"ready": n_epochs,
"started": n_epochs,
"processed": n_epochs,
# TODO: the following "-1" offset will be fixed by
# https://github.com/PyTorchLightning/pytorch-lightning/pull/8578
"completed": n_epochs - 1,
},
},
"epoch_loop.state_dict": ANY,
"epoch_loop.batch_progress": {
"total": {
"ready": n_epochs * n_batches,
"started": n_epochs * n_batches,
"processed": n_epochs * n_batches,
"completed": n_epochs * n_batches,
},
"current": {
"ready": n_batches,
"started": n_batches,
"processed": n_batches,
"completed": n_batches,
},
},
"epoch_loop.scheduler_progress": {
"total": {"ready": n_sch_steps_total, "completed": n_sch_steps_total},
"current": {"ready": n_sch_steps_current, "completed": n_sch_steps_current},
},
"epoch_loop.batch_loop.state_dict": ANY,
"epoch_loop.batch_loop.manual_loop.state_dict": ANY,
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
"optimizer_idx": n_optimizers,
"optimizer": {
"step": {
"total": {
"ready": n_epochs * n_batches * n_optimizers,
"completed": n_epochs * n_batches * n_optimizers,
},
"current": {
"ready": n_batches * n_optimizers,
"completed": n_batches * n_optimizers,
},
},
"zero_grad": {
"total": {
"ready": n_epochs * n_batches * n_optimizers,
"started": n_epochs * n_batches * n_optimizers,
"completed": n_epochs * n_batches * n_optimizers,
},
"current": {
"ready": n_batches * n_optimizers,
"started": n_batches * n_optimizers,
"completed": n_batches * n_optimizers,
},
},
},
},
"epoch_loop.val_loop.state_dict": ANY,
"epoch_loop.val_loop.dataloader_progress": ANY,
"epoch_loop.val_loop.epoch_loop.state_dict": ANY,
"epoch_loop.val_loop.epoch_loop.batch_progress": ANY,
"epoch_loop.val_loop._results": ANY,
"epoch_loop._results": ANY,
}
assert checkpoint["loops"]["fit_loop"] == expected


@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@RunIf(min_torch="1.7.0")
def test_fit_loop_reset(tmpdir):
"""Test that the reset logic in fit- and epoch loop is aware of whether the loop is restarting from a completed
loop or from a mid-epoch checkpoint."""

# generate checkpoints at end of epoch and mid-epoch
model = BoringModel()
checkpoint_callback = ModelCheckpoint(
dirpath=tmpdir,
every_n_train_steps=2,
save_top_k=-1,
)
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=4,
num_sanity_val_steps=0,
max_epochs=2,
callbacks=[checkpoint_callback],
logger=False,
weights_summary=None,
)
trainer.fit(model)

# reset state loaded from a checkpoint from mid-epoch
mid_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=1.ckpt"))
fit_loop = trainer.fit_loop
epoch_loop = fit_loop.epoch_loop
optimizer_loop = epoch_loop.batch_loop.optimizer_loop
assert not fit_loop.restarting
assert not epoch_loop.restarting
assert not optimizer_loop.restarting

fit_loop.load_state_dict(mid_epoch_ckpt["loops"]["fit_loop"])

def mid_epoch_reset_assertions():
assert fit_loop.restarting
assert fit_loop.epoch_progress.total.ready == 1
assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch
assert fit_loop.epoch_progress.current.ready == 0
assert fit_loop.epoch_progress.current.completed == 0

assert epoch_loop.restarting
assert epoch_loop.batch_progress.total.ready == 2
assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end
assert epoch_loop.batch_progress.current.ready == 2
assert epoch_loop.batch_progress.current.completed == 2

# resetting from a mid-epoch checkpoint should not change progress counters
mid_epoch_reset_assertions()
assert optimizer_loop.optim_progress.optimizer_idx == 1
fit_loop.reset()
epoch_loop.reset()
optimizer_loop.reset()
mid_epoch_reset_assertions()
assert optimizer_loop.optim_progress.optimizer_idx == 0

# reset state loaded from a checkpoint from the end of an epoch
end_of_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=3.ckpt"))
fit_loop = trainer.fit_loop
epoch_loop = fit_loop.epoch_loop
fit_loop.restarting = False
epoch_loop.restarting = False
optimizer_loop.restarting = False

fit_loop.load_state_dict(end_of_epoch_ckpt["loops"]["fit_loop"])

assert fit_loop.restarting
assert fit_loop.epoch_progress.total.ready == 1
assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint saves before the epoch completes
assert fit_loop.epoch_progress.current.ready == 0
assert fit_loop.epoch_progress.current.completed == 0

assert epoch_loop.restarting
assert epoch_loop.batch_progress.total.ready == 4
assert epoch_loop.batch_progress.total.completed == 3 # the checkpoint was saved on train_batch_end
assert epoch_loop.batch_progress.current.ready == 4
assert epoch_loop.batch_progress.current.completed == 4

assert optimizer_loop.optim_progress.optimizer_idx == 1

# resetting from a end-of-epoch checkpoint should reset the current counters to 0
fit_loop.reset()
epoch_loop.reset()
optimizer_loop.reset()

assert fit_loop.restarting
assert fit_loop.epoch_progress.total.ready == 1
assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint saves before the epoch completes
assert fit_loop.epoch_progress.current.ready == 0
assert fit_loop.epoch_progress.current.completed == 0

assert epoch_loop.restarting
assert epoch_loop.batch_progress.total.ready == 4
assert epoch_loop.batch_progress.total.completed == 3 # the checkpoint was saved on train_batch_end
assert epoch_loop.batch_progress.current.ready == 0
assert epoch_loop.batch_progress.current.completed == 0

assert optimizer_loop.optim_progress.optimizer_idx == 0