Skip to content

Commit da67242

Browse files
committed
Improvements
1 parent 7590181 commit da67242

File tree

4 files changed

+33
-7
lines changed

4 files changed

+33
-7
lines changed

lerobot/common/datasets/lerobot_dataset.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,10 @@ def write_video_info(self) -> None:
243243
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
244244
been encoded the same way. Also, this means it assumes the first episode exists.
245245
"""
246+
# TODO(rcadene): What should we do here?
247+
if "videos" not in self.info:
248+
self.info["videos"] = {}
249+
246250
for key in self.video_keys:
247251
if key not in self.info["videos"]:
248252
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
@@ -260,6 +264,8 @@ def create(
260264
robot_type: str | None = None,
261265
features: dict | None = None,
262266
use_videos: bool = True,
267+
# tags: list[str] | None = None,
268+
# license_type: str | None = None,
263269
) -> "LeRobotDatasetMetadata":
264270
"""Creates metadata for a LeRobotDataset."""
265271
obj = cls.__new__(cls)
@@ -283,8 +289,12 @@ def create(
283289
else:
284290
features = {**features, **DEFAULT_FEATURES}
285291

292+
# TODO(rcadene): implement sanity check for features
293+
286294
obj.tasks, obj.stats, obj.episodes = {}, {}, []
287295
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
296+
# obj.tags = tags
297+
# obj.license_type = license_type
288298
if len(obj.video_keys) > 0 and not use_videos:
289299
raise ValueError()
290300
write_json(obj.info, obj.root / INFO_PATH)
@@ -646,7 +656,7 @@ def _create_episode_buffer(self, episode_index: int | None = None) -> dict:
646656
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
647657
return {
648658
"size": 0,
649-
**{key: [] if key != "episode_index" else current_ep_idx for key in self.features},
659+
**{key: current_ep_idx if key == "episode_index" else [] for key in self.features},
650660
}
651661

652662
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
@@ -669,6 +679,9 @@ def add_frame(self, frame: dict) -> None:
669679
temporary directory — nothing is written to disk. To save those frames, the 'add_episode()' method
670680
then needs to be called.
671681
"""
682+
# TODO(rcadene): Add sanity check for the input, check it's numpy or torch,
683+
# check the dtype and shape matches, etc.
684+
672685
frame_index = self.episode_buffer["size"]
673686
for key, ft in self.features.items():
674687
if key == "frame_index":
@@ -705,6 +718,11 @@ def add_episode(self, task: str, encode_videos: bool = False) -> None:
705718
# TODO(aliberts): Add option to use existing episode_index
706719
raise NotImplementedError()
707720

721+
if episode_length == 0:
722+
raise ValueError(
723+
"You must add one or several frames with `add_frame` before calling `add_episode`."
724+
)
725+
708726
task_index = self.meta.get_task_index(task)
709727

710728
if not set(self.episode_buffer.keys()) == set(self.features):
@@ -719,11 +737,14 @@ def add_episode(self, task: str, encode_videos: bool = False) -> None:
719737
self.episode_buffer[key] = np.full((episode_length,), episode_index)
720738
elif key == "task_index":
721739
self.episode_buffer[key] = np.full((episode_length,), task_index)
722-
elif ft["dtype"] in ["image", "video"]:
740+
elif ft["dtype"] == "image":
723741
continue
742+
elif ft["dtype"] == "video":
743+
del self.episode_buffer[key]
724744
elif ft["shape"][0] == 1:
725745
self.episode_buffer[key] = torch.tensor(self.episode_buffer[key])
726746
elif ft["shape"][0] > 1:
747+
# TODO(rcadene): why torch over here?
727748
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
728749
else:
729750
raise ValueError()

lerobot/common/datasets/utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None
435435

436436

437437
def create_lerobot_dataset_card(
438-
tags: list | None = None, text: str | None = None, info: dict | None = None
438+
tags: list | None = None, text: str | None = None, info: dict | None = None, license: str = "apache-2.0"
439439
) -> DatasetCard:
440440
card = DatasetCard(DATASET_CARD_TEMPLATE)
441441
card.data.configs = [
@@ -445,6 +445,7 @@ def create_lerobot_dataset_card(
445445
}
446446
]
447447
card.data.task_categories = ["robotics"]
448+
card.data.license = license
448449
card.data.tags = ["LeRobot"]
449450
if tags is not None:
450451
card.data.tags += tags

lerobot/common/robot_devices/control_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
from lerobot.common.datasets.image_writer import safe_stop_image_writer
1919
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
20-
from lerobot.common.datasets.utils import _get_info_from_robot
2120
from lerobot.common.policies.factory import make_policy
2221
from lerobot.common.robot_devices.robots.utils import Robot
2322
from lerobot.common.robot_devices.utils import busy_wait
@@ -334,7 +333,8 @@ def sanity_check_dataset_name(repo_id, policy):
334333

335334

336335
def sanity_check_dataset_robot_compatibility(dataset, robot, fps, use_videos):
337-
robot_type, keys, image_keys, video_keys, shapes, names = _get_info_from_robot(robot, use_videos)
336+
# TODO(rcadene): fix that before merging
337+
robot_type, keys, image_keys, video_keys, shapes, names = _get_info_from_robot(robot, use_videos) # noqa
338338

339339
fields = [
340340
("robot_type", dataset.meta.info["robot_type"], robot_type),

lerobot/scripts/push_dataset_to_hub.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,14 @@ def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str
117117

118118

119119
def push_dataset_card_to_hub(
120-
repo_id: str, revision: str | None, tags: list | None = None, text: str | None = None
120+
repo_id: str,
121+
revision: str | None,
122+
tags: list | None = None,
123+
text: str | None = None,
124+
license: str = "apache-2.0",
121125
):
122126
"""Creates and pushes a LeRobotDataset Card with appropriate tags to easily find it on the hub."""
123-
card = create_lerobot_dataset_card(tags=tags, text=text)
127+
card = create_lerobot_dataset_card(tags=tags, text=text, license=license)
124128
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=revision)
125129

126130

0 commit comments

Comments
 (0)