9
9
10
10
11
11
def create_empty_dataset (repo_id , mode ):
12
- features = {}
12
+ features = {
13
+ "observation.state" : {
14
+ "dtype" : "float32" ,
15
+ "shape" : (2 ,),
16
+ "names" : [
17
+ ["x" , "y" ],
18
+ ],
19
+ },
20
+ "action" : {
21
+ "dtype" : "float32" ,
22
+ "shape" : (2 ,),
23
+ "names" : [
24
+ ["x" , "y" ],
25
+ ],
26
+ },
27
+ "next.reward" : {
28
+ "dtype" : "float32" ,
29
+ "shape" : (1 ,),
30
+ "names" : None ,
31
+ },
32
+ "next.success" : {
33
+ "dtype" : "bool" ,
34
+ "shape" : (1 ,),
35
+ "names" : None ,
36
+ },
37
+ }
13
38
14
39
if mode == "keypoints" :
15
- state_dim = 16
40
+ features ["observation.environment_state" ] = {
41
+ "dtype" : "float32" ,
42
+ "shape" : (16 ,),
43
+ "names" : [
44
+ "keypoints" ,
45
+ ],
46
+ }
16
47
else :
17
- state_dim = 2
18
48
features ["observation.image" ] = {
19
49
"dtype" : mode ,
20
50
"shape" : (3 , 96 , 96 ),
@@ -25,35 +55,6 @@ def create_empty_dataset(repo_id, mode):
25
55
],
26
56
}
27
57
28
- features .update (
29
- {
30
- "observation.state" : {
31
- "dtype" : "float32" ,
32
- "shape" : (state_dim ,),
33
- "names" : [
34
- ["x" , "y" ],
35
- ],
36
- },
37
- "action" : {
38
- "dtype" : "float32" ,
39
- "shape" : (2 ,),
40
- "names" : [
41
- ["x" , "y" ],
42
- ],
43
- },
44
- "next.reward" : {
45
- "dtype" : "float32" ,
46
- "shape" : (1 ,),
47
- "names" : None ,
48
- },
49
- "next.success" : {
50
- "dtype" : "bool" ,
51
- "shape" : (1 ,),
52
- "names" : None ,
53
- },
54
- }
55
- )
56
-
57
58
dataset = LeRobotDataset .create (
58
59
repo_id = repo_id ,
59
60
fps = 10 ,
@@ -146,7 +147,7 @@ def calculate_reward(coverage, success_threshold):
146
147
return np .clip (coverage / success_threshold , 0 , 1 )
147
148
148
149
149
- def populate_dataset (dataset , episode_data_index , episodes , image , state , action , reward , success ):
150
+ def populate_dataset (dataset , episode_data_index , episodes , image , state , env_state , action , reward , success ):
150
151
if episodes is None :
151
152
episodes = range (len (episode_data_index ["from" ]))
152
153
@@ -160,20 +161,22 @@ def populate_dataset(dataset, episode_data_index, episodes, image, state, action
160
161
161
162
frame = {
162
163
"action" : torch .from_numpy (action [i ]),
163
- "timestamp" : frame_idx / dataset .fps ,
164
164
# Shift reward and success by +1 until the last item of the episode
165
165
"next.reward" : reward [i + (frame_idx < num_frames - 1 )],
166
166
"next.success" : success [i + (frame_idx < num_frames - 1 )],
167
167
}
168
168
169
169
frame ["observation.state" ] = torch .from_numpy (state [i ])
170
+
171
+ if env_state is not None :
172
+ frame ["observation.environment_state" ] = torch .from_numpy (env_state [i ])
173
+
170
174
if image is not None :
171
175
frame ["observation.image" ] = torch .from_numpy (image [i ])
172
176
173
- # TODO(rcadene): add_frame_to_buffer, add_episode_from_buffer
174
177
dataset .add_frame (frame )
175
178
176
- dataset .add_episode (task = "Push the T-shaped blue block onto the T-shaped green target surface." )
179
+ dataset .save_episode (task = "Push the T-shaped blue block onto the T-shaped green target surface." )
177
180
178
181
return dataset
179
182
@@ -205,7 +208,8 @@ def port_pusht(raw_dir, repo_id, episodes=None, mode="video", push_to_hub=True):
205
208
episode_data_index ,
206
209
episodes ,
207
210
image = None if mode == "keypoints" else image ,
208
- state = keypoints if mode == "keypoints" else agent_pos ,
211
+ state = agent_pos ,
212
+ env_state = keypoints if mode == "keypoints" else None ,
209
213
action = action ,
210
214
reward = reward ,
211
215
success = success ,
@@ -217,17 +221,26 @@ def port_pusht(raw_dir, repo_id, episodes=None, mode="video", push_to_hub=True):
217
221
218
222
219
223
if __name__ == "__main__" :
220
- episodes = [0 , 1 ]
221
- # episodes = None
222
-
223
- # for mode in ["video"]:
224
- for mode in ["image" ]:
225
- # for mode in ["keypoints"]:
226
- # for mode in ["video", "image", "keypoints"]:
227
- repo_id = "cadene/pusht_v2"
224
+ # To try this script, modify the repo id with your own HuggingFace user (e.g cadene/pusht)
225
+ repo_id = "lerobot/pusht"
226
+
227
+ episodes = None
228
+ # Uncomment if you want to try with a subset (episode 0 and 1)
229
+ # episodes = [0, 1]
230
+
231
+ modes = ["video" , "image" , "keypoints" ]
232
+ # Uncomment if you want to try with a specific mode
233
+ # modes = ["video"]
234
+ # modes = ["image"]
235
+ # modes = ["keypoints"]
236
+
237
+ for mode in ["video" , "image" , "keypoints" ]:
228
238
if mode in ["image" , "keypoints" ]:
229
239
repo_id += f"_{ mode } "
240
+
241
+ # download and load raw dataset, create LeRobotDataset, populate it, push to hub
230
242
port_pusht ("data/lerobot-raw/pusht_raw" , repo_id = repo_id , mode = mode , episodes = episodes )
231
243
232
- # dataset = LeRobotDataset(repo_id="cadene/pusht_v2", local_files_only=True)
233
- # dataset_old = LeRobotDataset(repo_id="lerobot/pusht")
244
+ # Uncomment if you want to loal the local dataset and explore it
245
+ # dataset = LeRobotDataset(repo_id=repo_id, local_files_only=True)
246
+ # breakpoint()
0 commit comments