Skip to content

Commit e1e7edb

Browse files
committed
fix sanity_check_dataset_robot_compatibility
1 parent 7ba6318 commit e1e7edb

File tree

2 files changed

+11
-24
lines changed

2 files changed

+11
-24
lines changed

lerobot/common/datasets/lerobot_dataset.py

+8-15
Original file line numberDiff line numberDiff line change
@@ -250,10 +250,6 @@ def write_video_info(self) -> None:
250250
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
251251
been encoded the same way. Also, this means it assumes the first episode exists.
252252
"""
253-
# TODO(rcadene): What should we do here?
254-
if "videos" not in self.info:
255-
self.info["videos"] = {}
256-
257253
for key in self.video_keys:
258254
if not self.features[key].get("info", None):
259255
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
@@ -278,8 +274,6 @@ def create(
278274
robot_type: str | None = None,
279275
features: dict | None = None,
280276
use_videos: bool = True,
281-
# tags: list[str] | None = None,
282-
# license_type: str | None = None,
283277
) -> "LeRobotDatasetMetadata":
284278
"""Creates metadata for a LeRobotDataset."""
285279
obj = cls.__new__(cls)
@@ -301,14 +295,11 @@ def create(
301295
"Dataset features must either come from a Robot or explicitly passed upon creation."
302296
)
303297
else:
298+
# TODO(aliberts, rcadene): implement sanity check for features
304299
features = {**features, **DEFAULT_FEATURES}
305300

306-
# TODO(rcadene): implement sanity check for features
307-
308301
obj.tasks, obj.stats, obj.episodes = {}, {}, []
309302
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
310-
# obj.tags = tags
311-
# obj.license_type = license_type
312303
if len(obj.video_keys) > 0 and not use_videos:
313304
raise ValueError()
314305
write_json(obj.info, obj.root / INFO_PATH)
@@ -439,6 +430,7 @@ def __init__(
439430

440431
# Unused attributes
441432
self.image_writer = None
433+
self.episode_buffer = None
442434

443435
self.root.mkdir(exist_ok=True, parents=True)
444436

@@ -464,9 +456,6 @@ def __init__(
464456
# Available stats implies all videos have been encoded and dataset is iterable
465457
self.consolidated = self.meta.stats is not None
466458

467-
# Create an empty buffer to extend the dataset if required
468-
self.episode_buffer = self._create_episode_buffer()
469-
470459
def push_to_hub(
471460
self,
472461
tags: list | None = None,
@@ -704,9 +693,12 @@ def add_frame(self, frame: dict) -> None:
704693
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
705694
then needs to be called.
706695
"""
707-
# TODO(rcadene): Add sanity check for the input, check it's numpy or torch,
696+
# TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch,
708697
# check the dtype and shape matches, etc.
709698

699+
if self.episode_buffer is None:
700+
self.episode_buffer = self._create_episode_buffer()
701+
710702
frame_index = self.episode_buffer["size"]
711703
timestamp = frame["timestamp"] if "timestamp" in frame else frame_index / self.fps
712704
self.episode_buffer["frame_index"].append(frame_index)
@@ -930,7 +922,8 @@ def create(
930922
obj.tolerance_s = tolerance_s
931923
obj.image_writer = None
932924

933-
obj.start_image_writer(image_writer_processes, image_writer_threads)
925+
if image_writer_processes or image_writer_threads:
926+
obj.start_image_writer(image_writer_processes, image_writer_threads)
934927

935928
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
936929
obj.episode_buffer = obj._create_episode_buffer()

lerobot/common/robot_devices/control_utils.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from lerobot.common.datasets.image_writer import safe_stop_image_writer
1919
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
20+
from lerobot.common.datasets.utils import get_features_from_robot
2021
from lerobot.common.policies.factory import make_policy
2122
from lerobot.common.robot_devices.robots.utils import Robot
2223
from lerobot.common.robot_devices.utils import busy_wait
@@ -333,17 +334,10 @@ def sanity_check_dataset_name(repo_id, policy):
333334

334335

335336
def sanity_check_dataset_robot_compatibility(dataset, robot, fps, use_videos):
336-
# TODO(rcadene): fix that before merging
337-
robot_type, keys, image_keys, video_keys, shapes, names = _get_info_from_robot(robot, use_videos) # noqa
338-
339337
fields = [
340-
("robot_type", dataset.meta.info["robot_type"], robot_type),
338+
("robot_type", dataset.meta.info["robot_type"], robot.robot_type),
341339
("fps", dataset.meta.info["fps"], fps),
342-
("keys", dataset.meta.info["keys"], keys),
343-
("image_keys", dataset.meta.info["image_keys"], image_keys),
344-
("video_keys", dataset.meta.info["video_keys"], video_keys),
345-
("shapes", dataset.meta.info["shapes"], shapes),
346-
("names", dataset.meta.info["names"], names),
340+
("features", dataset.features, get_features_from_robot(robot, use_videos)),
347341
]
348342

349343
mismatches = []

0 commit comments

Comments
 (0)