Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/lerobot/configs/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,24 @@
TRAIN_CONFIG_NAME = "train_config.json"


@dataclass
class EarlyStoppingConfig:
"""Configuration for early stopping based on validation/eval loss."""

# Enable early stopping
enable: bool = False
# Stop training if no improvement for this many steps
patience_steps: int = 10000
# Minimum change to qualify as an improvement (relative)
min_delta: float = 0.0
# Metric to monitor: "val_loss" for validation dataset loss, "eval_success" for sim env success rate
monitor: str = "val_loss"
# Whether higher is better (True for success rate, False for loss)
higher_is_better: bool = False
# Whether to shuffle episodes before splitting into train and validation sets
shuffle_episodes: bool = True


@dataclass
class TrainPipelineConfig(HubMixin):
dataset: DatasetConfig
Expand Down Expand Up @@ -60,11 +78,17 @@ class TrainPipelineConfig(HubMixin):
save_checkpoint: bool = True
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
save_freq: int = 20_000
# Fraction of episodes to use for validation (0.0 = no validation, e.g. 0.1 = 10% for validation).
# Validation loss is computed at each eval_freq step using inference (select_action).
validation_fraction: float = 0.0
use_policy_training_preset: bool = True
optimizer: OptimizerConfig | None = None
scheduler: LRSchedulerConfig | None = None
eval: EvalConfig = field(default_factory=EvalConfig)
wandb: WandBConfig = field(default_factory=WandBConfig)
early_stopping: EarlyStoppingConfig = field(default_factory=EarlyStoppingConfig)
# Maximum number of checkpoints to keep (0 = keep all)
keep_last_n_checkpoints: int = 0
peft: PeftConfig | None = None

# RA-BC (Reward-Aligned Behavior Cloning) parameters
Expand Down
Loading