Skip to content

Commit d9d5c19

Browse files
removed need for OPENX_CONFIGS
1 parent 963738d commit d9d5c19

File tree

1 file changed

+45
-94
lines changed

1 file changed

+45
-94
lines changed

lerobot/common/datasets/push_dataset_to_hub/openx_rlds_format.py

+45-94
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
from PIL import Image as PILImage
4141

4242
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
43-
from lerobot.common.datasets.push_dataset_to_hub.openx.transforms import OPENX_STANDARDIZATION_TRANSFORMS
4443
from lerobot.common.datasets.push_dataset_to_hub.utils import (
4544
concatenate_episodes,
4645
get_default_encoding,
@@ -52,26 +51,18 @@
5251
)
5352
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
5453

55-
with open("lerobot/common/datasets/push_dataset_to_hub/openx/configs.yaml") as f:
56-
_openx_list = yaml.safe_load(f)
57-
58-
OPENX_DATASET_CONFIGS = _openx_list["OPENX_DATASET_CONFIGS"]
59-
6054
np.set_printoptions(precision=2)
6155

62-
6356
def tf_to_torch(data):
6457
return torch.from_numpy(data.numpy())
6558

66-
6759
def tf_img_convert(img):
6860
if img.dtype == tf.string:
6961
img = tf.io.decode_image(img, expand_animations=False, dtype=tf.uint8)
7062
elif img.dtype != tf.uint8:
7163
raise ValueError(f"Unsupported image dtype: found with dtype {img.dtype}")
7264
return img.numpy()
7365

