Skip to content

Commit b1e9c01

Browse files
remove openx_dataset_name
1 parent b78832b commit b1e9c01

File tree

2 files changed

+14
-22
lines changed

2 files changed

+14
-22
lines changed

lerobot/common/datasets/push_dataset_to_hub/openx_rlds_format.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,19 @@
5555

5656
np.set_printoptions(precision=2)
5757

58+
5859
def tf_to_torch(data):
5960
return torch.from_numpy(data.numpy())
6061

62+
6163
def tf_img_convert(img):
6264
if img.dtype == tf.string:
6365
img = tf.io.decode_image(img, expand_animations=False, dtype=tf.uint8)
6466
elif img.dtype != tf.uint8:
6567
raise ValueError(f"Unsupported image dtype: found with dtype {img.dtype}")
6668
return img.numpy()
6769

70+
6871
def _broadcast_metadata_rlds(i: tf.Tensor, traj: dict) -> dict:
6972
"""
7073
In the RLDS format, each trajectory has some top-level metadata that is explicitly separated out, and a "steps"
@@ -130,7 +133,7 @@ def load_from_raw(
130133
# search for 'image' keys in the observations
131134
image_keys = []
132135
state_keys = []
133-
observation_info = dataset_info.features['steps']['observation']
136+
observation_info = dataset_info.features["steps"]["observation"]
134137
for key in observation_info:
135138
# check whether the key is for an image or a vector observation
136139
if len(observation_info[key].shape) == 3:
@@ -254,7 +257,7 @@ def load_from_raw(
254257

255258
def to_hf_dataset(data_dict, video) -> Dataset:
256259
features = {}
257-
260+
258261
for key in data_dict:
259262
# check if vector state obs
260263
if key.startswith("observation.") and "observation.images." not in key:
@@ -272,7 +275,7 @@ def to_hf_dataset(data_dict, video) -> Dataset:
272275
features["action"] = Sequence(
273276
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
274277
)
275-
278+
276279
features["is_terminal"] = Value(dtype="bool", id=None)
277280
features["is_first"] = Value(dtype="bool", id=None)
278281
features["discount"] = Value(dtype="float32", id=None)
@@ -297,7 +300,6 @@ def from_raw_to_lerobot_format(
297300
episodes: list[int] | None = None,
298301
encoding: dict | None = None,
299302
):
300-
301303
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
302304
hf_dataset = to_hf_dataset(data_dict, video)
303305
episode_data_index = calculate_episode_data_index(hf_dataset)

lerobot/scripts/push_dataset_to_hub.py

+8-18
Original file line numberDiff line numberDiff line change
@@ -200,24 +200,14 @@ def push_dataset_to_hub(
200200
# convert dataset from original raw format to LeRobot format
201201
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
202202

203-
fmt_kwgs = {
204-
"raw_dir": raw_dir,
205-
"videos_dir": videos_dir,
206-
"fps": fps,
207-
"video": video,
208-
"episodes": episodes,
209-
"encoding": encoding,
210-
}
211-
212-
if "openx_rlds." in raw_format:
213-
# Support for official OXE dataset name inside `raw_format`.
214-
# For instance, `raw_format="oxe_rlds"` uses the default formating (TODO what does that mean?),
215-
# and `raw_format="oxe_rlds.bridge_orig"` uses the brdige_orig formating
216-
_, openx_dataset_name = raw_format.split(".")
217-
print(f"Converting dataset [{openx_dataset_name}] from 'openx_rlds' to LeRobot format.")
218-
fmt_kwgs["openx_dataset_name"] = openx_dataset_name
219-
220-
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(**fmt_kwgs)
203+
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
204+
raw_dir,
205+
videos_dir,
206+
fps,
207+
video,
208+
episodes,
209+
encoding,
210+
)
221211

222212
lerobot_dataset = LeRobotDataset.from_preloaded(
223213
repo_id=repo_id,

0 commit comments

Comments
 (0)