40
40
from PIL import Image as PILImage
41
41
42
42
from lerobot .common .datasets .lerobot_dataset import CODEBASE_VERSION
43
- from lerobot .common .datasets .push_dataset_to_hub .openx .transforms import OPENX_STANDARDIZATION_TRANSFORMS
44
43
from lerobot .common .datasets .push_dataset_to_hub .utils import (
45
44
concatenate_episodes ,
46
45
get_default_encoding ,
52
51
)
53
52
from lerobot .common .datasets .video_utils import VideoFrame , encode_video_frames
54
53
55
- with open ("lerobot/common/datasets/push_dataset_to_hub/openx/configs.yaml" ) as f :
56
- _openx_list = yaml .safe_load (f )
57
-
58
- OPENX_DATASET_CONFIGS = _openx_list ["OPENX_DATASET_CONFIGS" ]
59
-
60
54
np .set_printoptions (precision = 2 )
61
55
62
-
63
56
def tf_to_torch (data ):
64
57
return torch .from_numpy (data .numpy ())
65
58
66
-
67
59
def tf_img_convert (img ):
68
60
if img .dtype == tf .string :
69
61
img = tf .io .decode_image (img , expand_animations = False , dtype = tf .uint8 )
70
62
elif img .dtype != tf .uint8 :
71
63
raise ValueError (f"Unsupported image dtype: found with dtype { img .dtype } " )
72
64
return img .numpy ()
73
65
74
-
75
66
def _broadcast_metadata_rlds (i : tf .Tensor , traj : dict ) -> dict :
76
67
"""
77
68
In the RLDS format, each trajectory has some top-level metadata that is explicitly separated out, and a "steps"
@@ -108,7 +99,6 @@ def load_from_raw(
108
99
video : bool ,
109
100
episodes : list [int ] | None = None ,
110
101
encoding : dict | None = None ,
111
- openx_dataset_name : str | None = None ,
112
102
):
113
103
"""
114
104
Args:
@@ -136,16 +126,17 @@ def load_from_raw(
136
126
# we will apply the standardization transform if the dataset_name is provided
137
127
# if the dataset name is not provided and the goal is to convert any rlds formatted dataset
138
128
# search for 'image' keys in the observations
139
- if openx_dataset_name is not None :
140
- print (" - applying standardization transform for dataset: " , openx_dataset_name )
141
- assert openx_dataset_name in OPENX_STANDARDIZATION_TRANSFORMS
142
- transform_fn = OPENX_STANDARDIZATION_TRANSFORMS [openx_dataset_name ]
143
- dataset = dataset .map (transform_fn )
144
-
145
- image_keys = OPENX_DATASET_CONFIGS [openx_dataset_name ]["image_obs_keys" ]
146
- else :
147
- obs_keys = dataset_info .features ["steps" ]["observation" ].keys ()
148
- image_keys = [key for key in obs_keys if "image" in key ]
129
+ image_keys = []
130
+ state_keys = []
131
+ observation_info = dataset_info .features ['steps' ]['observation' ]
132
+ for key in observation_info :
133
+ # check whether the key is for an image or a vector observation
134
+ if len (observation_info [key ].shape ) == 3 :
135
+ # only adding uint8 images discards depth images
136
+ if observation_info [key ].dtype == tf .uint8 :
137
+ image_keys .append (key )
138
+ else :
139
+ state_keys .append (key )
149
140
150
141
lang_key = "language_instruction" if "language_instruction" in dataset .element_spec else None
151
142
@@ -177,6 +168,8 @@ def load_from_raw(
177
168
# convert episodes index to sorted list
178
169
episodes = sorted (episodes )
179
170
171
+ breakpoint ()
172
+
180
173
for ep_idx in tqdm .tqdm (range (starting_ep_idx , ds_length )):
181
174
episode = next (it )
182
175
@@ -193,49 +186,30 @@ def load_from_raw(
193
186
194
187
num_frames = episode ["action" ].shape [0 ]
195
188
196
- ###########################################################
197
- # Handle the episodic data
198
-
199
- # last step of demonstration is considered done
200
- done = torch .zeros (num_frames , dtype = torch .bool )
201
- done [- 1 ] = True
202
189
ep_dict = {}
203
- langs = [] # TODO: might be located in "observation"
190
+ for key in state_keys :
191
+ ep_dict [f"observation.{ key } " ] = tf_to_torch (episode ["observation" ][key ])
204
192
205
- image_array_dict = {key : [] for key in image_keys }
206
-
207
- # We will create the state observation tensor by stacking the state
208
- # obs keys defined in the openx/configs.py
209
- if openx_dataset_name is not None :
210
- state_obs_keys = OPENX_DATASET_CONFIGS [openx_dataset_name ]["state_obs_keys" ]
211
- # stack the state observations, if is None, pad with zeros
212
- states = []
213
- for key in state_obs_keys :
214
- if key in episode ["observation" ]:
215
- states .append (tf_to_torch (episode ["observation" ][key ]))
216
- else :
217
- states .append (torch .zeros (num_frames , 1 )) # pad with zeros
218
- states = torch .cat (states , dim = 1 )
219
- # assert states.shape == (num_frames, 8), f"states shape: {states.shape}"
220
- else :
221
- states = tf_to_torch (episode ["observation" ]["state" ])
222
-
223
- actions = tf_to_torch (episode ["action" ])
224
- rewards = tf_to_torch (episode ["reward" ]).float ()
193
+ ep_dict ["action" ] = tf_to_torch (episode ["action" ])
194
+ ep_dict ["next.reward" ] = tf_to_torch (episode ["reward" ]).float ()
195
+ ep_dict ["next.done" ] = tf_to_torch (episode ["is_last" ])
196
+ ep_dict ["is_terminal" ] = tf_to_torch (episode ["is_terminal" ])
197
+ ep_dict ["is_first" ] = tf_to_torch (episode ["is_first" ])
198
+ ep_dict ["discount" ] = tf_to_torch (episode ["discount" ])
225
199
226
200
# If lang_key is present, convert the entire tensor at once
227
201
if lang_key is not None :
228
- langs = [str (x ) for x in episode [lang_key ]]
202
+ ep_dict ["language_instruction" ] = [str (x ) for x in episode [lang_key ]]
203
+
204
+ ep_dict ["timestamp" ] = torch .arange (0 , num_frames , 1 ) / fps
205
+ ep_dict ["episode_index" ] = torch .tensor ([ep_idx ] * num_frames )
206
+ ep_dict ["frame_index" ] = torch .arange (0 , num_frames , 1 )
229
207
230
208
for im_key in image_keys :
231
209
imgs = episode ["observation" ][im_key ]
232
210
image_array_dict [im_key ] = [tf_img_convert (img ) for img in imgs ]
233
211
234
- # simple assertions
235
- for item in [states , actions , rewards , done ]:
236
- assert len (item ) == num_frames
237
-
238
- ###########################################################
212
+ image_array_dict = {key : [] for key in image_keys }
239
213
240
214
# loop through all cameras
241
215
for im_key in image_keys :
@@ -262,17 +236,6 @@ def load_from_raw(
262
236
else :
263
237
ep_dict [img_key ] = [PILImage .fromarray (x ) for x in imgs_array ]
264
238
265
- if lang_key is not None :
266
- ep_dict ["language_instruction" ] = langs
267
-
268
- ep_dict ["observation.state" ] = states
269
- ep_dict ["action" ] = actions
270
- ep_dict ["timestamp" ] = torch .arange (0 , num_frames , 1 ) / fps
271
- ep_dict ["episode_index" ] = torch .tensor ([ep_idx ] * num_frames )
272
- ep_dict ["frame_index" ] = torch .arange (0 , num_frames , 1 )
273
- ep_dict ["next.reward" ] = rewards
274
- ep_dict ["next.done" ] = done
275
-
276
239
path_ep_dict = tmp_ep_dicts_dir .joinpath (
277
240
"ep_dict_" + "0" * (10 - len (str (ep_idx ))) + str (ep_idx ) + ".pt"
278
241
)
@@ -289,31 +252,29 @@ def load_from_raw(
289
252
290
253
def to_hf_dataset (data_dict , video ) -> Dataset :
291
254
features = {}
255
+
256
+ for key in data_dict :
257
+ # check if vector state obs
258
+ if key .startswith ("observation." ) and "observation.images." not in key :
259
+ features [key ] = Sequence (length = data_dict [key ].shape [1 ], feature = Value (dtype = "float32" , id = None ))
260
+ # check if image obs
261
+ elif "observation.images." in key :
262
+ if video :
263
+ features [key ] = VideoFrame ()
264
+ else :
265
+ features [key ] = Image ()
292
266
293
- keys = [key for key in data_dict if "observation.images." in key ]
294
- for key in keys :
295
- if video :
296
- features [key ] = VideoFrame ()
297
- else :
298
- features [key ] = Image ()
299
-
300
- features ["observation.state" ] = Sequence (
301
- length = data_dict ["observation.state" ].shape [1 ], feature = Value (dtype = "float32" , id = None )
302
- )
303
- if "observation.velocity" in data_dict :
304
- features ["observation.velocity" ] = Sequence (
305
- length = data_dict ["observation.velocity" ].shape [1 ], feature = Value (dtype = "float32" , id = None )
306
- )
307
- if "observation.effort" in data_dict :
308
- features ["observation.effort" ] = Sequence (
309
- length = data_dict ["observation.effort" ].shape [1 ], feature = Value (dtype = "float32" , id = None )
310
- )
311
267
if "language_instruction" in data_dict :
312
268
features ["language_instruction" ] = Value (dtype = "string" , id = None )
313
269
314
270
features ["action" ] = Sequence (
315
271
length = data_dict ["action" ].shape [1 ], feature = Value (dtype = "float32" , id = None )
316
272
)
273
+
274
+ features ["is_terminal" ] = Value (dtype = "bool" , id = None )
275
+ features ["is_first" ] = Value (dtype = "bool" , id = None )
276
+ features ["discount" ] = Value (dtype = "float32" , id = None )
277
+
317
278
features ["episode_index" ] = Value (dtype = "int64" , id = None )
318
279
features ["frame_index" ] = Value (dtype = "int64" , id = None )
319
280
features ["timestamp" ] = Value (dtype = "float32" , id = None )
@@ -333,19 +294,9 @@ def from_raw_to_lerobot_format(
333
294
video : bool = True ,
334
295
episodes : list [int ] | None = None ,
335
296
encoding : dict | None = None ,
336
- openx_dataset_name : str | None = None ,
337
297
):
338
- """This is a test impl for rlds conversion"""
339
- if openx_dataset_name is None :
340
- # set a default rlds frame rate if the dataset is not from openx
341
- fps = 30
342
- elif "fps" not in OPENX_DATASET_CONFIGS [openx_dataset_name ]:
343
- raise ValueError (
344
- "fps for this dataset is not specified in openx/configs.py yet," "means it is not yet tested"
345
- )
346
- fps = OPENX_DATASET_CONFIGS [openx_dataset_name ]["fps" ]
347
-
348
- data_dict = load_from_raw (raw_dir , videos_dir , fps , video , episodes , encoding , openx_dataset_name )
298
+
299
+ data_dict = load_from_raw (raw_dir , videos_dir , fps , video , episodes , encoding )
349
300
hf_dataset = to_hf_dataset (data_dict , video )
350
301
episode_data_index = calculate_episode_data_index (hf_dataset )
351
302
info = {
0 commit comments