Skip to content

Commit 04bbf71

Browse files
committed
Fix pusht keypoints + add BackwardCompatibilityError for dataset
1 parent e1e7edb commit 04bbf71

File tree

5 files changed

+95
-63
lines changed

5 files changed

+95
-63
lines changed

examples/port_datasets/pusht_zarr.py

+60-47
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,42 @@
99

1010

1111
def create_empty_dataset(repo_id, mode):
12-
features = {}
12+
features = {
13+
"observation.state": {
14+
"dtype": "float32",
15+
"shape": (2,),
16+
"names": [
17+
["x", "y"],
18+
],
19+
},
20+
"action": {
21+
"dtype": "float32",
22+
"shape": (2,),
23+
"names": [
24+
["x", "y"],
25+
],
26+
},
27+
"next.reward": {
28+
"dtype": "float32",
29+
"shape": (1,),
30+
"names": None,
31+
},
32+
"next.success": {
33+
"dtype": "bool",
34+
"shape": (1,),
35+
"names": None,
36+
},
37+
}
1338

1439
if mode == "keypoints":
15-
state_dim = 16
40+
features["observation.environment_state"] = {
41+
"dtype": "float32",
42+
"shape": (16,),
43+
"names": [
44+
"keypoints",
45+
],
46+
}
1647
else:
17-
state_dim = 2
1848
features["observation.image"] = {
1949
"dtype": mode,
2050
"shape": (3, 96, 96),
@@ -25,35 +55,6 @@ def create_empty_dataset(repo_id, mode):
2555
],
2656
}
2757

