Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/lerobot/scripts/lerobot_edit_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,13 @@ def handle_merge(cfg: EditDatasetConfig) -> None:
raise ValueError("repo_id must be specified as the output repository for merged dataset")

logging.info(f"Loading {len(cfg.operation.repo_ids)} datasets to merge")
datasets = [LeRobotDataset(repo_id, root=cfg.root) for repo_id in cfg.operation.repo_ids]
# When root is provided, we need to construct the full path for each dataset
# by appending the repo_id to the root. When root is None, LeRobotDataset
# automatically uses HF_LEROBOT_HOME / repo_id.
datasets = [
LeRobotDataset(repo_id, root=Path(cfg.root) / repo_id if cfg.root else None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As defined in LeRobotDataset, datasets will be stored under root/repo_id, so we need to standardize the dataset location in this way.

related to #2316

Copy link
Author

@riochuong riochuong Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does that mean for merge to work correctly on custom local folder the only way is to use root=None and move data to default HF_LEROBOT_HOME (not too bad but need to remember to set this as venv) ??

for repo_id in cfg.operation.repo_ids
]

output_dir = Path(cfg.root) / cfg.repo_id if cfg.root else HF_LEROBOT_HOME / cfg.repo_id

Expand Down
217 changes: 216 additions & 1 deletion tests/datasets/test_dataset_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
remove_feature,
split_dataset,
)
from lerobot.scripts.lerobot_edit_dataset import convert_dataset_to_videos
from lerobot.scripts.lerobot_edit_dataset import EditDatasetConfig, MergeConfig, convert_dataset_to_videos, handle_merge


@pytest.fixture
Expand Down Expand Up @@ -850,6 +850,221 @@ def test_merge_preserves_stats(sample_dataset, tmp_path, empty_lerobot_dataset_f
assert "std" in merged.meta.stats[feature]


def test_handle_merge_with_custom_root(tmp_path, empty_lerobot_dataset_factory):
"""Test handle_merge() with custom root path (bug fix for custom root paths)."""
features = {
"action": {"dtype": "float32", "shape": (6,), "names": None},
"observation.state": {"dtype": "float32", "shape": (4,), "names": None},
"observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None},
}

# Create custom root directory
custom_root = tmp_path / "custom_datasets"
custom_root.mkdir()

# Create two datasets in custom root
dataset1 = empty_lerobot_dataset_factory(
root=custom_root / "dataset1",
features=features,
)
for ep_idx in range(2):
for _ in range(10):
frame = {
"action": np.random.randn(6).astype(np.float32),
"observation.state": np.random.randn(4).astype(np.float32),
"observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
"task": f"task_{ep_idx % 2}",
}
dataset1.add_frame(frame)
dataset1.save_episode()
dataset1.finalize()

dataset2 = empty_lerobot_dataset_factory(
root=custom_root / "dataset2",
features=features,
)
for ep_idx in range(3):
for _ in range(10):
frame = {
"action": np.random.randn(6).astype(np.float32),
"observation.state": np.random.randn(4).astype(np.float32),
"observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
"task": f"task_{ep_idx % 2}",
}
dataset2.add_frame(frame)
dataset2.save_episode()
dataset2.finalize()

# Test handle_merge with custom root
cfg = EditDatasetConfig(
repo_id="merged_dataset",
root=custom_root,
operation=MergeConfig(
type="merge",
repo_ids=["dataset1", "dataset2"],
),
)

with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(custom_root / "merged_dataset")
handle_merge(cfg)

# Verify merged dataset was created in correct location
merged_path = custom_root / "merged_dataset"
assert merged_path.exists()
assert (merged_path / "meta" / "info.json").exists()

# Load and verify merged dataset
from lerobot.datasets.lerobot_dataset import LeRobotDataset

merged = LeRobotDataset("merged_dataset", root=merged_path)
assert merged.meta.total_episodes == 5 # 2 + 3


def test_handle_merge_without_custom_root(tmp_path, empty_lerobot_dataset_factory):
"""Test handle_merge() without custom root (uses default HF_LEROBOT_HOME behavior)."""
features = {
"action": {"dtype": "float32", "shape": (6,), "names": None},
"observation.state": {"dtype": "float32", "shape": (4,), "names": None},
}

# Create datasets in tmp_path (simulating HF_LEROBOT_HOME)
dataset1 = empty_lerobot_dataset_factory(
root=tmp_path / "dataset1",
features=features,
)
for ep_idx in range(2):
for _ in range(5):
frame = {
"action": np.random.randn(6).astype(np.float32),
"observation.state": np.random.randn(4).astype(np.float32),
"task": f"task_{ep_idx}",
}
dataset1.add_frame(frame)
dataset1.save_episode()
dataset1.finalize()

dataset2 = empty_lerobot_dataset_factory(
root=tmp_path / "dataset2",
features=features,
)
for ep_idx in range(1):
for _ in range(5):
frame = {
"action": np.random.randn(6).astype(np.float32),
"observation.state": np.random.randn(4).astype(np.float32),
"task": f"task_{ep_idx}",
}
dataset2.add_frame(frame)
dataset2.save_episode()
dataset2.finalize()

# Test handle_merge with root=None (will use constructed paths from datasets)
# Note: We need to pass explicit root to match the datasets we created
cfg = EditDatasetConfig(
repo_id="merged_dataset",
root=tmp_path,
operation=MergeConfig(
type="merge",
repo_ids=["dataset1", "dataset2"],
),
)

with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "merged_dataset")
handle_merge(cfg)

