Skip to content

Commit f300054

Browse files
committed
simplify the test logic, and improve comments
1 parent dd39610 commit f300054

File tree

3 files changed

+10
-15
lines changed

3 files changed

+10
-15
lines changed

tests/artifacts/image_transforms/save_image_transforms_to_safetensors.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from lerobot.common.utils.random_utils import seeded_context
2929

3030
ARTIFACT_DIR = Path("tests/artifacts/image_transforms")
31-
# TODO: Fix lerobot/aloha_mobile_shrimp dataset
3231
DATASET_REPO_ID = "lerobot/aloha_static_cups_open"
3332

3433

@@ -43,16 +42,16 @@ def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path
4342

4443

4544
def save_single_transforms(original_frame: torch.Tensor, output_dir: Path):
46-
transforms = [
45+
transforms = {
4746
("ColorJitter", "brightness", [(0.5, 0.5), (2.0, 2.0)]),
4847
("ColorJitter", "contrast", [(0.5, 0.5), (2.0, 2.0)]),
4948
("ColorJitter", "saturation", [(0.5, 0.5), (2.0, 2.0)]),
5049
("ColorJitter", "hue", [(-0.25, -0.25), (0.25, 0.25)]),
5150
("SharpnessJitter", "sharpness", [(0.5, 0.5), (2.0, 2.0)]),
52-
]
51+
}
5352

5453
frames = {"original_frame": original_frame}
55-
for tf_type, tf_name, min_max_values in transforms:
54+
for tf_type, tf_name, min_max_values in transforms.items():
5655
for min_max in min_max_values:
5756
tf_cfg = ImageTransformConfig(type=tf_type, kwargs={tf_name: min_max})
5857
tf = make_transform_from_config(tf_cfg)

tests/datasets/test_image_transforms.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -254,15 +254,13 @@ def test_backward_compatibility_single_transforms(
254254

255255

256256
@require_x86_64_kernel
257+
@pytest.mark.skipif(
258+
version.parse(torch.__version__) < version.parse("2.7.0"),
259+
reason="Test artifacts were generated with PyTorch >= 2.7.0 which has different multinomial behavior",
260+
)
257261
def test_backward_compatibility_default_config(img_tensor, default_transforms):
258-
# This test depends on the behavior of torch.multinomial, which changed in PyTorch 2.7.0.
259-
# The test artifacts (default_transforms.safetensors) were generated using PyTorch >= 2.7.0.
260-
# For more details, see:
261-
# - PyTorch issue: https://github.com/pytorch/pytorch/issues/154031
262-
# - LeRobot PR: https://github.com/huggingface/lerobot/pull/1127
263-
# If running with PyTorch < 2.7.0, this test is expected to fail.
264-
if version.parse(torch.__version__) < version.parse("2.7.0"):
265-
pytest.skip(f"Skipping test with PyTorch {torch.__version__}. Requires PyTorch >= 2.7.0")
262+
# NOTE: PyTorch versions have different randomness, it might break this test.
263+
# See this PR: https://github.com/huggingface/lerobot/pull/1127.
266264

267265
cfg = ImageTransformsConfig(enable=True)
268266
default_tf = ImageTransforms(cfg)

tests/policies/test_policies.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -415,9 +415,7 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
415415
https://github.com/huggingface/lerobot/pull/1127.
416416
417417
"""
418-
# Skip the test for act policy if PyTorch version is below 2.7
419-
# The MultiheadSelfAttention randomness changed between PyTorch 2.6 and 2.7
420-
# The artifacts were generated with PyTorch 2.7, so the test will break with act policy if PyTorch is below 2.7
418+
# NOTE: ACT policy has different randomness, after PyTorch 2.7.0
421419
if policy_name == "act" and version.parse(torch.__version__) < version.parse("2.7.0"):
422420
pytest.skip(f"Skipping act policy test with PyTorch {torch.__version__}. Requires PyTorch >= 2.7.0")
423421

0 commit comments

Comments
 (0)