Skip to content

Commit 5260fdd

Browse files
committed
[Feature] keep_checkpoints_num and checkpoint_at_end (facebookresearch#102)
* amend * amend * amend (cherry picked from commit a930915)
1 parent f6dad33 commit 5260fdd

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

benchmarl/conf/experiment/base_experiment.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,8 @@ restore_file: null
9898
# Interval for experiment saving in terms of collected frames (this should be a multiple of on/off_policy_collected_frames_per_batch).
9999
# Set it to 0 to disable checkpointing
100100
checkpoint_interval: 0
101+
# Wether to checkpoint when the experiment is done
102+
checkpoint_at_end: False
103+
# How many checkpoints to keep. As new checkpoints are taken, temporally older checkpoints are deleted to keep this number of
104+
# checkpoints. The checkpoint at the end is included in this number. Set to `null` to keep all checkpoints.
105+
keep_checkpoints_num: 3

benchmarl/experiment/experiment.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import os
1313
import time
14-
from collections import OrderedDict
14+
from collections import deque, OrderedDict
1515
from dataclasses import dataclass, MISSING
1616
from pathlib import Path
1717
from typing import Dict, List, Optional
@@ -96,7 +96,9 @@ class ExperimentConfig:
9696

9797
save_folder: Optional[str] = MISSING
9898
restore_file: Optional[str] = MISSING
99-
checkpoint_interval: float = MISSING
99+
checkpoint_interval: int = MISSING
100+
checkpoint_at_end: bool = MISSING
101+
keep_checkpoints_num: Optional[int] = MISSING
100102

101103
def train_batch_size(self, on_policy: bool) -> int:
102104
"""
@@ -280,6 +282,8 @@ def validate(self, on_policy: bool):
280282
f"checkpoint_interval ({self.checkpoint_interval}) "
281283
f"is not a multiple of the collected_frames_per_batch ({self.collected_frames_per_batch(on_policy)})"
282284
)
285+
if self.keep_checkpoints_num is not None and self.keep_checkpoints_num <= 0:
286+
raise ValueError("keep_checkpoints_num must be greater than zero or null")
283287
if self.max_n_frames is None and self.max_n_iters is None:
284288
raise ValueError("n_iters and total_frames are both not set")
285289

@@ -483,6 +487,7 @@ def _setup_name(self):
483487
self.model_name = self.model_config.associated_class().__name__.lower()
484488
self.environment_name = self.task.env_name().lower()
485489
self.task_name = self.task.name.lower()
490+
self._checkpointed_files = deque([])
486491

487492
if self.config.restore_file is not None and self.config.save_folder is not None:
488493
raise ValueError(
@@ -667,6 +672,8 @@ def _collection_loop(self):
667672
self._save_experiment()
668673
pbar.update()
669674

675+
if self.config.checkpoint_at_end:
676+
self._save_experiment()
670677
self.close()
671678

672679
def close(self):
@@ -834,10 +841,16 @@ def load_state_dict(self, state_dict: Dict) -> None:
834841

835842
def _save_experiment(self) -> None:
836843
"""Checkpoint trainer"""
844+
if self.config.keep_checkpoints_num is not None:
845+
while len(self._checkpointed_files) >= self.config.keep_checkpoints_num:
846+
file_to_delete = self._checkpointed_files.popleft()
847+
file_to_delete.unlink(missing_ok=False)
848+
837849
checkpoint_folder = self.folder_name / "checkpoints"
838850
checkpoint_folder.mkdir(parents=False, exist_ok=True)
839851
checkpoint_file = checkpoint_folder / f"checkpoint_{self.total_frames}.pt"
840852
torch.save(self.state_dict(), checkpoint_file)
853+
self._checkpointed_files.append(checkpoint_file)
841854

842855
def _load_experiment(self) -> Experiment:
843856
"""Load trainer from checkpoint"""

0 commit comments

Comments
 (0)