Skip to content

Commit f43e5d0

Browse files
committed
Fix tests
1 parent 9ee8711 commit f43e5d0

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

lerobot/common/datasets/utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,10 @@ def append_jsonlines(data: dict, fpath: Path) -> None:
136136

137137

138138
def load_info(local_dir: Path) -> dict:
139-
return load_json(local_dir / INFO_PATH)
139+
info = load_json(local_dir / INFO_PATH)
140+
for ft in info["features"].values():
141+
ft["shape"] = tuple(ft["shape"])
142+
return info
140143

141144

142145
def load_stats(local_dir: Path) -> dict:

lerobot/common/robot_devices/control_utils.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import cv2
1414
import torch
1515
import tqdm
16+
from deepdiff import DeepDiff
1617
from termcolor import colored
1718

1819
from lerobot.common.datasets.image_writer import safe_stop_image_writer
@@ -333,16 +334,19 @@ def sanity_check_dataset_name(repo_id, policy):
333334
)
334335

335336

336-
def sanity_check_dataset_robot_compatibility(dataset, robot, fps, use_videos):
337+
def sanity_check_dataset_robot_compatibility(
338+
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool
339+
) -> None:
337340
fields = [
338-
("robot_type", dataset.meta.info["robot_type"], robot.robot_type),
339-
("fps", dataset.meta.info["fps"], fps),
341+
("robot_type", dataset.meta.robot_type, robot.robot_type),
342+
("fps", dataset.fps, fps),
340343
("features", dataset.features, get_features_from_robot(robot, use_videos)),
341344
]
342345

343346
mismatches = []
344347
for field, dataset_value, present_value in fields:
345-
if dataset_value != present_value:
348+
diff = DeepDiff(dataset_value, present_value)
349+
if diff:
346350
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
347351

348352
if mismatches:

0 commit comments

Comments
 (0)