diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py index 7f481b9ca7c..58ed64420ca 100644 --- a/src/lerobot/configs/default.py +++ b/src/lerobot/configs/default.py @@ -27,7 +27,8 @@ class DatasetConfig: # "dataset_index" into the returned item. The index mapping is made according to the order in which the # datasets are provided. repo_id: str - # Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id. + # Root directory for a concrete local dataset tree (e.g. 'dataset/path'). If None, local datasets are + # looked up under $HF_LEROBOT_HOME/repo_id and Hub downloads use a revision-safe cache under $HF_LEROBOT_HOME/hub. root: str | None = None episodes: list[int] | None = None image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig) diff --git a/src/lerobot/datasets/dataset_metadata.py b/src/lerobot/datasets/dataset_metadata.py index a43ba07b472..65dbc9c4a18 100644 --- a/src/lerobot/datasets/dataset_metadata.py +++ b/src/lerobot/datasets/dataset_metadata.py @@ -44,11 +44,12 @@ check_version_compatibility, flatten_dict, get_safe_version, + has_legacy_hub_download_metadata, is_valid_version, update_chunk_file_indices, ) from lerobot.datasets.video_utils import get_video_info -from lerobot.utils.constants import HF_LEROBOT_HOME +from lerobot.utils.constants import HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE CODEBASE_VERSION = "v3.0" @@ -77,8 +78,12 @@ def __init__( Args: repo_id: Repository identifier (e.g. ``'lerobot/aloha_sim'``). - root: Local directory for the dataset. Defaults to - ``$HF_LEROBOT_HOME/{repo_id}``. + root: Local directory for the dataset. When provided, Hub downloads + are materialized directly into this directory. When omitted, + existing local datasets are still looked up under + ``$HF_LEROBOT_HOME/{repo_id}``, but Hub downloads use a + revision-safe snapshot cache under + ``$HF_LEROBOT_HOME/hub``. revision: Git revision (branch, tag, or commit hash). Defaults to the current codebase version. force_cache_sync: If ``True``, re-download metadata from the Hub @@ -88,7 +93,8 @@ def __init__( """ self.repo_id = repo_id self.revision = revision if revision else CODEBASE_VERSION - self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id + self._requested_root = Path(root) if root is not None else None + self.root = self._requested_root if self._requested_root is not None else HF_LEROBOT_HOME / repo_id self._pq_writer = None self.latest_episode = None self._metadata_buffer: list[dict] = [] @@ -96,14 +102,15 @@ def __init__( self._finalized = False try: - if force_cache_sync: + if force_cache_sync or ( + self._requested_root is None and has_legacy_hub_download_metadata(self.root) + ): raise FileNotFoundError self._load_metadata() except (FileNotFoundError, NotADirectoryError): if is_valid_version(self.revision): self.revision = get_safe_version(self.repo_id, self.revision) - (self.root / "meta").mkdir(exist_ok=True, parents=True) self._pull_from_repo(allow_patterns="meta/") self._load_metadata() @@ -178,14 +185,29 @@ def _pull_from_repo( allow_patterns: list[str] | str | None = None, ignore_patterns: list[str] | str | None = None, ) -> None: + if self._requested_root is None: + self.root = Path( + snapshot_download( + self.repo_id, + repo_type="dataset", + revision=self.revision, + cache_dir=HF_LEROBOT_HUB_CACHE, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + ) + ) + return + + self._requested_root.mkdir(exist_ok=True, parents=True) snapshot_download( self.repo_id, repo_type="dataset", revision=self.revision, - local_dir=self.root, + local_dir=self._requested_root, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, ) + self.root = self._requested_root @property def url_root(self) -> str: @@ -593,7 +615,8 @@ def create( """ obj = cls.__new__(cls) obj.repo_id = repo_id - obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id + obj._requested_root = Path(root) if root is not None else None + obj.root = obj._requested_root if obj._requested_root is not None else HF_LEROBOT_HOME / repo_id obj.root.mkdir(parents=True, exist_ok=False) diff --git a/src/lerobot/datasets/dataset_reader.py b/src/lerobot/datasets/dataset_reader.py index 0233a3cf6b5..3720a50847a 100644 --- a/src/lerobot/datasets/dataset_reader.py +++ b/src/lerobot/datasets/dataset_reader.py @@ -68,7 +68,7 @@ def __init__( visual features. """ self._meta = meta - self._root = root + self.root = root self.episodes = episodes self._tolerance_s = tolerance_s self._video_backend = video_backend @@ -125,7 +125,7 @@ def num_episodes(self) -> int: def _load_hf_dataset(self) -> datasets.Dataset: """hf_dataset contains all the observations, states, actions, rewards, etc.""" features = get_hf_features_from_features(self._meta.features) - hf_dataset = load_nested_dataset(self._root / "data", features=features, episodes=self.episodes) + hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes) hf_dataset.set_transform(hf_transform_to_torch) return hf_dataset @@ -150,7 +150,7 @@ def _check_cached_episodes_sufficient(self) -> bool: if len(self._meta.video_keys) > 0: for ep_idx in requested_episodes: for vid_key in self._meta.video_keys: - video_path = self._root / self._meta.get_video_file_path(ep_idx, vid_key) + video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key) if not video_path.exists(): return False @@ -240,7 +240,7 @@ def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) - from_timestamp = ep[f"videos/{vid_key}/from_timestamp"] shifted_query_ts = [from_timestamp + ts for ts in query_ts] - video_path = self._root / self._meta.get_video_file_path(ep_idx, vid_key) + video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key) frames = decode_video_frames(video_path, shifted_query_ts, self._tolerance_s, self._video_backend) item[vid_key] = frames.squeeze(0) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index cba0c1cbabc..f719222fd65 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -37,7 +37,7 @@ get_safe_default_codec, resolve_vcodec, ) -from lerobot.utils.constants import HF_LEROBOT_HOME +from lerobot.utils.constants import HF_LEROBOT_HUB_CACHE logger = logging.getLogger(__name__) @@ -144,10 +144,11 @@ def __init__( Args: repo_id (str): This is the repo id that will be used to fetch the dataset. - root (Path | None, optional): Local directory where the dataset will be downloaded and - stored. If set, all dataset files will be stored directly under this path. If not set, the - dataset files will be stored under $HF_LEROBOT_HOME/repo_id (configurable via the - HF_LEROBOT_HOME environment variable). + root (Path | None, optional): Local directory where the dataset will be read from or downloaded + into. If set, all dataset files are materialized directly under this path. If not set, + existing local datasets are still looked up under ``$HF_LEROBOT_HOME/{repo_id}``, but Hub + downloads use a revision-safe snapshot cache under + ``$HF_LEROBOT_HOME/hub``. episodes (list[int] | None, optional): If specified, this will only load episodes specified by their episode_index in this list. Defaults to None. image_transforms (Callable | None, optional): You can pass standard v2 image transforms from @@ -190,7 +191,7 @@ def __init__( """ super().__init__() self.repo_id = repo_id - self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id + self._requested_root = Path(root) if root else None self.image_transforms = image_transforms self.delta_timestamps = delta_timestamps self.episodes = episodes @@ -201,12 +202,15 @@ def __init__( self._vcodec = resolve_vcodec(vcodec) self._encoder_threads = encoder_threads - self.root.mkdir(exist_ok=True, parents=True) + if self._requested_root is not None: + self._requested_root.mkdir(exist_ok=True, parents=True) - # Load metadata + # Load metadata (sets self.root once from the resolved metadata root) self.meta = LeRobotDatasetMetadata( - self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync + self.repo_id, self._requested_root, self.revision, force_cache_sync=force_cache_sync ) + self.root = self.meta.root + self.revision = self.meta.revision # Create reader (hf_dataset loaded below) self.reader = DatasetReader( @@ -556,14 +560,33 @@ def _download(self, download_videos: bool = True) -> None: if self.episodes is not None: # Reader is guaranteed to exist here (created in __init__ before _download) files = self.reader.get_episodes_file_paths() - snapshot_download( - self.repo_id, - repo_type="dataset", - revision=self.revision, - local_dir=self.root, - allow_patterns=files, - ignore_patterns=ignore_patterns, - ) + + if self._requested_root is None: + self.meta.root = Path( + snapshot_download( + self.repo_id, + repo_type="dataset", + revision=self.revision, + cache_dir=HF_LEROBOT_HUB_CACHE, + allow_patterns=files, + ignore_patterns=ignore_patterns, + ) + ) + else: + self._requested_root.mkdir(exist_ok=True, parents=True) + snapshot_download( + self.repo_id, + repo_type="dataset", + revision=self.revision, + local_dir=self._requested_root, + allow_patterns=files, + ignore_patterns=ignore_patterns, + ) + self.meta.root = self._requested_root + + # Propagate resolved root from metadata (single source of truth) + self.root = self.meta.root + self.reader.root = self.meta.root # ── Class constructors ──────────────────────────────────────────── @@ -635,6 +658,7 @@ def create( metadata_buffer_size=metadata_buffer_size, ) obj.repo_id = obj.meta.repo_id + obj._requested_root = obj.meta.root obj.root = obj.meta.root obj.revision = None obj.tolerance_s = tolerance_s @@ -695,8 +719,10 @@ def resume( Args: repo_id: Repository identifier of the existing dataset. - root: Local directory of the dataset. Defaults to - ``$HF_LEROBOT_HOME/{repo_id}``. + root: Local directory of the dataset. When provided, Hub downloads + are materialized directly into this directory. When omitted, + Hub downloads use a revision-safe snapshot cache under + ``$HF_LEROBOT_HOME/hub``. tolerance_s: Timestamp synchronization tolerance in seconds. revision: Git revision (branch, tag, or commit hash). Defaults to current codebase version tag. @@ -716,11 +742,16 @@ def resume( Returns: A :class:`LeRobotDataset` in write mode, ready to append episodes. """ + if not root: + raise ValueError( + "resume() requires an explicit 'root' directory because it creates a DatasetWriter. " + "Writing into the revision-safe Hub snapshot cache (used when root=None) would corrupt " + "the shared cache. Please provide a local directory path." + ) vcodec = resolve_vcodec(vcodec) obj = cls.__new__(cls) obj.repo_id = repo_id - obj.root = Path(root) if root else HF_LEROBOT_HOME / repo_id - obj.root.mkdir(exist_ok=True, parents=True) + obj._requested_root = Path(root) obj.revision = revision if revision else CODEBASE_VERSION obj.tolerance_s = tolerance_s obj.image_transforms = None @@ -731,10 +762,14 @@ def resume( obj._vcodec = vcodec obj._encoder_threads = encoder_threads - # Load metadata + if obj._requested_root is not None: + obj._requested_root.mkdir(exist_ok=True, parents=True) + + # Load metadata (revision-safe when root is not provided) obj.meta = LeRobotDatasetMetadata( - obj.repo_id, obj.root, obj.revision, force_cache_sync=force_cache_sync + obj.repo_id, obj._requested_root, obj.revision, force_cache_sync=force_cache_sync ) + obj.root = obj.meta.root # Reader is lazily created on first access (write-only mode) obj.reader = None diff --git a/src/lerobot/datasets/streaming_dataset.py b/src/lerobot/datasets/streaming_dataset.py index 62e00558a69..1767cc79d13 100644 --- a/src/lerobot/datasets/streaming_dataset.py +++ b/src/lerobot/datasets/streaming_dataset.py @@ -255,7 +255,9 @@ def __init__( Args: repo_id (str): This is the repo id that will be used to fetch the dataset. - root (Path | None, optional): Local directory to use for downloading/writing files. + root (Path | None, optional): Local directory to use for local datasets. When omitted, Hub + metadata is resolved through a revision-safe snapshot cache under + ``$HF_LEROBOT_HOME/hub``. episodes (list[int] | None, optional): If specified, this will only load episodes specified by their episode_index in this list. image_transforms (Callable | None, optional): Transform to apply to image data. @@ -271,7 +273,8 @@ def __init__( """ super().__init__() self.repo_id = repo_id - self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id + self._requested_root = Path(root) if root else None + self.root = self._requested_root if self._requested_root is not None else HF_LEROBOT_HOME / repo_id self.streaming_from_local = root is not None self.image_transforms = image_transforms @@ -288,12 +291,15 @@ def __init__( # We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown) self.video_decoder_cache = None - self.root.mkdir(exist_ok=True, parents=True) + if self._requested_root is not None: + self.root.mkdir(exist_ok=True, parents=True) # Load metadata self.meta = LeRobotDatasetMetadata( - self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync + self.repo_id, self._requested_root, self.revision, force_cache_sync=force_cache_sync ) + self.root = self.meta.root + self.revision = self.meta.revision # Check version check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION) diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 2e1d360f90a..36e7934edb5 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -18,6 +18,7 @@ import json import logging from collections.abc import Iterator +from pathlib import Path from typing import Any import datasets @@ -101,6 +102,18 @@ def __init__(self, repo_id: str, version: packaging.version.Version): } +def has_legacy_hub_download_metadata(root: Path) -> bool: + """Return ``True`` when *root* looks like a legacy Hub ``local_dir`` mirror. + + ``snapshot_download(local_dir=...)`` stores lightweight metadata under + ``/.cache/huggingface/download/``. The presence of this + directory is a reliable indicator that the dataset was downloaded with + the old non-revision-safe ``local_dir`` mode and should be re-fetched + through the snapshot cache instead. + """ + return (root / ".cache" / "huggingface" / "download").exists() + + def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -> tuple[int, int]: if file_idx == chunks_size - 1: file_idx = 0 diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py index ecd54844c98..fd10cab35b0 100644 --- a/src/lerobot/utils/constants.py +++ b/src/lerobot/utils/constants.py @@ -65,6 +65,10 @@ # cache dir default_cache_path = Path(HF_HOME) / "lerobot" HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser() +# LeRobot's own revision-safe Hub cache (NOT the system-wide ~/.cache/huggingface/hub/). +# Used as the ``cache_dir`` argument to ``snapshot_download`` so that different +# dataset revisions are stored in isolated snapshot directories. +HF_LEROBOT_HUB_CACHE = HF_LEROBOT_HOME / "hub" # calibration dir default_calibration_path = HF_LEROBOT_HOME / "calibration" diff --git a/tests/datasets/test_lerobot_dataset.py b/tests/datasets/test_lerobot_dataset.py index d7ce54a1594..a8aa47ed295 100644 --- a/tests/datasets/test_lerobot_dataset.py +++ b/tests/datasets/test_lerobot_dataset.py @@ -19,9 +19,15 @@ property delegation, and the full create-record-finalize-read lifecycle. """ +from pathlib import Path +from unittest.mock import Mock + import pytest import torch +import lerobot.datasets.dataset_metadata as dataset_metadata_module +import lerobot.datasets.lerobot_dataset as lerobot_dataset_module +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata from lerobot.datasets.dataset_reader import DatasetReader from lerobot.datasets.dataset_writer import DatasetWriter from lerobot.datasets.lerobot_dataset import LeRobotDataset @@ -30,12 +36,69 @@ SIMPLE_FEATURES = { "state": {"dtype": "float32", "shape": (2,), "names": None}, } +SNAPSHOT_MAIN_FEATURES = { + **SIMPLE_FEATURES, + "test": {"dtype": "float32", "shape": (2,), "names": None}, +} def _make_frame(task: str = "Dummy task") -> dict: return {"task": task, "state": torch.randn(2)} +def _set_default_cache_root(monkeypatch: pytest.MonkeyPatch, cache_root: Path) -> None: + monkeypatch.setattr(dataset_metadata_module, "HF_LEROBOT_HOME", cache_root) + monkeypatch.setattr(dataset_metadata_module, "HF_LEROBOT_HUB_CACHE", cache_root / "hub") + monkeypatch.setattr(lerobot_dataset_module, "HF_LEROBOT_HUB_CACHE", cache_root / "hub") + + +def _write_dataset_tree( + root: Path, + *, + motor_features: dict[str, dict], + info_factory, + stats_factory, + tasks_factory, + episodes_factory, + hf_dataset_factory, + create_info, + create_stats, + create_tasks, + create_episodes, + create_hf_dataset, +) -> None: + root.mkdir(parents=True, exist_ok=True) + info = info_factory( + total_episodes=1, + total_frames=3, + total_tasks=1, + use_videos=False, + motor_features=motor_features, + camera_features={}, + ) + tasks = tasks_factory(total_tasks=1) + episodes = episodes_factory( + features=info["features"], + fps=info["fps"], + total_episodes=1, + total_frames=3, + tasks=tasks, + ) + stats = stats_factory(features=info["features"]) + hf_dataset = hf_dataset_factory( + features=info["features"], + tasks=tasks, + episodes=episodes, + fps=info["fps"], + ) + + create_info(root, info) + create_stats(root, stats) + create_tasks(root, tasks) + create_episodes(root, episodes) + create_hf_dataset(root, hf_dataset) + + # ── Read-only mode (via __init__) ──────────────────────────────────── @@ -75,6 +138,261 @@ def test_len_matches_num_frames(tmp_path, lerobot_dataset_factory): assert len(dataset) == dataset.num_frames +def test_metadata_without_root_uses_hub_cache_snapshot_download( + tmp_path, + info_factory, + stats_factory, + tasks_factory, + episodes_factory, + hf_dataset_factory, + create_info, + create_stats, + create_tasks, + create_episodes, + create_hf_dataset, + monkeypatch, +): + """Metadata refresh uses the dedicated Hub cache instead of a shared local_dir mirror.""" + repo_id = DUMMY_REPO_ID + cache_root = tmp_path / "lerobot_cache" + snapshot_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main" + _write_dataset_tree( + snapshot_root, + motor_features=SNAPSHOT_MAIN_FEATURES, + info_factory=info_factory, + stats_factory=stats_factory, + tasks_factory=tasks_factory, + episodes_factory=episodes_factory, + hf_dataset_factory=hf_dataset_factory, + create_info=create_info, + create_stats=create_stats, + create_tasks=create_tasks, + create_episodes=create_episodes, + create_hf_dataset=create_hf_dataset, + ) + + _set_default_cache_root(monkeypatch, cache_root) + snapshot_download = Mock(return_value=str(snapshot_root)) + monkeypatch.setattr(dataset_metadata_module, "snapshot_download", snapshot_download) + + meta = LeRobotDatasetMetadata(repo_id=repo_id, revision="main", force_cache_sync=True) + + assert meta.root == snapshot_root + assert snapshot_download.call_count == 1 + assert snapshot_download.call_args.args == (repo_id,) + assert snapshot_download.call_args.kwargs == { + "repo_type": "dataset", + "revision": "main", + "cache_dir": cache_root / "hub", + "allow_patterns": "meta/", + "ignore_patterns": None, + } + + +def test_without_root_reads_different_revisions_from_distinct_snapshot_roots( + tmp_path, + info_factory, + stats_factory, + tasks_factory, + episodes_factory, + hf_dataset_factory, + create_info, + create_stats, + create_tasks, + create_episodes, + create_hf_dataset, + monkeypatch, +): + """Different revisions resolve to different on-disk snapshot roots.""" + repo_id = DUMMY_REPO_ID + old_revision = "b59010db93eb6cc3cf06ef2f7cae1bbe62b726d9" + cache_root = tmp_path / "lerobot_cache" + main_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main" + old_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-old" + + _write_dataset_tree( + main_root, + motor_features=SNAPSHOT_MAIN_FEATURES, + info_factory=info_factory, + stats_factory=stats_factory, + tasks_factory=tasks_factory, + episodes_factory=episodes_factory, + hf_dataset_factory=hf_dataset_factory, + create_info=create_info, + create_stats=create_stats, + create_tasks=create_tasks, + create_episodes=create_episodes, + create_hf_dataset=create_hf_dataset, + ) + _write_dataset_tree( + old_root, + motor_features=SIMPLE_FEATURES, + info_factory=info_factory, + stats_factory=stats_factory, + tasks_factory=tasks_factory, + episodes_factory=episodes_factory, + hf_dataset_factory=hf_dataset_factory, + create_info=create_info, + create_stats=create_stats, + create_tasks=create_tasks, + create_episodes=create_episodes, + create_hf_dataset=create_hf_dataset, + ) + + _set_default_cache_root(monkeypatch, cache_root) + snapshot_roots = { + "main": main_root, + old_revision: old_root, + } + meta_snapshot_download = Mock( + side_effect=lambda repo_id, **kwargs: str(snapshot_roots[kwargs["revision"]]) + ) + data_snapshot_download = Mock( + side_effect=lambda repo_id, **kwargs: str(snapshot_roots[kwargs["revision"]]) + ) + monkeypatch.setattr(dataset_metadata_module, "snapshot_download", meta_snapshot_download) + monkeypatch.setattr(lerobot_dataset_module, "snapshot_download", data_snapshot_download) + + main_dataset = LeRobotDataset( + repo_id=repo_id, revision="main", download_videos=False, force_cache_sync=True + ) + old_dataset = LeRobotDataset( + repo_id=repo_id, revision=old_revision, download_videos=False, force_cache_sync=True + ) + + assert main_dataset.root == main_root + assert old_dataset.root == old_root + assert "test" in main_dataset.hf_dataset.column_names + assert "test" not in old_dataset.hf_dataset.column_names + + # Metadata downloads use cache_dir, not local_dir + assert meta_snapshot_download.call_count == 2 + for download_call in meta_snapshot_download.call_args_list: + assert download_call.kwargs["cache_dir"] == cache_root / "hub" + assert "local_dir" not in download_call.kwargs + + # Data downloads also use cache_dir, not local_dir + assert data_snapshot_download.call_count == 2 + for download_call in data_snapshot_download.call_args_list: + assert download_call.kwargs["cache_dir"] == cache_root / "hub" + assert "local_dir" not in download_call.kwargs + + +def test_metadata_without_root_ignores_legacy_local_dir_cache( + tmp_path, + info_factory, + stats_factory, + tasks_factory, + episodes_factory, + hf_dataset_factory, + create_info, + create_stats, + create_tasks, + create_episodes, + create_hf_dataset, + monkeypatch, +): + """Legacy local-dir mirrors are bypassed in favor of revision-safe snapshots.""" + repo_id = DUMMY_REPO_ID + cache_root = tmp_path / "lerobot_cache" + legacy_root = cache_root / repo_id + snapshot_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main" + + _write_dataset_tree( + legacy_root, + motor_features=SIMPLE_FEATURES, + info_factory=info_factory, + stats_factory=stats_factory, + tasks_factory=tasks_factory, + episodes_factory=episodes_factory, + hf_dataset_factory=hf_dataset_factory, + create_info=create_info, + create_stats=create_stats, + create_tasks=create_tasks, + create_episodes=create_episodes, + create_hf_dataset=create_hf_dataset, + ) + (legacy_root / ".cache" / "huggingface" / "download").mkdir(parents=True, exist_ok=True) + _write_dataset_tree( + snapshot_root, + motor_features=SNAPSHOT_MAIN_FEATURES, + info_factory=info_factory, + stats_factory=stats_factory, + tasks_factory=tasks_factory, + episodes_factory=episodes_factory, + hf_dataset_factory=hf_dataset_factory, + create_info=create_info, + create_stats=create_stats, + create_tasks=create_tasks, + create_episodes=create_episodes, + create_hf_dataset=create_hf_dataset, + ) + + _set_default_cache_root(monkeypatch, cache_root) + snapshot_download = Mock(return_value=str(snapshot_root)) + monkeypatch.setattr(dataset_metadata_module, "snapshot_download", snapshot_download) + + meta = LeRobotDatasetMetadata(repo_id=repo_id, revision="main") + + assert meta.root == snapshot_root + assert "test" in meta.features + assert snapshot_download.call_count == 1 + + +def test_download_without_root_uses_hub_cache( + tmp_path, + info_factory, + stats_factory, + tasks_factory, + episodes_factory, + hf_dataset_factory, + create_info, + create_stats, + create_tasks, + create_episodes, + create_hf_dataset, + monkeypatch, +): + """LeRobotDataset._download() uses cache_dir (not local_dir) when root is not provided.""" + repo_id = DUMMY_REPO_ID + cache_root = tmp_path / "lerobot_cache" + snapshot_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main" + + # Pre-populate snapshot directory so metadata loads succeed, but leave + # data absent so that _download() is triggered. + _write_dataset_tree( + snapshot_root, + motor_features=SIMPLE_FEATURES, + info_factory=info_factory, + stats_factory=stats_factory, + tasks_factory=tasks_factory, + episodes_factory=episodes_factory, + hf_dataset_factory=hf_dataset_factory, + create_info=create_info, + create_stats=create_stats, + create_tasks=create_tasks, + create_episodes=create_episodes, + create_hf_dataset=create_hf_dataset, + ) + + _set_default_cache_root(monkeypatch, cache_root) + meta_snapshot_download = Mock(return_value=str(snapshot_root)) + monkeypatch.setattr(dataset_metadata_module, "snapshot_download", meta_snapshot_download) + + # Mock the data snapshot_download to return the same root (data already + # exists there from _write_dataset_tree). + data_snapshot_download = Mock(return_value=str(snapshot_root)) + monkeypatch.setattr(lerobot_dataset_module, "snapshot_download", data_snapshot_download) + + LeRobotDataset(repo_id=repo_id, revision="main", force_cache_sync=True) + + # _download() should have called snapshot_download with cache_dir + assert data_snapshot_download.call_count == 1 + call_kwargs = data_snapshot_download.call_args.kwargs + assert call_kwargs["cache_dir"] == cache_root / "hub" + assert "local_dir" not in call_kwargs + + # ── Write-only mode (via create()) ──────────────────────────────────