55
55
56
56
np .set_printoptions (precision = 2 )
57
57
58
+
58
59
def tf_to_torch (data ):
59
60
return torch .from_numpy (data .numpy ())
60
61
62
+
61
63
def tf_img_convert (img ):
62
64
if img .dtype == tf .string :
63
65
img = tf .io .decode_image (img , expand_animations = False , dtype = tf .uint8 )
64
66
elif img .dtype != tf .uint8 :
65
67
raise ValueError (f"Unsupported image dtype: found with dtype { img .dtype } " )
66
68
return img .numpy ()
67
69
70
+
68
71
def _broadcast_metadata_rlds (i : tf .Tensor , traj : dict ) -> dict :
69
72
"""
70
73
In the RLDS format, each trajectory has some top-level metadata that is explicitly separated out, and a "steps"
@@ -130,7 +133,7 @@ def load_from_raw(
130
133
# search for 'image' keys in the observations
131
134
image_keys = []
132
135
state_keys = []
133
- observation_info = dataset_info .features [' steps' ][ ' observation' ]
136
+ observation_info = dataset_info .features [" steps" ][ " observation" ]
134
137
for key in observation_info :
135
138
# check whether the key is for an image or a vector observation
136
139
if len (observation_info [key ].shape ) == 3 :
@@ -254,7 +257,7 @@ def load_from_raw(
254
257
255
258
def to_hf_dataset (data_dict , video ) -> Dataset :
256
259
features = {}
257
-
260
+
258
261
for key in data_dict :
259
262
# check if vector state obs
260
263
if key .startswith ("observation." ) and "observation.images." not in key :
@@ -272,7 +275,7 @@ def to_hf_dataset(data_dict, video) -> Dataset:
272
275
features ["action" ] = Sequence (
273
276
length = data_dict ["action" ].shape [1 ], feature = Value (dtype = "float32" , id = None )
274
277
)
275
-
278
+
276
279
features ["is_terminal" ] = Value (dtype = "bool" , id = None )
277
280
features ["is_first" ] = Value (dtype = "bool" , id = None )
278
281
features ["discount" ] = Value (dtype = "float32" , id = None )
@@ -297,7 +300,6 @@ def from_raw_to_lerobot_format(
297
300
episodes : list [int ] | None = None ,
298
301
encoding : dict | None = None ,
299
302
):
300
-
301
303
data_dict = load_from_raw (raw_dir , videos_dir , fps , video , episodes , encoding )
302
304
hf_dataset = to_hf_dataset (data_dict , video )
303
305
episode_data_index = calculate_episode_data_index (hf_dataset )
0 commit comments