diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py index e835b1de69b..354df086bee 100644 --- a/src/lerobot/scripts/lerobot_edit_dataset.py +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -245,7 +245,13 @@ def handle_merge(cfg: EditDatasetConfig) -> None: raise ValueError("repo_id must be specified as the output repository for merged dataset") logging.info(f"Loading {len(cfg.operation.repo_ids)} datasets to merge") - datasets = [LeRobotDataset(repo_id, root=cfg.root) for repo_id in cfg.operation.repo_ids] + # When root is provided, we need to construct the full path for each dataset + # by appending the repo_id to the root. When root is None, LeRobotDataset + # automatically uses HF_LEROBOT_HOME / repo_id. + datasets = [ + LeRobotDataset(repo_id, root=Path(cfg.root) / repo_id if cfg.root else None) + for repo_id in cfg.operation.repo_ids + ] output_dir = Path(cfg.root) / cfg.repo_id if cfg.root else HF_LEROBOT_HOME / cfg.repo_id diff --git a/tests/datasets/test_dataset_tools.py b/tests/datasets/test_dataset_tools.py index 3a4516fc82d..5b0750ca4c2 100644 --- a/tests/datasets/test_dataset_tools.py +++ b/tests/datasets/test_dataset_tools.py @@ -29,7 +29,7 @@ remove_feature, split_dataset, ) -from lerobot.scripts.lerobot_edit_dataset import convert_dataset_to_videos +from lerobot.scripts.lerobot_edit_dataset import EditDatasetConfig, MergeConfig, convert_dataset_to_videos, handle_merge @pytest.fixture @@ -850,6 +850,221 @@ def test_merge_preserves_stats(sample_dataset, tmp_path, empty_lerobot_dataset_f assert "std" in merged.meta.stats[feature] +def test_handle_merge_with_custom_root(tmp_path, empty_lerobot_dataset_factory): + """Test handle_merge() with custom root path (bug fix for custom root paths).""" + features = { + "action": {"dtype": "float32", "shape": (6,), "names": None}, + "observation.state": {"dtype": "float32", "shape": (4,), "names": None}, + "observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None}, + } + + # Create custom root directory + custom_root = tmp_path / "custom_datasets" + custom_root.mkdir() + + # Create two datasets in custom root + dataset1 = empty_lerobot_dataset_factory( + root=custom_root / "dataset1", + features=features, + ) + for ep_idx in range(2): + for _ in range(10): + frame = { + "action": np.random.randn(6).astype(np.float32), + "observation.state": np.random.randn(4).astype(np.float32), + "observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8), + "task": f"task_{ep_idx % 2}", + } + dataset1.add_frame(frame) + dataset1.save_episode() + dataset1.finalize() + + dataset2 = empty_lerobot_dataset_factory( + root=custom_root / "dataset2", + features=features, + ) + for ep_idx in range(3): + for _ in range(10): + frame = { + "action": np.random.randn(6).astype(np.float32), + "observation.state": np.random.randn(4).astype(np.float32), + "observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8), + "task": f"task_{ep_idx % 2}", + } + dataset2.add_frame(frame) + dataset2.save_episode() + dataset2.finalize() + + # Test handle_merge with custom root + cfg = EditDatasetConfig( + repo_id="merged_dataset", + root=custom_root, + operation=MergeConfig( + type="merge", + repo_ids=["dataset1", "dataset2"], + ), + ) + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(custom_root / "merged_dataset") + handle_merge(cfg) + + # Verify merged dataset was created in correct location + merged_path = custom_root / "merged_dataset" + assert merged_path.exists() + assert (merged_path / "meta" / "info.json").exists() + + # Load and verify merged dataset + from lerobot.datasets.lerobot_dataset import LeRobotDataset + + merged = LeRobotDataset("merged_dataset", root=merged_path) + assert merged.meta.total_episodes == 5 # 2 + 3 + + +def test_handle_merge_without_custom_root(tmp_path, empty_lerobot_dataset_factory): + """Test handle_merge() without custom root (uses default HF_LEROBOT_HOME behavior).""" + features = { + "action": {"dtype": "float32", "shape": (6,), "names": None}, + "observation.state": {"dtype": "float32", "shape": (4,), "names": None}, + } + + # Create datasets in tmp_path (simulating HF_LEROBOT_HOME) + dataset1 = empty_lerobot_dataset_factory( + root=tmp_path / "dataset1", + features=features, + ) + for ep_idx in range(2): + for _ in range(5): + frame = { + "action": np.random.randn(6).astype(np.float32), + "observation.state": np.random.randn(4).astype(np.float32), + "task": f"task_{ep_idx}", + } + dataset1.add_frame(frame) + dataset1.save_episode() + dataset1.finalize() + + dataset2 = empty_lerobot_dataset_factory( + root=tmp_path / "dataset2", + features=features, + ) + for ep_idx in range(1): + for _ in range(5): + frame = { + "action": np.random.randn(6).astype(np.float32), + "observation.state": np.random.randn(4).astype(np.float32), + "task": f"task_{ep_idx}", + } + dataset2.add_frame(frame) + dataset2.save_episode() + dataset2.finalize() + + # Test handle_merge with root=None (will use constructed paths from datasets) + # Note: We need to pass explicit root to match the datasets we created + cfg = EditDatasetConfig( + repo_id="merged_dataset", + root=tmp_path, + operation=MergeConfig( + type="merge", + repo_ids=["dataset1", "dataset2"], + ), + ) + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "merged_dataset") + handle_merge(cfg) + + # Verify merged dataset + merged_path = tmp_path / "merged_dataset" + assert merged_path.exists() + + from lerobot.datasets.lerobot_dataset import LeRobotDataset + + merged = LeRobotDataset("merged_dataset", root=merged_path) + assert merged.meta.total_episodes == 3 # 2 + 1 + + +def test_handle_merge_custom_root_preserves_metadata(tmp_path, empty_lerobot_dataset_factory): + """Test that handle_merge with custom root preserves dataset metadata.""" + features = { + "action": {"dtype": "float32", "shape": (6,), "names": None}, + "observation.state": {"dtype": "float32", "shape": (4,), "names": None}, + } + + custom_root = tmp_path / "datasets" + custom_root.mkdir() + + # Create dataset with specific FPS + dataset1 = empty_lerobot_dataset_factory( + root=custom_root / "dataset1", + features=features, + fps=30, + ) + for ep_idx in range(1): + for _ in range(10): + frame = { + "action": np.random.randn(6).astype(np.float32), + "observation.state": np.random.randn(4).astype(np.float32), + "task": f"task_{ep_idx}", + } + dataset1.add_frame(frame) + dataset1.save_episode() + dataset1.finalize() + + dataset2 = empty_lerobot_dataset_factory( + root=custom_root / "dataset2", + features=features, + fps=30, + ) + for ep_idx in range(2): + for _ in range(10): + frame = { + "action": np.random.randn(6).astype(np.float32), + "observation.state": np.random.randn(4).astype(np.float32), + "task": f"task_{ep_idx}", + } + dataset2.add_frame(frame) + dataset2.save_episode() + dataset2.finalize() + + # Merge with custom root + cfg = EditDatasetConfig( + repo_id="merged_dataset", + root=custom_root, + operation=MergeConfig( + type="merge", + repo_ids=["dataset1", "dataset2"], + ), + ) + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(custom_root / "merged_dataset") + handle_merge(cfg) + + # Verify metadata preserved + from lerobot.datasets.lerobot_dataset import LeRobotDataset + + merged = LeRobotDataset("merged_dataset", root=custom_root / "merged_dataset") + assert merged.meta.fps == 30 + assert merged.meta.total_episodes == 3 # 1 + 2 + assert merged.meta.total_frames == 30 # 10 + 20 + # Check that user-defined features are present (in addition to default features) + assert "action" in merged.meta.features + assert "observation.state" in merged.meta.features + + def test_add_features_preserves_existing_stats(sample_dataset, tmp_path): """Test that adding a feature preserves existing stats.""" num_frames = sample_dataset.meta.total_frames