|
13 | 13 | import cv2
|
14 | 14 | import torch
|
15 | 15 | import tqdm
|
| 16 | +from deepdiff import DeepDiff |
16 | 17 | from termcolor import colored
|
17 | 18 |
|
18 | 19 | from lerobot.common.datasets.image_writer import safe_stop_image_writer
|
@@ -333,16 +334,19 @@ def sanity_check_dataset_name(repo_id, policy):
|
333 | 334 | )
|
334 | 335 |
|
335 | 336 |
|
336 |
| -def sanity_check_dataset_robot_compatibility(dataset, robot, fps, use_videos): |
| 337 | +def sanity_check_dataset_robot_compatibility( |
| 338 | + dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool |
| 339 | +) -> None: |
337 | 340 | fields = [
|
338 |
| - ("robot_type", dataset.meta.info["robot_type"], robot.robot_type), |
339 |
| - ("fps", dataset.meta.info["fps"], fps), |
| 341 | + ("robot_type", dataset.meta.robot_type, robot.robot_type), |
| 342 | + ("fps", dataset.fps, fps), |
340 | 343 | ("features", dataset.features, get_features_from_robot(robot, use_videos)),
|
341 | 344 | ]
|
342 | 345 |
|
343 | 346 | mismatches = []
|
344 | 347 | for field, dataset_value, present_value in fields:
|
345 |
| - if dataset_value != present_value: |
| 348 | + diff = DeepDiff(dataset_value, present_value) |
| 349 | + if diff: |
346 | 350 | mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
|
347 | 351 |
|
348 | 352 | if mismatches:
|
|
0 commit comments