diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py index 7f481b9ca7c..58ed64420ca 100644 --- a/src/lerobot/configs/default.py +++ b/src/lerobot/configs/default.py @@ -27,7 +27,8 @@ class DatasetConfig: # "dataset_index" into the returned item. The index mapping is made according to the order in which the # datasets are provided. repo_id: str - # Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id. + # Root directory for a concrete local dataset tree (e.g. 'dataset/path'). If None, local datasets are + # looked up under $HF_LEROBOT_HOME/repo_id and Hub downloads use a revision-safe cache under $HF_LEROBOT_HOME/hub. root: str | None = None episodes: list[int] | None = None image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig) diff --git a/src/lerobot/datasets/dataset_metadata.py b/src/lerobot/datasets/dataset_metadata.py index a43ba07b472..ef2c8675f9b 100644 --- a/src/lerobot/datasets/dataset_metadata.py +++ b/src/lerobot/datasets/dataset_metadata.py @@ -48,7 +48,7 @@ update_chunk_file_indices, ) from lerobot.datasets.video_utils import get_video_info -from lerobot.utils.constants import HF_LEROBOT_HOME +from lerobot.utils.constants import HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE CODEBASE_VERSION = "v3.0" @@ -77,8 +77,12 @@ def __init__( Args: repo_id: Repository identifier (e.g. ``'lerobot/aloha_sim'``). - root: Local directory for the dataset. Defaults to - ``$HF_LEROBOT_HOME/{repo_id}``. + root: Local directory for the dataset. When provided, Hub downloads + are materialized directly into this directory. When omitted, + existing local datasets are still looked up under + ``$HF_LEROBOT_HOME/{repo_id}``, but Hub downloads use a + revision-safe snapshot cache under + ``$HF_LEROBOT_HOME/hub``. revision: Git revision (branch, tag, or commit hash). Defaults to the current codebase version. force_cache_sync: If ``True``, re-download metadata from the Hub @@ -88,7 +92,8 @@ def __init__( """ self.repo_id = repo_id self.revision = revision if revision else CODEBASE_VERSION - self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id + self._requested_root = Path(root) if root is not None else None + self.root = self._requested_root if self._requested_root is not None else HF_LEROBOT_HOME / repo_id self._pq_writer = None self.latest_episode = None self._metadata_buffer: list[dict] = [] @@ -96,17 +101,23 @@ def __init__( self._finalized = False try: - if force_cache_sync: + if force_cache_sync or ( + self._requested_root is None and self._has_legacy_hub_download_metadata(self.root) + ): raise FileNotFoundError self._load_metadata() except (FileNotFoundError, NotADirectoryError): if is_valid_version(self.revision): self.revision = get_safe_version(self.repo_id, self.revision) - (self.root / "meta").mkdir(exist_ok=True, parents=True) self._pull_from_repo(allow_patterns="meta/") self._load_metadata() + @staticmethod + def _has_legacy_hub_download_metadata(root: Path) -> bool: + """Return True when ``root`` looks like a legacy Hub ``local_dir`` mirror.""" + return (root / ".cache" / "huggingface" / "download").exists() + def _flush_metadata_buffer(self) -> None: """Write all buffered episode metadata to parquet file.""" if not hasattr(self, "_metadata_buffer") or len(self._metadata_buffer) == 0: @@ -178,14 +189,29 @@ def _pull_from_repo( allow_patterns: list[str] | str | None = None, ignore_patterns: list[str] | str | None = None, ) -> None: + if self._requested_root is None: + self.root = Path( + snapshot_download( + self.repo_id, + repo_type="dataset", + revision=self.revision, + cache_dir=HF_LEROBOT_HUB_CACHE, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + ) + ) + return + + self._requested_root.mkdir(exist_ok=True, parents=True) snapshot_download( self.repo_id, repo_type="dataset", revision=self.revision, - local_dir=self.root, + local_dir=self._requested_root, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, ) + self.root = self._requested_root @property def url_root(self) -> str: diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index cba0c1cbabc..a9b8f73c336 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -37,7 +37,7 @@ get_safe_default_codec, resolve_vcodec, ) -from lerobot.utils.constants import HF_LEROBOT_HOME +from lerobot.utils.constants import HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE logger = logging.getLogger(__name__) @@ -144,10 +144,11 @@ def __init__( Args: repo_id (str): This is the repo id that will be used to fetch the dataset. - root (Path | None, optional): Local directory where the dataset will be downloaded and - stored. If set, all dataset files will be stored directly under this path. If not set, the - dataset files will be stored under $HF_LEROBOT_HOME/repo_id (configurable via the - HF_LEROBOT_HOME environment variable). + root (Path | None, optional): Local directory where the dataset will be read from or downloaded + into. If set, all dataset files are materialized directly under this path. If not set, + existing local datasets are still looked up under ``$HF_LEROBOT_HOME/{repo_id}``, but Hub + downloads use a revision-safe snapshot cache under + ``$HF_LEROBOT_HOME/hub``. episodes (list[int] | None, optional): If specified, this will only load episodes specified by their episode_index in this list. Defaults to None. image_transforms (Callable | None, optional): You can pass standard v2 image transforms from @@ -190,7 +191,8 @@ def __init__( """ super().__init__() self.repo_id = repo_id - self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id + self._requested_root = Path(root) if root else None + self.root = self._requested_root if self._requested_root is not None else HF_LEROBOT_HOME / repo_id self.image_transforms = image_transforms self.delta_timestamps = delta_timestamps self.episodes = episodes @@ -201,12 +203,15 @@ def __init__( self._vcodec = resolve_vcodec(vcodec) self._encoder_threads = encoder_threads - self.root.mkdir(exist_ok=True, parents=True) + if self._requested_root is not None: + self.root.mkdir(exist_ok=True, parents=True) # Load metadata self.meta = LeRobotDatasetMetadata( - self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync + self.repo_id, self._requested_root, self.revision, force_cache_sync=force_cache_sync ) + self.root = self.meta.root + self.revision = self.meta.revision # Create reader (hf_dataset loaded below) self.reader = DatasetReader( @@ -556,14 +561,32 @@ def _download(self, download_videos: bool = True) -> None: if self.episodes is not None: # Reader is guaranteed to exist here (created in __init__ before _download) files = self.reader.get_episodes_file_paths() - snapshot_download( - self.repo_id, - repo_type="dataset", - revision=self.revision, - local_dir=self.root, - allow_patterns=files, - ignore_patterns=ignore_patterns, - ) + + if self._requested_root is None: + self.root = Path( + snapshot_download( + self.repo_id, + repo_type="dataset", + revision=self.revision, + cache_dir=HF_LEROBOT_HUB_CACHE, + allow_patterns=files, + ignore_patterns=ignore_patterns, + ) + ) + else: + self._requested_root.mkdir(exist_ok=True, parents=True) + snapshot_download( + self.repo_id, + repo_type="dataset", + revision=self.revision, + local_dir=self._requested_root, + allow_patterns=files, + ignore_patterns=ignore_patterns, + ) + self.root = self._requested_root + + self.meta.root = self.root + self.reader._root = self.root # ── Class constructors ──────────────────────────────────────────── @@ -635,6 +658,7 @@ def create( metadata_buffer_size=metadata_buffer_size, ) obj.repo_id = obj.meta.repo_id + obj._requested_root = obj.meta.root obj.root = obj.meta.root obj.revision = None obj.tolerance_s = tolerance_s @@ -719,7 +743,8 @@ def resume( vcodec = resolve_vcodec(vcodec) obj = cls.__new__(cls) obj.repo_id = repo_id - obj.root = Path(root) if root else HF_LEROBOT_HOME / repo_id + obj._requested_root = Path(root) if root else HF_LEROBOT_HOME / repo_id + obj.root = obj._requested_root obj.root.mkdir(exist_ok=True, parents=True) obj.revision = revision if revision else CODEBASE_VERSION obj.tolerance_s = tolerance_s @@ -733,7 +758,7 @@ def resume( # Load metadata obj.meta = LeRobotDatasetMetadata( - obj.repo_id, obj.root, obj.revision, force_cache_sync=force_cache_sync + obj.repo_id, obj._requested_root, obj.revision, force_cache_sync=force_cache_sync ) # Reader is lazily created on first access (write-only mode) diff --git a/src/lerobot/datasets/streaming_dataset.py b/src/lerobot/datasets/streaming_dataset.py index 62e00558a69..1767cc79d13 100644 --- a/src/lerobot/datasets/streaming_dataset.py +++ b/src/lerobot/datasets/streaming_dataset.py @@ -255,7 +255,9 @@ def __init__( Args: repo_id (str): This is the repo id that will be used to fetch the dataset. - root (Path | None, optional): Local directory to use for downloading/writing files. + root (Path | None, optional): Local directory to use for local datasets. When omitted, Hub + metadata is resolved through a revision-safe snapshot cache under + ``$HF_LEROBOT_HOME/hub``. episodes (list[int] | None, optional): If specified, this will only load episodes specified by their episode_index in this list. image_transforms (Callable | None, optional): Transform to apply to image data. @@ -271,7 +273,8 @@ def __init__( """ super().__init__() self.repo_id = repo_id - self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id + self._requested_root = Path(root) if root else None + self.root = self._requested_root if self._requested_root is not None else HF_LEROBOT_HOME / repo_id self.streaming_from_local = root is not None self.image_transforms = image_transforms @@ -288,12 +291,15 @@ def __init__( # We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown) self.video_decoder_cache = None - self.root.mkdir(exist_ok=True, parents=True) + if self._requested_root is not None: + self.root.mkdir(exist_ok=True, parents=True) # Load metadata self.meta = LeRobotDatasetMetadata( - self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync + self.repo_id, self._requested_root, self.revision, force_cache_sync=force_cache_sync ) + self.root = self.meta.root + self.revision = self.meta.revision # Check version check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION) diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py index ecd54844c98..09957513a71 100644 --- a/src/lerobot/utils/constants.py +++ b/src/lerobot/utils/constants.py @@ -65,6 +65,7 @@ # cache dir default_cache_path = Path(HF_HOME) / "lerobot" HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser() +HF_LEROBOT_HUB_CACHE = HF_LEROBOT_HOME / "hub" # calibration dir default_calibration_path = HF_LEROBOT_HOME / "calibration" diff --git a/tests/datasets/test_lerobot_dataset.py b/tests/datasets/test_lerobot_dataset.py index d7ce54a1594..38a04e3a3b7 100644 --- a/tests/datasets/test_lerobot_dataset.py +++ b/tests/datasets/test_lerobot_dataset.py @@ -19,9 +19,15 @@ property delegation, and the full create-record-finalize-read lifecycle. """ +from pathlib import Path +from unittest.mock import Mock + import pytest import torch +import lerobot.datasets.dataset_metadata as dataset_metadata_module +import lerobot.datasets.lerobot_dataset as lerobot_dataset_module +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata from lerobot.datasets.dataset_reader import DatasetReader from lerobot.datasets.dataset_writer import DatasetWriter from lerobot.datasets.lerobot_dataset import LeRobotDataset @@ -30,12 +36,69 @@ SIMPLE_FEATURES = { "state": {"dtype": "float32", "shape": (2,), "names": None}, } +SNAPSHOT_MAIN_FEATURES = { + **SIMPLE_FEATURES, + "test": {"dtype": "float32", "shape": (2,), "names": None}, +} def _make_frame(task: str = "Dummy task") -> dict: return {"task": task, "state": torch.randn(2)} +def _set_default_cache_root(monkeypatch: pytest.MonkeyPatch, cache_root: Path) -> None: + monkeypatch.setattr(dataset_metadata_module, "HF_LEROBOT_HOME", cache_root) + monkeypatch.setattr(dataset_metadata_module, "HF_LEROBOT_HUB_CACHE", cache_root / "hub") + monkeypatch.setattr(lerobot_dataset_module, "HF_LEROBOT_HOME", cache_root) + + +def _write_dataset_tree( + root: Path, + *, + motor_features: dict[str, dict], + info_factory, + stats_factory, + tasks_factory, + episodes_factory, + hf_dataset_factory, + create_info, + create_stats, + create_tasks, + create_episodes, + create_hf_dataset, +) -> None: + root.mkdir(parents=True, exist_ok=True) + info = info_factory( + total_episodes=1, + total_frames=3, + total_tasks=1, + use_videos=False, + motor_features=motor_features, + camera_features={}, + ) + tasks = tasks_factory(total_tasks=1) + episodes = episodes_factory( + features=info["features"], + fps=info["fps"], + total_episodes=1, + total_frames=3, + tasks=tasks, + ) + stats = stats_factory(features=info["features"]) + hf_dataset = hf_dataset_factory( + features=info["features"], + tasks=tasks, + episodes=episodes, + fps=info["fps"], + ) + + create_info(root, info) + create_stats(root, stats) + create_tasks(root, tasks) + create_episodes(root, episodes) + create_hf_dataset(root, hf_dataset) + + # ── Read-only mode (via __init__) ──────────────────────────────────── @@ -75,6 +138,190 @@ def test_len_matches_num_frames(tmp_path, lerobot_dataset_factory): assert len(dataset) == dataset.num_frames +def test_metadata_without_root_uses_hub_cache_snapshot_download( + tmp_path, + info_factory, + stats_factory, + tasks_factory, + episodes_factory, + hf_dataset_factory, + create_info, + create_stats, + create_tasks, + create_episodes, + create_hf_dataset, + monkeypatch, +): + """Metadata refresh uses the dedicated Hub cache instead of a shared local_dir mirror.""" + repo_id = DUMMY_REPO_ID + cache_root = tmp_path / "lerobot_cache" + snapshot_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main" + _write_dataset_tree( + snapshot_root, + motor_features=SNAPSHOT_MAIN_FEATURES, + info_factory=info_factory, + stats_factory=stats_factory, + tasks_factory=tasks_factory, + episodes_factory=episodes_factory, + hf_dataset_factory=hf_dataset_factory, + create_info=create_info, + create_stats=create_stats, + create_tasks=create_tasks, + create_episodes=create_episodes, + create_hf_dataset=create_hf_dataset, + ) + + _set_default_cache_root(monkeypatch, cache_root) + snapshot_download = Mock(return_value=str(snapshot_root)) + monkeypatch.setattr(dataset_metadata_module, "snapshot_download", snapshot_download) + + meta = LeRobotDatasetMetadata(repo_id=repo_id, revision="main", force_cache_sync=True) + + assert meta.root == snapshot_root + assert snapshot_download.call_count == 1 + assert snapshot_download.call_args.args == (repo_id,) + assert snapshot_download.call_args.kwargs == { + "repo_type": "dataset", + "revision": "main", + "cache_dir": cache_root / "hub", + "allow_patterns": "meta/", + "ignore_patterns": None, + } + + +def test_without_root_reads_different_revisions_from_distinct_snapshot_roots( + tmp_path, + info_factory, + stats_factory, + tasks_factory, + episodes_factory, + hf_dataset_factory, + create_info, + create_stats, + create_tasks, + create_episodes, + create_hf_dataset, + monkeypatch, +): + """Different revisions resolve to different on-disk snapshot roots.""" + repo_id = DUMMY_REPO_ID + old_revision = "b59010db93eb6cc3cf06ef2f7cae1bbe62b726d9" + cache_root = tmp_path / "lerobot_cache" + main_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main" + old_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-old" + + _write_dataset_tree( + main_root, + motor_features=SNAPSHOT_MAIN_FEATURES, + info_factory=info_factory, + stats_factory=stats_factory, + tasks_factory=tasks_factory, + episodes_factory=episodes_factory, + hf_dataset_factory=hf_dataset_factory, + create_info=create_info, + create_stats=create_stats, + create_tasks=create_tasks, + create_episodes=create_episodes, + create_hf_dataset=create_hf_dataset, + ) + _write_dataset_tree( + old_root, + motor_features=SIMPLE_FEATURES, + info_factory=info_factory, + stats_factory=stats_factory, + tasks_factory=tasks_factory, + episodes_factory=episodes_factory, + hf_dataset_factory=hf_dataset_factory, + create_info=create_info, + create_stats=create_stats, + create_tasks=create_tasks, + create_episodes=create_episodes, + create_hf_dataset=create_hf_dataset, + ) + + _set_default_cache_root(monkeypatch, cache_root) + snapshot_roots = { + "main": main_root, + old_revision: old_root, + } + snapshot_download = Mock(side_effect=lambda repo_id, **kwargs: str(snapshot_roots[kwargs["revision"]])) + monkeypatch.setattr(dataset_metadata_module, "snapshot_download", snapshot_download) + + main_dataset = LeRobotDataset(repo_id=repo_id, revision="main", download_videos=False) + old_dataset = LeRobotDataset(repo_id=repo_id, revision=old_revision, download_videos=False) + + assert main_dataset.root == main_root + assert old_dataset.root == old_root + assert "test" in main_dataset.hf_dataset.column_names + assert "test" not in old_dataset.hf_dataset.column_names + assert snapshot_download.call_count == 2 + for download_call in snapshot_download.call_args_list: + assert download_call.kwargs["cache_dir"] == cache_root / "hub" + assert download_call.kwargs["allow_patterns"] == "meta/" + assert "local_dir" not in download_call.kwargs + + +def test_metadata_without_root_ignores_legacy_local_dir_cache( + tmp_path, + info_factory, + stats_factory, + tasks_factory, + episodes_factory, + hf_dataset_factory, + create_info, + create_stats, + create_tasks, + create_episodes, + create_hf_dataset, + monkeypatch, +): + """Legacy local-dir mirrors are bypassed in favor of revision-safe snapshots.""" + repo_id = DUMMY_REPO_ID + cache_root = tmp_path / "lerobot_cache" + legacy_root = cache_root / repo_id + snapshot_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main" + + _write_dataset_tree( + legacy_root, + motor_features=SIMPLE_FEATURES, + info_factory=info_factory, + stats_factory=stats_factory, + tasks_factory=tasks_factory, + episodes_factory=episodes_factory, + hf_dataset_factory=hf_dataset_factory, + create_info=create_info, + create_stats=create_stats, + create_tasks=create_tasks, + create_episodes=create_episodes, + create_hf_dataset=create_hf_dataset, + ) + (legacy_root / ".cache" / "huggingface" / "download").mkdir(parents=True, exist_ok=True) + _write_dataset_tree( + snapshot_root, + motor_features=SNAPSHOT_MAIN_FEATURES, + info_factory=info_factory, + stats_factory=stats_factory, + tasks_factory=tasks_factory, + episodes_factory=episodes_factory, + hf_dataset_factory=hf_dataset_factory, + create_info=create_info, + create_stats=create_stats, + create_tasks=create_tasks, + create_episodes=create_episodes, + create_hf_dataset=create_hf_dataset, + ) + + _set_default_cache_root(monkeypatch, cache_root) + snapshot_download = Mock(return_value=str(snapshot_root)) + monkeypatch.setattr(dataset_metadata_module, "snapshot_download", snapshot_download) + + meta = LeRobotDatasetMetadata(repo_id=repo_id, revision="main") + + assert meta.root == snapshot_root + assert "test" in meta.features + assert snapshot_download.call_count == 1 + + # ── Write-only mode (via create()) ──────────────────────────────────