Skip to content

Commit 1f13bda

Browse files
authored
Improve dataset v2 (#498)
1 parent acae4b4 commit 1f13bda

File tree

9 files changed

+393
-70
lines changed

9 files changed

+393
-70
lines changed

examples/port_datasets/pusht_zarr.py

+246
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
import shutil
2+
from pathlib import Path
3+
4+
import numpy as np
5+
import torch
6+
7+
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
8+
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
9+
10+
11+
def create_empty_dataset(repo_id, mode):
12+
features = {
13+
"observation.state": {
14+
"dtype": "float32",
15+
"shape": (2,),
16+
"names": [
17+
["x", "y"],
18+
],
19+
},
20+
"action": {
21+
"dtype": "float32",
22+
"shape": (2,),
23+
"names": [
24+
["x", "y"],
25+
],
26+
},
27+
"next.reward": {
28+
"dtype": "float32",
29+
"shape": (1,),
30+
"names": None,
31+
},
32+
"next.success": {
33+
"dtype": "bool",
34+
"shape": (1,),
35+
"names": None,
36+
},
37+
}
38+
39+
if mode == "keypoints":
40+
features["observation.environment_state"] = {
41+
"dtype": "float32",
42+
"shape": (16,),
43+
"names": [
44+
"keypoints",
45+
],
46+
}
47+
else:
48+
features["observation.image"] = {
49+
"dtype": mode,
50+
"shape": (3, 96, 96),
51+
"names": [
52+
"channel",
53+
"height",
54+
"width",
55+
],
56+
}
57+
58+
dataset = LeRobotDataset.create(
59+
repo_id=repo_id,
60+
fps=10,
61+
robot_type="2d pointer",
62+
features=features,
63+
image_writer_threads=4,
64+
)
65+
return dataset
66+
67+
68+
def load_raw_dataset(zarr_path, load_images=True):
69+
try:
70+
from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import (
71+
ReplayBuffer as DiffusionPolicyReplayBuffer,
72+
)
73+
except ModuleNotFoundError as e:
74+
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
75+
raise e
76+
77+
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
78+
79+
env_state = zarr_data["state"][:]
80+
agent_pos = env_state[:, :2]
81+
block_pos = env_state[:, 2:4]
82+
block_angle = env_state[:, 4]
83+
84+
action = zarr_data["action"][:]
85+
86+
image = None
87+
if load_images:
88+
# b h w c
89+
image = zarr_data["img"]
90+
91+
episode_data_index = {
92+
"from": np.array([0] + zarr_data.meta["episode_ends"][:-1].tolist()),
93+
"to": zarr_data.meta["episode_ends"],
94+
}
95+
96+
return image, agent_pos, block_pos, block_angle, action, episode_data_index
97+
98+
99+
def calculate_coverage(block_pos, block_angle):
100+
try:
101+
import pymunk
102+
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
103+
except ModuleNotFoundError as e:
104+
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
105+
raise e
106+
107+
num_frames = len(block_pos)
108+
109+
coverage = np.zeros((num_frames,))
110+
# 8 keypoints with 2 coords each
111+
keypoints = np.zeros((num_frames, 16))
112+
113+
# Set x, y, theta (in radians)
114+
goal_pos_angle = np.array([256, 256, np.pi / 4])
115+
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
116+
117+
for i in range(num_frames):
118+
space = pymunk.Space()
119+
space.gravity = 0, 0
120+
space.damping = 0
121+
122+
# Add walls.
123+
walls = [
124+
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
125+
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
126+
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
127+
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
128+
]
129+
space.add(*walls)
130+
131+
block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
132+
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
133+
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
134+
intersection_area = goal_geom.intersection(block_geom).area
135+
goal_area = goal_geom.area
136+
coverage[i] = intersection_area / goal_area
137+
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
138+
139+
return coverage, keypoints
140+
141+
142+
def calculate_success(coverage, success_threshold):
143+
return coverage > success_threshold
144+
145+
146+
def calculate_reward(coverage, success_threshold):
147+
return np.clip(coverage / success_threshold, 0, 1)
148+
149+
150+
def populate_dataset(dataset, episode_data_index, episodes, image, state, env_state, action, reward, success):
151+
if episodes is None:
152+
episodes = range(len(episode_data_index["from"]))
153+
154+
for ep_idx in episodes:
155+
from_idx = episode_data_index["from"][ep_idx]
156+
to_idx = episode_data_index["to"][ep_idx]
157+
num_frames = to_idx - from_idx
158+
159+
for frame_idx in range(num_frames):
160+
i = from_idx + frame_idx
161+
162+
frame = {
163+
"action": torch.from_numpy(action[i]),
164+
# Shift reward and success by +1 until the last item of the episode
165+
"next.reward": reward[i + (frame_idx < num_frames - 1)],
166+
"next.success": success[i + (frame_idx < num_frames - 1)],
167+
}
168+
169+
frame["observation.state"] = torch.from_numpy(state[i])
170+
171+
if env_state is not None:
172+
frame["observation.environment_state"] = torch.from_numpy(env_state[i])
173+
174+
if image is not None:
175+
frame["observation.image"] = torch.from_numpy(image[i])
176+
177+
dataset.add_frame(frame)
178+
179+
dataset.save_episode(task="Push the T-shaped blue block onto the T-shaped green target surface.")
180+
181+
return dataset
182+
183+
184+
def port_pusht(raw_dir, repo_id, episodes=None, mode="video", push_to_hub=True):
185+
if mode not in ["video", "image", "keypoints"]:
186+
raise ValueError(mode)
187+
188+
if (LEROBOT_HOME / repo_id).exists():
189+
shutil.rmtree(LEROBOT_HOME / repo_id)
190+
191+
raw_dir = Path(raw_dir)
192+
if not raw_dir.exists():
193+
download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw")
194+
195+
image, agent_pos, block_pos, block_angle, action, episode_data_index = load_raw_dataset(
196+
zarr_path=raw_dir / "pusht_cchi_v7_replay.zarr"
197+
)
198+
199+
# Calculate success and reward based on the overlapping area
200+
# of the T-object and the T-area.
201+
coverage, keypoints = calculate_coverage(block_pos, block_angle)
202+
success = calculate_success(coverage, success_threshold=0.95)
203+
reward = calculate_reward(coverage, success_threshold=0.95)
204+
205+
dataset = create_empty_dataset(repo_id, mode)
206+
dataset = populate_dataset(
207+
dataset,
208+
episode_data_index,
209+
episodes,
210+
image=None if mode == "keypoints" else image,
211+
state=agent_pos,
212+
env_state=keypoints if mode == "keypoints" else None,
213+
action=action,
214+
reward=reward,
215+
success=success,
216+
)
217+
dataset.consolidate()
218+
219+
if push_to_hub:
220+
dataset.push_to_hub()
221+
222+
223+
if __name__ == "__main__":
224+
# To try this script, modify the repo id with your own HuggingFace user (e.g cadene/pusht)
225+
repo_id = "lerobot/pusht"
226+
227+
episodes = None
228+
# Uncomment if you want to try with a subset (episode 0 and 1)
229+
# episodes = [0, 1]
230+
231+
modes = ["video", "image", "keypoints"]
232+
# Uncomment if you want to try with a specific mode
233+
# modes = ["video"]
234+
# modes = ["image"]
235+
# modes = ["keypoints"]
236+
237+
for mode in ["video", "image", "keypoints"]:
238+
if mode in ["image", "keypoints"]:
239+
repo_id += f"_{mode}"
240+
241+
# download and load raw dataset, create LeRobotDataset, populate it, push to hub
242+
port_pusht("data/lerobot-raw/pusht_raw", repo_id=repo_id, mode=mode, episodes=episodes)
243+
244+
# Uncomment if you want to loal the local dataset and explore it
245+
# dataset = LeRobotDataset(repo_id=repo_id, local_files_only=True)
246+
# breakpoint()

lerobot/common/datasets/lerobot_dataset.py

+30-7
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,8 @@ def create(
280280
obj.repo_id = repo_id
281281
obj.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
282282

283+
obj.root.mkdir(parents=True, exist_ok=False)
284+
283285
if robot is not None:
284286
features = get_features_from_robot(robot, use_videos)
285287
robot_type = robot.robot_type
@@ -293,6 +295,7 @@ def create(
293295
"Dataset features must either come from a Robot or explicitly passed upon creation."
294296
)
295297
else:
298+
# TODO(aliberts, rcadene): implement sanity check for features
296299
features = {**features, **DEFAULT_FEATURES}
297300

298301
obj.tasks, obj.stats, obj.episodes = {}, {}, []
@@ -424,11 +427,10 @@ def __init__(
424427
self.video_backend = video_backend if video_backend is not None else "pyav"
425428
self.delta_indices = None
426429
self.local_files_only = local_files_only
427-
self.consolidated = True
428430

429431
# Unused attributes
430432
self.image_writer = None
431-
self.episode_buffer = {}
433+
self.episode_buffer = None
432434

433435
self.root.mkdir(exist_ok=True, parents=True)
434436

@@ -451,12 +453,16 @@ def __init__(
451453
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
452454
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
453455

456+
# Available stats implies all videos have been encoded and dataset is iterable
457+
self.consolidated = self.meta.stats is not None
458+
454459
def push_to_hub(
455460
self,
456461
tags: list | None = None,
457462
text: str | None = None,
458-
license: str | None = "mit",
463+
license: str | None = "apache-2.0",
459464
push_videos: bool = True,
465+
private: bool = False,
460466
) -> None:
461467
if not self.consolidated:
462468
raise RuntimeError(
@@ -468,7 +474,13 @@ def push_to_hub(
468474
if not push_videos:
469475
ignore_patterns.append("videos/")
470476

471-
create_repo(self.repo_id, repo_type="dataset", exist_ok=True)
477+
create_repo(
478+
repo_id=self.repo_id,
479+
private=private,
480+
repo_type="dataset",
481+
exist_ok=True,
482+
)
483+
472484
upload_folder(
473485
repo_id=self.repo_id,
474486
folder_path=self.root,
@@ -658,7 +670,7 @@ def _create_episode_buffer(self, episode_index: int | None = None) -> dict:
658670
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
659671
return {
660672
"size": 0,
661-
**{key: [] if key != "episode_index" else current_ep_idx for key in self.features},
673+
**{key: current_ep_idx if key == "episode_index" else [] for key in self.features},
662674
}
663675

664676
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
@@ -681,8 +693,14 @@ def add_frame(self, frame: dict) -> None:
681693
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
682694
then needs to be called.
683695
"""
696+
# TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch,
697+
# check the dtype and shape matches, etc.
698+
699+
if self.episode_buffer is None:
700+
self.episode_buffer = self._create_episode_buffer()
701+
684702
frame_index = self.episode_buffer["size"]
685-
timestamp = frame["timestamp"] if "timestamp" in frame else frame_index / self.fps
703+
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
686704
self.episode_buffer["frame_index"].append(frame_index)
687705
self.episode_buffer["timestamp"].append(timestamp)
688706

@@ -723,6 +741,11 @@ def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict
723741
# TODO(aliberts): Add option to use existing episode_index
724742
raise NotImplementedError()
725743

744+
if episode_length == 0:
745+
raise ValueError(
746+
"You must add one or several frames with `add_frame` before calling `add_episode`."
747+
)
748+
726749
task_index = self.meta.get_task_index(task)
727750

728751
if not set(episode_buffer.keys()) == set(self.features):
@@ -781,7 +804,7 @@ def clear_episode_buffer(self) -> None:
781804
# Reset the buffer
782805
self.episode_buffer = self._create_episode_buffer()
783806

784-
def start_image_writer(self, num_processes: int = 0, num_threads: int = 1) -> None:
807+
def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None:
785808
if isinstance(self.image_writer, AsyncImageWriter):
786809
logging.warning(
787810
"You are starting a new AsyncImageWriter that is replacing an already exising one in the dataset."

0 commit comments

Comments
 (0)