28-
features.update(
29-
{
30-
"observation.state": {
31-
"dtype": "float32",
32-
"shape": (state_dim,),
33-
"names": [
34-
["x", "y"],
35-
],
36-
},
37-
"action": {
38-
"dtype": "float32",
39-
"shape": (2,),
40-
"names": [
41-
["x", "y"],
42-
],
43-
},
44-
"next.reward": {
45-
"dtype": "float32",
46-
"shape": (1,),
47-
"names": None,
48-
},
49-
"next.success": {
50-
"dtype": "bool",
51-
"shape": (1,),
52-
"names": None,
53-
},
54-
}
55-
)
56-
5758
dataset = LeRobotDataset.create(
5859
repo_id=repo_id,
5960
fps=10,
@@ -146,7 +147,7 @@ def calculate_reward(coverage, success_threshold):
146147
return np.clip(coverage / success_threshold, 0, 1)
147148

148149

149-
def populate_dataset(dataset, episode_data_index, episodes, image, state, action, reward, success):
150+
def populate_dataset(dataset, episode_data_index, episodes, image, state, env_state, action, reward, success):
150151
if episodes is None:
151152
episodes = range(len(episode_data_index["from"]))
152153

@@ -160,20 +161,22 @@ def populate_dataset(dataset, episode_data_index, episodes, image, state, action
160161

161162
frame = {
162163
"action": torch.from_numpy(action[i]),
163-
"timestamp": frame_idx / dataset.fps,
164164
# Shift reward and success by +1 until the last item of the episode
165165
"next.reward": reward[i + (frame_idx < num_frames - 1)],
166166
"next.success": success[i + (frame_idx < num_frames - 1)],
167167
}
168168

169169
frame["observation.state"] = torch.from_numpy(state[i])
170+
171+
if env_state is not None:
172+
frame["observation.environment_state"] = torch.from_numpy(env_state[i])
173+
170174
if image is not None:
171175
frame["observation.image"] = torch.from_numpy(image[i])
172176

173-
# TODO(rcadene): add_frame_to_buffer, add_episode_from_buffer
174177
dataset.add_frame(frame)
175178

176-
dataset.add_episode(task="Push the T-shaped blue block onto the T-shaped green target surface.")
179+
dataset.save_episode(task="Push the T-shaped blue block onto the T-shaped green target surface.")
177180

178181
return dataset
179182

@@ -205,7 +208,8 @@ def port_pusht(raw_dir, repo_id, episodes=None, mode="video", push_to_hub=True):
205208
episode_data_index,
206209
episodes,
207210
image=None if mode == "keypoints" else image,
208-
state=keypoints if mode == "keypoints" else agent_pos,
211+
state=agent_pos,
212+
env_state=keypoints if mode == "keypoints" else None,
209213
action=action,
210214
reward=reward,
211215
success=success,
@@ -217,17 +221,26 @@ def port_pusht(raw_dir, repo_id, episodes=None, mode="video", push_to_hub=True):
217221

218222

219223
if __name__ == "__main__":
220-
episodes = [0, 1]
221-
# episodes = None
222-
223-
# for mode in ["video"]:
224-
for mode in ["image"]:
225-
# for mode in ["keypoints"]:
226-
# for mode in ["video", "image", "keypoints"]:
227-
repo_id = "cadene/pusht_v2"
224+
# To try this script, modify the repo id with your own HuggingFace user (e.g cadene/pusht)
225+
repo_id = "lerobot/pusht"
226+
227+
episodes = None
228+
# Uncomment if you want to try with a subset (episode 0 and 1)
229+
# episodes = [0, 1]
230+
231+
modes = ["video", "image", "keypoints"]
232+
# Uncomment if you want to try with a specific mode
233+
# modes = ["video"]
234+
# modes = ["image"]
235+
# modes = ["keypoints"]
236+
237+
for mode in ["video", "image", "keypoints"]:
228238
if mode in ["image", "keypoints"]:
229239
repo_id += f"_{mode}"
240+
241+
# download and load raw dataset, create LeRobotDataset, populate it, push to hub
230242
port_pusht("data/lerobot-raw/pusht_raw", repo_id=repo_id, mode=mode, episodes=episodes)
231243

232-
# dataset = LeRobotDataset(repo_id="cadene/pusht_v2", local_files_only=True)
233-
# dataset_old = LeRobotDataset(repo_id="lerobot/pusht")
244+
# Uncomment if you want to loal the local dataset and explore it
245+
# dataset = LeRobotDataset(repo_id=repo_id, local_files_only=True)
246+
# breakpoint()

lerobot/common/datasets/lerobot_dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,7 @@ def add_frame(self, frame: dict) -> None:
700700
self.episode_buffer = self._create_episode_buffer()
701701

702702
frame_index = self.episode_buffer["size"]
703-
timestamp = frame["timestamp"] if "timestamp" in frame else frame_index / self.fps
703+
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
704704
self.episode_buffer["frame_index"].append(frame_index)
705705
self.episode_buffer["timestamp"].append(timestamp)
706706

lerobot/common/datasets/utils.py

+32-13
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
import json
17+
import textwrap
1718
import warnings
1819
from itertools import accumulate
1920
from pathlib import Path
@@ -188,17 +189,37 @@ def _get_major_minor(version: str) -> tuple[int]:
188189
return int(split[0]), int(split[1])
189190

190191

192+
class BackwardCompatibilityError(Exception):
193+
def __init__(self, repo_id, version):
194+
message = textwrap.dedent(f"""
195+
BackwardCompatibilityError: The dataset you requested ({repo_id}) is in {version} format.
196+
197+
We introduced a new format since v2.0 which is not backward compatible with v1.x.
198+
Please, use our conversion script. Modify the following command with your own task description:
199+
```
200+
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
201+
--repo-id {repo_id} \\
202+
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
203+
```
204+
205+
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.",
206+
"Insert the peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.",
207+
"Open the top cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped target.",
208+
"Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the sweatshirt.", ...
209+
210+
If you encounter a problem, contact LeRobot maintainers on Discord ('https://discord.com/invite/s3KuuzsPFb')
211+
or open an issue on GitHub.
212+
""")
213+
super().__init__(message)
214+
215+
191216
def check_version_compatibility(
192217
repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True
193218
) -> None:
194219
current_major, _ = _get_major_minor(current_version)
195220
major_to_check, _ = _get_major_minor(version_to_check)
196221
if major_to_check < current_major and enforce_breaking_major:
197-
raise ValueError(
198-
f"""The dataset you requested ({repo_id}) is in {version_to_check} format. We introduced a new
199-
format with v2.0 that is not backward compatible. Please use our conversion script
200-
first (convert_dataset_v1_to_v2.py) to convert your dataset to this new format."""
201-
)
222+
raise BackwardCompatibilityError(repo_id, version_to_check)
202223
elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
203224
warnings.warn(
204225
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
@@ -209,18 +230,16 @@ def check_version_compatibility(
209230
)
210231

211232

212-
def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> str:
213-
num_version = float(version.strip("v"))
214-
if num_version < 2 and enforce_v2:
215-
raise ValueError(
216-
f"""The dataset you requested ({repo_id}) is in {version} format. We introduced a new
217-
format with v2.0 that is not backward compatible. Please use our conversion script
218-
first (convert_dataset_v1_to_v2.py) to convert your dataset to this new format."""
219-
)
233+
def get_hub_safe_version(repo_id: str, version: str) -> str:
220234
api = HfApi()
221235
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
222236
branches = [b.name for b in dataset_info.branches]
223237
if version not in branches:
238+
num_version = float(version.strip("v"))
239+
hub_num_versions = [float(v.strip("v")) for v in branches if v.startswith("v")]
240+
if num_version >= 2.0 and all(v < 2.0 for v in hub_num_versions):
241+
raise BackwardCompatibilityError(repo_id, version)
242+
224243
warnings.warn(
225244
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
226245
codebase. The following versions are available: {branches}.

lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def convert_dataset(
441441
arxiv: str | None = None,
442442
test_branch: str | None = None,
443443
):
444-
v1 = get_hub_safe_version(repo_id, V16, enforce_v2=False)
444+
v1 = get_hub_safe_version(repo_id, V16)
445445
v1x_dir = local_dir / V16 / repo_id
446446
v20_dir = local_dir / V20 / repo_id
447447
v1x_dir.mkdir(parents=True, exist_ok=True)

tests/fixtures/dataset_factories.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def _create_lerobot_dataset_metadata(
325325
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
326326
) as mock_snapshot_download_patch,
327327
):
328-
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version
328+
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version: version
329329
mock_snapshot_download_patch.side_effect = mock_snapshot_download
330330

331331
return LeRobotDatasetMetadata(repo_id=repo_id, root=root, local_files_only=local_files_only)

0 commit comments

Comments
 (0)