Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] keep_checkpoints_num and checkpoint_at_end #102

Merged
merged 3 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions benchmarl/conf/experiment/base_experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,8 @@ restore_file: null
# Interval for experiment saving in terms of collected frames (this should be a multiple of on/off_policy_collected_frames_per_batch).
# Set it to 0 to disable checkpointing
checkpoint_interval: 0
# Wether to checkpoint when the experiment is done
checkpoint_at_end: False
# How many checkpoints to keep. As new checkpoints are taken, temporally older checkpoints are deleted to keep this number of
# checkpoints. The checkpoint at the end is included in this number. Set to `null` to keep all checkpoints.
keep_checkpoints_num: 3
17 changes: 15 additions & 2 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import os
import time
from collections import OrderedDict
from collections import deque, OrderedDict
from dataclasses import dataclass, MISSING
from pathlib import Path
from typing import Dict, List, Optional
Expand Down Expand Up @@ -96,7 +96,9 @@ class ExperimentConfig:

save_folder: Optional[str] = MISSING
restore_file: Optional[str] = MISSING
checkpoint_interval: float = MISSING
checkpoint_interval: int = MISSING
checkpoint_at_end: bool = MISSING
keep_checkpoints_num: Optional[int] = MISSING

def train_batch_size(self, on_policy: bool) -> int:
"""
Expand Down Expand Up @@ -280,6 +282,8 @@ def validate(self, on_policy: bool):
f"checkpoint_interval ({self.checkpoint_interval}) "
f"is not a multiple of the collected_frames_per_batch ({self.collected_frames_per_batch(on_policy)})"
)
if self.keep_checkpoints_num is not None and self.keep_checkpoints_num <= 0:
raise ValueError("keep_checkpoints_num must be greater than zero or null")
if self.max_n_frames is None and self.max_n_iters is None:
raise ValueError("n_iters and total_frames are both not set")

Expand Down Expand Up @@ -483,6 +487,7 @@ def _setup_name(self):
self.model_name = self.model_config.associated_class().__name__.lower()
self.environment_name = self.task.env_name().lower()
self.task_name = self.task.name.lower()
self._checkpointed_files = deque([])

if self.config.restore_file is not None and self.config.save_folder is not None:
raise ValueError(
Expand Down Expand Up @@ -668,6 +673,8 @@ def _collection_loop(self):
pbar.update()
sampling_start = time.time()

if self.config.checkpoint_at_end:
self._save_experiment()
self.close()

def close(self):
Expand Down Expand Up @@ -835,10 +842,16 @@ def load_state_dict(self, state_dict: Dict) -> None:

def _save_experiment(self) -> None:
"""Checkpoint trainer"""
if self.config.keep_checkpoints_num is not None:
while len(self._checkpointed_files) >= self.config.keep_checkpoints_num:
file_to_delete = self._checkpointed_files.popleft()
file_to_delete.unlink(missing_ok=False)

checkpoint_folder = self.folder_name / "checkpoints"
checkpoint_folder.mkdir(parents=False, exist_ok=True)
checkpoint_file = checkpoint_folder / f"checkpoint_{self.total_frames}.pt"
torch.save(self.state_dict(), checkpoint_file)
self._checkpointed_files.append(checkpoint_file)

def _load_experiment(self) -> Experiment:
"""Load trainer from checkpoint"""
Expand Down
Loading