Skip to content

Commit 957659a

Browse files
committed
Add pusht_zarr.py
1 parent 8b94972 commit 957659a

File tree

1 file changed

+232
-0
lines changed

1 file changed

+232
-0
lines changed

examples/port_datasets/pusht_zarr.py

+232
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
import shutil
2+
from pathlib import Path
3+
4+
import numpy as np
5+
import torch
6+
7+
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
8+
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
9+
10+
11+
def create_empty_dataset(repo_id, mode):
12+
features = {}
13+
14+
if mode == "keypoints":
15+
state_dim = 16
16+
else:
17+
state_dim = 2
18+
features["observation.image"] = {
19+
"dtype": mode,
20+
"shape": (3, 96, 96),
21+
"names": [
22+
"channel",
23+
"height",
24+
"width",
25+
],
26+
}
27+
28+
features.update(
29+
{
30+
"observation.state": {
31+
"dtype": "float32",
32+
"shape": (state_dim,),
33+
"names": [
34+
["x", "y"],
35+
],
36+
},
37+
"action": {
38+
"dtype": "float32",
39+
"shape": (2,),
40+
"names": [
41+
["x", "y"],
42+
],
43+
},
44+
"next.reward": {
45+
"dtype": "float32",
46+
"shape": (1,),
47+
"names": None,
48+
},
49+
"next.success": {
50+
"dtype": "bool",
51+
"shape": (1,),
52+
"names": None,
53+
},
54+
}
55+
)
56+
57+
dataset = LeRobotDataset.create(
58+
repo_id=repo_id,
59+
fps=10,
60+
robot_type="2d pointer",
61+
features=features,
62+
)
63+
return dataset
64+
65+
66+
def load_raw_dataset(zarr_path, load_images=True):
67+
try:
68+
from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import (
69+
ReplayBuffer as DiffusionPolicyReplayBuffer,
70+
)
71+
except ModuleNotFoundError as e:
72+
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
73+
raise e
74+
75+
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
76+
77+
env_state = zarr_data["state"][:]
78+
agent_pos = env_state[:, :2]
79+
block_pos = env_state[:, 2:4]
80+
block_angle = env_state[:, 4]
81+
82+
action = zarr_data["action"][:]
83+
84+
image = None
85+
if load_images:
86+
# b h w c
87+
image = zarr_data["img"]
88+
89+
episode_data_index = {
90+
"from": np.array([0] + zarr_data.meta["episode_ends"][:-1].tolist()),
91+
"to": zarr_data.meta["episode_ends"],
92+
}
93+
94+
return image, agent_pos, block_pos, block_angle, action, episode_data_index
95+
96+
97+
def calculate_coverage(block_pos, block_angle):
98+
try:
99+
import pymunk
100+
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
101+
except ModuleNotFoundError as e:
102+
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
103+
raise e
104+
105+
num_frames = len(block_pos)
106+
107+
coverage = np.zeros((num_frames,))
108+
# 8 keypoints with 2 coords each
109+
keypoints = np.zeros((num_frames, 16))
110+
111+
# Set x, y, theta (in radians)
112+
goal_pos_angle = np.array([256, 256, np.pi / 4])
113+
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
114+
115+
for i in range(num_frames):
116+
space = pymunk.Space()
117+
space.gravity = 0, 0
118+
space.damping = 0
119+
120+
# Add walls.
121+
walls = [
122+
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
123+
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
124+
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
125+
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
126+
]
127+
space.add(*walls)
128+
129+
block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
130+
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
131+
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
132+
intersection_area = goal_geom.intersection(block_geom).area
133+
goal_area = goal_geom.area
134+
coverage[i] = intersection_area / goal_area
135+
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
136+
137+
return coverage, keypoints
138+
139+
140+
def calculate_success(coverage, success_threshold):
141+
return coverage > success_threshold
142+
143+
144+
def calculate_reward(coverage, success_threshold):
145+
return np.clip(coverage / success_threshold, 0, 1)
146+
147+
148+
def populate_dataset(dataset, episode_data_index, episodes, image, state, action, reward, success):
149+
if episodes is None:
150+
episodes = range(len(episode_data_index["from"]))
151+
152+
for ep_idx in episodes:
153+
from_idx = episode_data_index["from"][ep_idx]
154+
to_idx = episode_data_index["to"][ep_idx]
155+
num_frames = to_idx - from_idx
156+
157+
for frame_idx in range(num_frames):
158+
i = from_idx + frame_idx
159+
160+
frame = {
161+
"action": torch.from_numpy(action[i]),
162+
"timestamp": frame_idx / dataset.fps,
163+
# Shift reward and success by +1 until the last item of the episode
164+
"next.reward": reward[i + (frame_idx < num_frames - 1)],
165+
"next.success": success[i + (frame_idx < num_frames - 1)],
166+
}
167+
168+
frame["observation.state"] = torch.from_numpy(state[i])
169+
if image is not None:
170+
frame["observation.image"] = torch.from_numpy(image[i])
171+
172+
# TODO(rcadene): add_frame_to_buffer, add_episode_from_buffer
173+
dataset.add_frame(frame)
174+
175+
dataset.add_episode(task="Push the T-shaped blue block onto the T-shaped green target surface.")
176+
177+
return dataset
178+
179+
180+
def port_pusht(raw_dir, repo_id, episodes=None, mode="video", push_to_hub=True):
181+
if mode not in ["video", "image", "keypoints"]:
182+
raise ValueError(mode)
183+
184+
if (LEROBOT_HOME / repo_id).exists():
185+
shutil.rmtree(LEROBOT_HOME / repo_id)
186+
187+
raw_dir = Path(raw_dir)
188+
if not raw_dir.exists():
189+
download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw")
190+
191+
image, agent_pos, block_pos, block_angle, action, episode_data_index = load_raw_dataset(
192+
zarr_path=raw_dir / "pusht_cchi_v7_replay.zarr"
193+
)
194+
195+
# Calculate success and reward based on the overlapping area
196+
# of the T-object and the T-area.
197+
coverage, keypoints = calculate_coverage(block_pos, block_angle)
198+
success = calculate_success(coverage, success_threshold=0.95)
199+
reward = calculate_reward(coverage, success_threshold=0.95)
200+
201+
dataset = create_empty_dataset(repo_id, mode)
202+
dataset = populate_dataset(
203+
dataset,
204+
episode_data_index,
205+
episodes,
206+
image=None if mode == "keypoints" else image,
207+
state=keypoints if mode == "keypoints" else agent_pos,
208+
action=action,
209+
reward=reward,
210+
success=success,
211+
)
212+
dataset.consolidate()
213+
214+
if push_to_hub:
215+
dataset.push_to_hub()
216+
217+
218+
if __name__ == "__main__":
219+
episodes = [0, 1]
220+
# episodes = None
221+
222+
# for mode in ["video"]:
223+
for mode in ["image"]:
224+
# for mode in ["keypoints"]:
225+
# for mode in ["video", "image", "keypoints"]:
226+
repo_id = "cadene/pusht_v2"
227+
if mode in ["image", "keypoints"]:
228+
repo_id += f"_{mode}"
229+
port_pusht("data/lerobot-raw/pusht_raw", repo_id=repo_id, mode=mode, episodes=episodes)
230+
231+
# dataset = LeRobotDataset(repo_id="cadene/pusht_v2", local_files_only=True)
232+
# dataset_old = LeRobotDataset(repo_id="lerobot/pusht")

0 commit comments

Comments
 (0)