|
11 | 11 |
|
12 | 12 | import os
|
13 | 13 | import time
|
14 |
| -from collections import OrderedDict |
| 14 | +from collections import deque, OrderedDict |
15 | 15 | from dataclasses import dataclass, MISSING
|
16 | 16 | from pathlib import Path
|
17 | 17 | from typing import Dict, List, Optional
|
@@ -96,7 +96,9 @@ class ExperimentConfig:
|
96 | 96 |
|
97 | 97 | save_folder: Optional[str] = MISSING
|
98 | 98 | restore_file: Optional[str] = MISSING
|
99 |
| - checkpoint_interval: float = MISSING |
| 99 | + checkpoint_interval: int = MISSING |
| 100 | + checkpoint_at_end: bool = MISSING |
| 101 | + keep_checkpoints_num: Optional[int] = MISSING |
100 | 102 |
|
101 | 103 | def train_batch_size(self, on_policy: bool) -> int:
|
102 | 104 | """
|
@@ -280,6 +282,8 @@ def validate(self, on_policy: bool):
|
280 | 282 | f"checkpoint_interval ({self.checkpoint_interval}) "
|
281 | 283 | f"is not a multiple of the collected_frames_per_batch ({self.collected_frames_per_batch(on_policy)})"
|
282 | 284 | )
|
| 285 | + if self.keep_checkpoints_num is not None and self.keep_checkpoints_num <= 0: |
| 286 | + raise ValueError("keep_checkpoints_num must be greater than zero or null") |
283 | 287 | if self.max_n_frames is None and self.max_n_iters is None:
|
284 | 288 | raise ValueError("n_iters and total_frames are both not set")
|
285 | 289 |
|
@@ -483,6 +487,7 @@ def _setup_name(self):
|
483 | 487 | self.model_name = self.model_config.associated_class().__name__.lower()
|
484 | 488 | self.environment_name = self.task.env_name().lower()
|
485 | 489 | self.task_name = self.task.name.lower()
|
| 490 | + self._checkpointed_files = deque([]) |
486 | 491 |
|
487 | 492 | if self.config.restore_file is not None and self.config.save_folder is not None:
|
488 | 493 | raise ValueError(
|
@@ -667,6 +672,8 @@ def _collection_loop(self):
|
667 | 672 | self._save_experiment()
|
668 | 673 | pbar.update()
|
669 | 674 |
|
| 675 | + if self.config.checkpoint_at_end: |
| 676 | + self._save_experiment() |
670 | 677 | self.close()
|
671 | 678 |
|
672 | 679 | def close(self):
|
@@ -834,10 +841,16 @@ def load_state_dict(self, state_dict: Dict) -> None:
|
834 | 841 |
|
835 | 842 | def _save_experiment(self) -> None:
|
836 | 843 | """Checkpoint trainer"""
|
| 844 | + if self.config.keep_checkpoints_num is not None: |
| 845 | + while len(self._checkpointed_files) >= self.config.keep_checkpoints_num: |
| 846 | + file_to_delete = self._checkpointed_files.popleft() |
| 847 | + file_to_delete.unlink(missing_ok=False) |
| 848 | + |
837 | 849 | checkpoint_folder = self.folder_name / "checkpoints"
|
838 | 850 | checkpoint_folder.mkdir(parents=False, exist_ok=True)
|
839 | 851 | checkpoint_file = checkpoint_folder / f"checkpoint_{self.total_frames}.pt"
|
840 | 852 | torch.save(self.state_dict(), checkpoint_file)
|
| 853 | + self._checkpointed_files.append(checkpoint_file) |
841 | 854 |
|
842 | 855 | def _load_experiment(self) -> Experiment:
|
843 | 856 | """Load trainer from checkpoint"""
|
|
0 commit comments