Skip to content

Commit a930915

Browse files
[Feature] keep_checkpoints_num and checkpoint_at_end (facebookresearch#102)
* amend * amend * amend
1 parent 2a337e8 commit a930915

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

benchmarl/conf/experiment/base_experiment.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,8 @@ restore_file: null
9898
# Interval for experiment saving in terms of collected frames (this should be a multiple of on/off_policy_collected_frames_per_batch).
9999
# Set it to 0 to disable checkpointing
100100
checkpoint_interval: 0
101+
# Wether to checkpoint when the experiment is done
102+
checkpoint_at_end: False
103+
# How many checkpoints to keep. As new checkpoints are taken, temporally older checkpoints are deleted to keep this number of
104+
# checkpoints. The checkpoint at the end is included in this number. Set to `null` to keep all checkpoints.
105+
keep_checkpoints_num: 3

benchmarl/experiment/experiment.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import os
1313
import time
14-
from collections import OrderedDict
14+
from collections import deque, OrderedDict
1515
from dataclasses import dataclass, MISSING
1616
from pathlib import Path
1717
from typing import Dict, List, Optional
@@ -96,7 +96,9 @@ class ExperimentConfig:
9696

9797
save_folder: Optional[str] = MISSING
9898
restore_file: Optional[str] = MISSING
99-
checkpoint_interval: float = MISSING
99+
checkpoint_interval: int = MISSING
100+
checkpoint_at_end: bool = MISSING
101+
keep_checkpoints_num: Optional[int] = MISSING
100102

101103
def train_batch_size(self, on_policy: bool) -> int:
102104
"""
@@ -280,6 +282,8 @@ def validate(self, on_policy: bool):
280282
f"checkpoint_interval ({self.checkpoint_interval}) "
281283
f"is not a multiple of the collected_frames_per_batch ({self.collected_frames_per_batch(on_policy)})"
282284
)
285+
if self.keep_checkpoints_num is not None and self.keep_checkpoints_num <= 0:
286+
raise ValueError("keep_checkpoints_num must be greater than zero or null")
283287
if self.max_n_frames is None and self.max_n_iters is None:
284288
raise ValueError("n_iters and total_frames are both not set")
285289

@@ -483,6 +487,7 @@ def _setup_name(self):
483487
self.model_name = self.model_config.associated_class().__name__.lower()
484488
self.environment_name = self.task.env_name().lower()
485489
self.task_name = self.task.name.lower()
490+
self._checkpointed_files = deque([])
486491

487492
if self.config.restore_file is not None and self.config.save_folder is not None:
488493
raise ValueError(
@@ -668,6 +673,8 @@ def _collection_loop(self):
668673
pbar.update()
669674
sampling_start = time.time()
670675

676+
if self.config.checkpoint_at_end:
677+
self._save_experiment()
671678
self.close()
672679

673680
def close(self):
@@ -835,10 +842,16 @@ def load_state_dict(self, state_dict: Dict) -> None:
835842

836843
def _save_experiment(self) -> None:
837844
"""Checkpoint trainer"""
845+
if self.config.keep_checkpoints_num is not None:
846+
while len(self._checkpointed_files) >= self.config.keep_checkpoints_num:
847+
file_to_delete = self._checkpointed_files.popleft()
848+
file_to_delete.unlink(missing_ok=False)
849+
838850
checkpoint_folder = self.folder_name / "checkpoints"
839851
checkpoint_folder.mkdir(parents=False, exist_ok=True)
840852
checkpoint_file = checkpoint_folder / f"checkpoint_{self.total_frames}.pt"
841853
torch.save(self.state_dict(), checkpoint_file)
854+
self._checkpointed_files.append(checkpoint_file)
842855

843856
def _load_experiment(self) -> Experiment:
844857
"""Load trainer from checkpoint"""

0 commit comments

Comments
 (0)