# Verify merged dataset
merged_path = tmp_path / "merged_dataset"
assert merged_path.exists()

from lerobot.datasets.lerobot_dataset import LeRobotDataset

merged = LeRobotDataset("merged_dataset", root=merged_path)
assert merged.meta.total_episodes == 3 # 2 + 1


def test_handle_merge_custom_root_preserves_metadata(tmp_path, empty_lerobot_dataset_factory):
"""Test that handle_merge with custom root preserves dataset metadata."""
features = {
"action": {"dtype": "float32", "shape": (6,), "names": None},
"observation.state": {"dtype": "float32", "shape": (4,), "names": None},
}

custom_root = tmp_path / "datasets"
custom_root.mkdir()

# Create dataset with specific FPS
dataset1 = empty_lerobot_dataset_factory(
root=custom_root / "dataset1",
features=features,
fps=30,
)
for ep_idx in range(1):
for _ in range(10):
frame = {
"action": np.random.randn(6).astype(np.float32),
"observation.state": np.random.randn(4).astype(np.float32),
"task": f"task_{ep_idx}",
}
dataset1.add_frame(frame)
dataset1.save_episode()
dataset1.finalize()

dataset2 = empty_lerobot_dataset_factory(
root=custom_root / "dataset2",
features=features,
fps=30,
)
for ep_idx in range(2):
for _ in range(10):
frame = {
"action": np.random.randn(6).astype(np.float32),
"observation.state": np.random.randn(4).astype(np.float32),
"task": f"task_{ep_idx}",
}
dataset2.add_frame(frame)
dataset2.save_episode()
dataset2.finalize()

# Merge with custom root
cfg = EditDatasetConfig(
repo_id="merged_dataset",
root=custom_root,
operation=MergeConfig(
type="merge",
repo_ids=["dataset1", "dataset2"],
),
)

with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(custom_root / "merged_dataset")
handle_merge(cfg)

# Verify metadata preserved
from lerobot.datasets.lerobot_dataset import LeRobotDataset

merged = LeRobotDataset("merged_dataset", root=custom_root / "merged_dataset")
assert merged.meta.fps == 30
assert merged.meta.total_episodes == 3 # 1 + 2
assert merged.meta.total_frames == 30 # 10 + 20
# Check that user-defined features are present (in addition to default features)
assert "action" in merged.meta.features
assert "observation.state" in merged.meta.features


def test_add_features_preserves_existing_stats(sample_dataset, tmp_path):
"""Test that adding a feature preserves existing stats."""
num_frames = sample_dataset.meta.total_frames
Expand Down