diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index 7a5eee77d9d..8a4a3680978 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -33,6 +33,24 @@ TRAIN_CONFIG_NAME = "train_config.json" +@dataclass +class EarlyStoppingConfig: + """Configuration for early stopping based on validation/eval loss.""" + + # Enable early stopping + enable: bool = False + # Stop training if no improvement for this many steps + patience_steps: int = 10000 + # Minimum change to qualify as an improvement (relative) + min_delta: float = 0.0 + # Metric to monitor: "val_loss" for validation dataset loss, "eval_success" for sim env success rate + monitor: str = "val_loss" + # Whether higher is better (True for success rate, False for loss) + higher_is_better: bool = False + # Whether to shuffle episodes before splitting into train and validation sets + shuffle_episodes: bool = True + + @dataclass class TrainPipelineConfig(HubMixin): dataset: DatasetConfig @@ -60,11 +78,17 @@ class TrainPipelineConfig(HubMixin): save_checkpoint: bool = True # Checkpoint is saved every `save_freq` training iterations and after the last training step. save_freq: int = 20_000 + # Fraction of episodes to use for validation (0.0 = no validation, e.g. 0.1 = 10% for validation). + # Validation loss is computed at each eval_freq step using inference (select_action). + validation_fraction: float = 0.0 use_policy_training_preset: bool = True optimizer: OptimizerConfig | None = None scheduler: LRSchedulerConfig | None = None eval: EvalConfig = field(default_factory=EvalConfig) wandb: WandBConfig = field(default_factory=WandBConfig) + early_stopping: EarlyStoppingConfig = field(default_factory=EarlyStoppingConfig) + # Maximum number of checkpoints to keep (0 = keep all) + keep_last_n_checkpoints: int = 0 peft: PeftConfig | None = None # RA-BC (Reward-Aligned Behavior Cloning) parameters diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 93b99e245f1..e2cbedb6278 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -15,6 +15,7 @@ # limitations under the License. import dataclasses import logging +import shutil import time from contextlib import nullcontext from pprint import pformat @@ -54,6 +55,97 @@ ) +class EarlyStoppingTracker: + """Tracks validation metrics and determines when to stop training early.""" + + def __init__( + self, + patience_steps: int, + min_delta: float = 0.0, + higher_is_better: bool = False, + ): + self.patience_steps = patience_steps + self.min_delta = min_delta + self.higher_is_better = higher_is_better + self.best_value: float | None = None + self.best_step: int = 0 + self.steps_without_improvement: int = 0 + + def update(self, value: float, step: int) -> bool: + """Update tracker with new metric value. + + Returns True if training should stop (patience exceeded). + """ + if self.best_value is None: + self.best_value = value + self.best_step = step + return False + + # Check if this is an improvement + if self.higher_is_better: + improved = value > self.best_value * (1 + self.min_delta) + else: + improved = value < self.best_value * (1 - self.min_delta) + + if improved: + self.best_value = value + self.best_step = step + self.steps_without_improvement = 0 + else: + self.steps_without_improvement = step - self.best_step + + return self.steps_without_improvement >= self.patience_steps + + def get_status(self) -> dict: + """Get current early stopping status for logging.""" + return { + "best_value": self.best_value, + "best_step": self.best_step, + "steps_without_improvement": self.steps_without_improvement, + } + + +def cleanup_old_checkpoints(output_dir, keep_last_n: int, current_step: int) -> None: + """Remove old checkpoints, keeping only the last N plus the 'last' symlink target.""" + if keep_last_n <= 0: + return + + checkpoints_dir = output_dir / "checkpoints" + if not checkpoints_dir.exists(): + return + + # Find all step checkpoint directories (format: NNNNNN or step_NNNNNN) + checkpoint_dirs = [] + for d in checkpoints_dir.iterdir(): + if d.is_dir() and d.name != "last": + # Try to extract step number from directory name + try: + # Handle formats like "005000" or "step_005000" + name = d.name.replace("step_", "") + step_num = int(name) + checkpoint_dirs.append((step_num, d)) + except ValueError: + continue + + # Sort by step number (oldest first) + checkpoint_dirs.sort(key=lambda x: x[0]) + + # Keep the last N checkpoints + if len(checkpoint_dirs) > keep_last_n: + to_remove = checkpoint_dirs[:-keep_last_n] + last_symlink = checkpoints_dir / "last" + last_resolved = last_symlink.resolve() if last_symlink.exists() else None + + for _step_num, checkpoint_dir in to_remove: + # Never remove the checkpoint pointed to by the 'last' symlink + if last_resolved and checkpoint_dir.resolve() == last_resolved: + logging.info(f"Skipping removal of current 'last' checkpoint: {checkpoint_dir}") + continue + + logging.info(f"Removing old checkpoint: {checkpoint_dir}") + shutil.rmtree(checkpoint_dir) + + def update_policy( train_metrics: MetricsTracker, policy: PreTrainedPolicy, @@ -148,6 +240,80 @@ def update_policy( return train_metrics, output_dict +def validate_dataset_loss( + policy: PreTrainedPolicy, + dataloader: torch.utils.data.DataLoader, + preprocessor, + accelerator: Accelerator, +) -> dict[str, float]: + """Validate using actual inference (select_action) for model-agnostic L1/L2 loss. + + This mimics real inference by calling select_action and comparing predicted + actions with ground truth, providing metrics that are comparable across + different policy architectures. + + Note: The dataloader should have batch_size=1 because select_action in many + policies (e.g. ACT, SmolVLA) uses internal temporal queues that are not + compatible with batching during inference. + """ + from lerobot.utils.constants import ACTION + + policy.eval() + + l1_acc = 0.0 + l2_acc = 0.0 + count = 0 + + # Unwrap policy from accelerator wrapper to access reset/select_action methods + unwrapped_policy = accelerator.unwrap_model(policy) + + with torch.no_grad(), accelerator.autocast(): + for batch in dataloader: + batch = preprocessor(batch) + + # Extract ground truth actions before select_action potentially pops them + gt_actions = batch[ACTION].clone() + + # Reset policy to clear any internal queues/state + unwrapped_policy.reset() + + # Run actual inference - this is the model-agnostic path + # select_action returns shape (batch_size, action_dim) + pred_actions = unwrapped_policy.select_action(batch) + + # Handle different action shapes: + # - gt_actions may be (B, action_dim) or (B, horizon, action_dim) + # - pred_actions is typically (B, action_dim) for single step + # Ground truth has temporal dimension, compare with first timestep + gt_first = gt_actions[:, 0, :] if gt_actions.dim() == 3 else gt_actions + + # Ensure shapes match for comparison + if pred_actions.shape != gt_first.shape: + # Truncate to minimum size if needed + min_dim = min(pred_actions.shape[-1], gt_first.shape[-1]) + pred_actions = pred_actions[..., :min_dim] + gt_first = gt_first[..., :min_dim] + + # Compute L1 and L2 loss + l1_loss = torch.nn.functional.l1_loss(pred_actions, gt_first) + l2_loss = torch.nn.functional.mse_loss(pred_actions, gt_first) + + l1_acc += l1_loss.item() + l2_acc += l2_loss.item() + count += 1 + + policy.train() + + if count == 0: + return {} + + return { + "loss": l1_acc / count, # Use L1 as primary loss for early stopping compatibility + "l1_loss": l1_acc / count, + "l2_loss": l2_acc / count, + } + + @parser.wrap() def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): """ @@ -221,6 +387,16 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): if not is_main_process: dataset = make_dataset(cfg) + # Create a separate validation dataset without augmentations if we have validation episodes + val_dataset = None + if cfg.validation_fraction > 0: + # Temporarily disable image transforms for validation dataset + original_transforms_enable = cfg.dataset.image_transforms.enable + cfg.dataset.image_transforms.enable = False + val_dataset = make_dataset(cfg) + # Restore original setting + cfg.dataset.image_transforms.enable = original_transforms_enable + # Create environment used for evaluating checkpoints during training on simulation data. # On real-world data, no need to create an environment as evaluations are done outside train.py, # using the eval.py instead, with gym_dora environment and dora-rs. @@ -337,30 +513,85 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") # create dataloader for offline training - if hasattr(cfg.policy, "drop_n_last_frames"): - shuffle = False - sampler = EpisodeAwareSampler( + # Split episodes into train/val if validation_fraction > 0 + val_episode_indices = [] + train_episode_indices = None # None means use all episodes (default behavior) + + if cfg.validation_fraction > 0: + num_episodes = dataset.num_episodes + num_val_episodes = int(num_episodes * cfg.validation_fraction) + if num_val_episodes == 0: + logging.warning( + "Validation fraction is too small to yield any episodes. Using 0 validation episodes." + ) + else: + all_indices = list(range(num_episodes)) + if cfg.early_stopping.shuffle_episodes: + import random + + random.Random(cfg.seed).shuffle(all_indices) + + val_episode_indices = all_indices[:num_val_episodes] + train_episode_indices = all_indices[num_val_episodes:] + + if is_main_process: + logging.info( + f"Training on {len(train_episode_indices)} episodes, " + f"validating on {len(val_episode_indices)} episodes " + f"(shuffled={cfg.early_stopping.shuffle_episodes})" + ) + + # Determine if we need to use EpisodeAwareSampler + # Use sampler when: validation split is active OR policy requires drop_n_last_frames + use_sampler = train_episode_indices is not None or hasattr(cfg.policy, "drop_n_last_frames") + + if use_sampler: + # When using validation split, use train_episode_indices + # Otherwise, use dataset.episodes (respects dataset config) or None (all episodes) + episode_indices = train_episode_indices if train_episode_indices is not None else dataset.episodes + train_sampler = EpisodeAwareSampler( dataset.meta.episodes["dataset_from_index"], dataset.meta.episodes["dataset_to_index"], - episode_indices_to_use=dataset.episodes, - drop_n_last_frames=cfg.policy.drop_n_last_frames, + episode_indices_to_use=episode_indices, + drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0), shuffle=True, ) + shuffle = False else: + train_sampler = None shuffle = True - sampler = None dataloader = torch.utils.data.DataLoader( dataset, num_workers=cfg.num_workers, batch_size=cfg.batch_size, shuffle=shuffle and not cfg.dataset.streaming, - sampler=sampler, + sampler=train_sampler, pin_memory=device.type == "cuda", drop_last=False, prefetch_factor=2 if cfg.num_workers > 0 else None, ) + # Val dataloader (uses separate dataset without augmentations) + val_dataloader = None + if val_episode_indices and val_dataset is not None: + val_sampler = EpisodeAwareSampler( + val_dataset.meta.episodes["dataset_from_index"], + val_dataset.meta.episodes["dataset_to_index"], + episode_indices_to_use=val_episode_indices, + drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0), + shuffle=False, + ) + val_dataloader = torch.utils.data.DataLoader( + val_dataset, + num_workers=cfg.num_workers, + batch_size=1, # Must be 1 for select_action inference compatibility + sampler=val_sampler, + pin_memory=device.type == "cuda", + drop_last=False, + prefetch_factor=2 if cfg.num_workers > 0 else None, + ) + # Prepare everything with accelerator accelerator.wait_for_everyone() policy, optimizer, dataloader, lr_scheduler = accelerator.prepare( @@ -389,11 +620,26 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): accelerator=accelerator, ) + # Initialize early stopping tracker + early_stopping_tracker = None + if cfg.early_stopping.enable: + early_stopping_tracker = EarlyStoppingTracker( + patience_steps=cfg.early_stopping.patience_steps, + min_delta=cfg.early_stopping.min_delta, + higher_is_better=cfg.early_stopping.higher_is_better, + ) + if is_main_process: + logging.info( + f"Early stopping enabled: patience={cfg.early_stopping.patience_steps}, " + f"monitor={cfg.early_stopping.monitor}, higher_is_better={cfg.early_stopping.higher_is_better}" + ) + if is_main_process: logging.info( f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}" ) + early_stop_triggered = False for _ in range(step, cfg.steps): start_time = time.perf_counter() batch = next(dl_iter) @@ -458,10 +704,26 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): accelerator.wait_for_everyone() - if cfg.env and is_eval_step: - if is_main_process: - step_id = get_step_identifier(step, cfg.steps) - logging.info(f"Eval policy at step {step}") + if is_eval_step: + eval_metric_value = None + + if val_dataloader: + if is_main_process: + logging.info(f"Validating on dataset at step {step}") + val_metrics = validate_dataset_loss(policy, val_dataloader, preprocessor, accelerator) + if is_main_process: + logging.info(f"Validation metrics: {val_metrics}") + if wandb_logger: + wandb_logger.log_dict({f"val/{k}": v for k, v in val_metrics.items()}, step) + + # Track validation loss for early stopping + if cfg.early_stopping.monitor == "val_loss" and "loss" in val_metrics: + eval_metric_value = val_metrics["loss"] + + if cfg.env: + if is_main_process: + step_id = get_step_identifier(step, cfg.steps) + logging.info(f"Eval policy at step {step}") with torch.no_grad(), accelerator.autocast(): eval_info = eval_policy_all( envs=eval_env, # dict[suite][task_id] -> vec_env @@ -483,6 +745,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): for suite, suite_info in eval_info.items(): logging.info("Suite %s aggregated: %s", suite, suite_info) + # Track success rate for early stopping + if cfg.early_stopping.monitor == "eval_success": + eval_metric_value = aggregated.get("pc_success", 0.0) + # meters/tracker eval_metrics = { "avg_sum_reward": AverageMeter("∑rwrd", ":.3f"), @@ -505,13 +771,42 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): wandb_logger.log_dict(wandb_log_dict, step, mode="eval") wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval") + # Early stopping check + if early_stopping_tracker is not None and eval_metric_value is not None: + should_stop = early_stopping_tracker.update(eval_metric_value, step) + status = early_stopping_tracker.get_status() + if is_main_process: + logging.info( + f"Early stopping: value={eval_metric_value:.4f}, " + f"best={status['best_value']:.4f} @ step {status['best_step']}, " + f"no improvement for {status['steps_without_improvement']} steps" + ) + if should_stop: + if is_main_process: + logging.info( + f"Early stopping triggered at step {step}. " + f"Best value: {status['best_value']:.4f} at step {status['best_step']}" + ) + early_stop_triggered = True + accelerator.wait_for_everyone() + # Cleanup old checkpoints if configured + if cfg.save_checkpoint and is_saving_step and cfg.keep_last_n_checkpoints > 0 and is_main_process: + cleanup_old_checkpoints(cfg.output_dir, cfg.keep_last_n_checkpoints, step) + + # Break out of training loop if early stopping triggered + if early_stop_triggered: + break + if eval_env: close_envs(eval_env) if is_main_process: - logging.info("End of training") + if early_stop_triggered: + logging.info(f"Training stopped early at step {step} due to early stopping") + else: + logging.info("End of training") if cfg.policy.push_to_hub: unwrapped_policy = accelerator.unwrap_model(policy)