Skip to content

Commit 9ee8711

Browse files
committed
Update example 1
1 parent 6203641 commit 9ee8711

File tree

1 file changed

+67
-27
lines changed

1 file changed

+67
-27
lines changed

examples/1_load_lerobot_dataset.py

+67-27
Original file line numberDiff line numberDiff line change
@@ -14,53 +14,92 @@
1414
"""
1515

1616
# TODO(aliberts, rcadene): Update this script with the new v2 api
17-
from pathlib import Path
1817
from pprint import pprint
1918

20-
import imageio
2119
import torch
20+
from huggingface_hub import HfApi
2221

2322
import lerobot
24-
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
23+
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
2524

25+
# We ported a number of existing datasets ourselves, use this to see the list:
2626
print("List of available datasets:")
2727
pprint(lerobot.available_datasets)
2828

29-
# Let's take one for this example
30-
repo_id = "lerobot/pusht"
31-
32-
# You can easily load a dataset from a Hugging Face repository
29+
# You can also browse through the datasets created/ported by the community on the hub using the hub api:
30+
hub_api = HfApi()
31+
repo_ids = [info.id for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])]
32+
pprint(repo_ids)
33+
34+
# Or simply explore them in your web browser directly at:
35+
# https://huggingface.co/datasets?other=LeRobot
36+
37+
# Let's take this one for this example
38+
repo_id = "aliberts/koch_tutorial"
39+
# We can have a look and fetch its metadata to know more about it:
40+
ds_meta = LeRobotDatasetMetadata(repo_id)
41+
42+
# By instantiating just this class, you can quickly access useful information about the content and the
43+
# structure of the dataset without downloading the actual data yet (only metadata files — which are
44+
# lightweight).
45+
print(f"Total number of episodes: {ds_meta.total_episodes}")
46+
print(f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}")
47+
print(f"Frames per second used during data collection: {ds_meta.fps}")
48+
print(f"Robot type: {ds_meta.robot_type}")
49+
print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n")
50+
51+
print("Tasks:")
52+
print(ds_meta.tasks)
53+
print("Features:")
54+
pprint(ds_meta.features)
55+
56+
# You can also get a short summary by simply printing the object:
57+
print(ds_meta)
58+
59+
# You can then load the actual dataset from the hub.
60+
# Either load any subset of episodes:
61+
dataset = LeRobotDataset(repo_id, episodes=[0, 10, 11, 23])
62+
63+
# And see how many frames you have:
64+
print(f"Selected episodes: {dataset.episodes}")
65+
print(f"Number of episodes selected: {dataset.num_episodes}")
66+
print(f"Number of frames selected: {dataset.num_frames}")
67+
68+
# Or simply load the entire dataset:
3369
dataset = LeRobotDataset(repo_id)
70+
print(f"Number of episodes selected: {dataset.num_episodes}")
71+
print(f"Number of frames selected: {dataset.num_frames}")
72+
73+
# The previous metadata class is contained in the 'meta' attribute of the dataset:
74+
print(dataset.meta)
3475

3576
# LeRobotDataset actually wraps an underlying Hugging Face dataset
36-
# (see https://huggingface.co/docs/datasets/index for more information).
37-
print(dataset)
77+
# (see https://huggingface.co/docs/datasets for more information).
3878
print(dataset.hf_dataset)
3979

40-
# And provides additional utilities for robotics and compatibility with Pytorch
41-
print(f"\naverage number of frames per episode: {dataset.num_frames / dataset.num_episodes:.3f}")
42-
print(f"frames per second used during data collection: {dataset.fps=}")
43-
print(f"keys to access images from cameras: {dataset.meta.camera_keys=}\n")
44-
45-
# Access frame indexes associated to first episode
80+
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
81+
# with the latter, like iterating through the dataset.
82+
# The __get_item__ iterates over the frames of the dataset. Since our datasets are also structured by
83+
# episodes, you can access the frame indices of any episode using the episode_data_index. Here, we access
84+
# frame indices associated to the first episode:
4685
episode_index = 0
4786
from_idx = dataset.episode_data_index["from"][episode_index].item()
4887
to_idx = dataset.episode_data_index["to"][episode_index].item()
4988

50-
# LeRobot datasets actually subclass PyTorch datasets so you can do everything you know and love from working
51-
# with the latter, like iterating through the dataset. Here we grab all the image frames.
52-
frames = [dataset[idx]["observation.image"] for idx in range(from_idx, to_idx)]
89+
# Then we grab all the image frames from the first camera:
90+
camera_key = dataset.meta.camera_keys[0]
91+
frames = [dataset[idx][camera_key] for idx in range(from_idx, to_idx)]
5392

54-
# Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention. To visualize
55-
# them, we convert to uint8 in range [0,255]
56-
frames = [(frame * 255).type(torch.uint8) for frame in frames]
57-
# and to channel last (h,w,c).
58-
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
59-
60-
# Finally, we save the frames to a mp4 video for visualization.
61-
Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
62-
imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_0.mp4", frames, fps=dataset.fps)
93+
# The objects returned by the dataset are all torch.Tensors
94+
print(type(frames[0]))
95+
print(frames[0].shape)
6396

97+
# Since we're using pytorch, the shape is in pytorch, channel-first convention (c, h, w).
98+
# We can compare this shape with the information available for that feature
99+
pprint(dataset.features[camera_key])
100+
# In particular:
101+
print(dataset.features[camera_key]["shape"])
102+
# The shape is in (h, w, c) which is a more universal format.
64103

65104
# For many machine learning applications we need to load the history of past observations or trajectories of
66105
# future actions. Our datasets can load previous and future frames for each key/modality, using timestamps
@@ -86,6 +125,7 @@
86125
batch_size=32,
87126
shuffle=True,
88127
)
128+
89129
for batch in dataloader:
90130
print(f"{batch['observation.image'].shape=}") # (32,4,c,h,w)
91131
print(f"{batch['observation.state'].shape=}") # (32,8,c)

0 commit comments

Comments
 (0)