74-
7566
def _broadcast_metadata_rlds(i: tf.Tensor, traj: dict) -> dict:
7667
"""
7768
In the RLDS format, each trajectory has some top-level metadata that is explicitly separated out, and a "steps"
@@ -108,7 +99,6 @@ def load_from_raw(
10899
video: bool,
109100
episodes: list[int] | None = None,
110101
encoding: dict | None = None,
111-
openx_dataset_name: str | None = None,
112102
):
113103
"""
114104
Args:
@@ -136,16 +126,17 @@ def load_from_raw(
136126
# we will apply the standardization transform if the dataset_name is provided
137127
# if the dataset name is not provided and the goal is to convert any rlds formatted dataset
138128
# search for 'image' keys in the observations
139-
if openx_dataset_name is not None:
140-
print(" - applying standardization transform for dataset: ", openx_dataset_name)
141-
assert openx_dataset_name in OPENX_STANDARDIZATION_TRANSFORMS
142-
transform_fn = OPENX_STANDARDIZATION_TRANSFORMS[openx_dataset_name]
143-
dataset = dataset.map(transform_fn)
144-
145-
image_keys = OPENX_DATASET_CONFIGS[openx_dataset_name]["image_obs_keys"]
146-
else:
147-
obs_keys = dataset_info.features["steps"]["observation"].keys()
148-
image_keys = [key for key in obs_keys if "image" in key]
129+
image_keys = []
130+
state_keys = []
131+
observation_info = dataset_info.features['steps']['observation']
132+
for key in observation_info:
133+
# check whether the key is for an image or a vector observation
134+
if len(observation_info[key].shape) == 3:
135+
# only adding uint8 images discards depth images
136+
if observation_info[key].dtype == tf.uint8:
137+
image_keys.append(key)
138+
else:
139+
state_keys.append(key)
149140

150141
lang_key = "language_instruction" if "language_instruction" in dataset.element_spec else None
151142

@@ -177,6 +168,8 @@ def load_from_raw(
177168
# convert episodes index to sorted list
178169
episodes = sorted(episodes)
179170

171+
breakpoint()
172+
180173
for ep_idx in tqdm.tqdm(range(starting_ep_idx, ds_length)):
181174
episode = next(it)
182175

@@ -193,49 +186,30 @@ def load_from_raw(
193186

194187
num_frames = episode["action"].shape[0]
195188

196-
###########################################################
197-
# Handle the episodic data
198-
199-
# last step of demonstration is considered done
200-
done = torch.zeros(num_frames, dtype=torch.bool)
201-
done[-1] = True
202189
ep_dict = {}
203-
langs = [] # TODO: might be located in "observation"
190+
for key in state_keys:
191+
ep_dict[f"observation.{key}"] = tf_to_torch(episode["observation"][key])
204192

205-
image_array_dict = {key: [] for key in image_keys}
206-
207-
# We will create the state observation tensor by stacking the state
208-
# obs keys defined in the openx/configs.py
209-
if openx_dataset_name is not None:
210-
state_obs_keys = OPENX_DATASET_CONFIGS[openx_dataset_name]["state_obs_keys"]
211-
# stack the state observations, if is None, pad with zeros
212-
states = []
213-
for key in state_obs_keys:
214-
if key in episode["observation"]:
215-
states.append(tf_to_torch(episode["observation"][key]))
216-
else:
217-
states.append(torch.zeros(num_frames, 1)) # pad with zeros
218-
states = torch.cat(states, dim=1)
219-
# assert states.shape == (num_frames, 8), f"states shape: {states.shape}"
220-
else:
221-
states = tf_to_torch(episode["observation"]["state"])
222-
223-
actions = tf_to_torch(episode["action"])
224-
rewards = tf_to_torch(episode["reward"]).float()
193+
ep_dict["action"] = tf_to_torch(episode["action"])
194+
ep_dict["next.reward"] = tf_to_torch(episode["reward"]).float()
195+
ep_dict["next.done"] = tf_to_torch(episode["is_last"])
196+
ep_dict["is_terminal"] = tf_to_torch(episode["is_terminal"])
197+
ep_dict["is_first"] = tf_to_torch(episode["is_first"])
198+
ep_dict["discount"] = tf_to_torch(episode["discount"])
225199

226200
# If lang_key is present, convert the entire tensor at once
227201
if lang_key is not None:
228-
langs = [str(x) for x in episode[lang_key]]
202+
ep_dict["language_instruction"] = [str(x) for x in episode[lang_key]]
203+
204+
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
205+
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
206+
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
229207

230208
for im_key in image_keys:
231209
imgs = episode["observation"][im_key]
232210
image_array_dict[im_key] = [tf_img_convert(img) for img in imgs]
233211

234-
# simple assertions
235-
for item in [states, actions, rewards, done]:
236-
assert len(item) == num_frames
237-
238-
###########################################################
212+
image_array_dict = {key: [] for key in image_keys}
239213

240214
# loop through all cameras
241215
for im_key in image_keys:
@@ -262,17 +236,6 @@ def load_from_raw(
262236
else:
263237
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
264238

265-
if lang_key is not None:
266-
ep_dict["language_instruction"] = langs
267-
268-
ep_dict["observation.state"] = states
269-
ep_dict["action"] = actions
270-
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
271-
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
272-
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
273-
ep_dict["next.reward"] = rewards
274-
ep_dict["next.done"] = done
275-
276239
path_ep_dict = tmp_ep_dicts_dir.joinpath(
277240
"ep_dict_" + "0" * (10 - len(str(ep_idx))) + str(ep_idx) + ".pt"
278241
)
@@ -289,31 +252,29 @@ def load_from_raw(
289252

290253
def to_hf_dataset(data_dict, video) -> Dataset:
291254
features = {}
255+
256+
for key in data_dict:
257+
# check if vector state obs
258+
if key.startswith("observation.") and "observation.images." not in key:
259+
features[key] = Sequence(length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None))
260+
# check if image obs
261+
elif "observation.images." in key:
262+
if video:
263+
features[key] = VideoFrame()
264+
else:
265+
features[key] = Image()
292266

293-
keys = [key for key in data_dict if "observation.images." in key]
294-
for key in keys:
295-
if video:
296-
features[key] = VideoFrame()
297-
else:
298-
features[key] = Image()
299-
300-
features["observation.state"] = Sequence(
301-
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
302-
)
303-
if "observation.velocity" in data_dict:
304-
features["observation.velocity"] = Sequence(
305-
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
306-
)
307-
if "observation.effort" in data_dict:
308-
features["observation.effort"] = Sequence(
309-
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
310-
)
311267
if "language_instruction" in data_dict:
312268
features["language_instruction"] = Value(dtype="string", id=None)
313269

314270
features["action"] = Sequence(
315271
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
316272
)
273+
274+
features["is_terminal"] = Value(dtype="bool", id=None)
275+
features["is_first"] = Value(dtype="bool", id=None)
276+
features["discount"] = Value(dtype="float32", id=None)
277+
317278
features["episode_index"] = Value(dtype="int64", id=None)
318279
features["frame_index"] = Value(dtype="int64", id=None)
319280
features["timestamp"] = Value(dtype="float32", id=None)
@@ -333,19 +294,9 @@ def from_raw_to_lerobot_format(
333294
video: bool = True,
334295
episodes: list[int] | None = None,
335296
encoding: dict | None = None,
336-
openx_dataset_name: str | None = None,
337297
):
338-
"""This is a test impl for rlds conversion"""
339-
if openx_dataset_name is None:
340-
# set a default rlds frame rate if the dataset is not from openx
341-
fps = 30
342-
elif "fps" not in OPENX_DATASET_CONFIGS[openx_dataset_name]:
343-
raise ValueError(
344-
"fps for this dataset is not specified in openx/configs.py yet," "means it is not yet tested"
345-
)
346-
fps = OPENX_DATASET_CONFIGS[openx_dataset_name]["fps"]
347-
348-
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding, openx_dataset_name)
298+
299+
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
349300
hf_dataset = to_hf_dataset(data_dict, video)
350301
episode_data_index = calculate_episode_data_index(hf_dataset)
351302
info = {

0 commit comments

Comments
 (0)