Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
!tests/data
!tests/artifacts
htmlcov/
.tox/
.nox/
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
!tests/data
!tests/artifacts
htmlcov/
.tox/
.nox/
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

exclude: ^(tests/data)
exclude: "tests/artifacts/.*\\.safetensors$"
default_language_version:
python: python3.10
repos:
Expand Down
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ sudo apt-get install git-lfs
git lfs install
```

Pull artifacts if they're not in [tests/data](tests/data)
Pull artifacts if they're not in [tests/artifacts](tests/artifacts)
```bash
git lfs pull
```
Expand Down
10 changes: 5 additions & 5 deletions lerobot/common/robot_devices/cameras/intelrealsense.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def find_cameras(raise_when_empty=True, mock=False) -> list[dict]:
connected to the computer.
"""
if mock:
import tests.mock_pyrealsense2 as rs
import tests.cameras.mock_pyrealsense2 as rs
else:
import pyrealsense2 as rs

Expand Down Expand Up @@ -100,7 +100,7 @@ def save_images_from_cameras(
serial_numbers = [cam["serial_number"] for cam in camera_infos]

if mock:
import tests.mock_cv2 as cv2
import tests.cameras.mock_cv2 as cv2
else:
import cv2

Expand Down Expand Up @@ -253,7 +253,7 @@ def __init__(
self.logs = {}

if self.mock:
import tests.mock_cv2 as cv2
import tests.cameras.mock_cv2 as cv2
else:
import cv2

Expand Down Expand Up @@ -287,7 +287,7 @@ def connect(self):
)

if self.mock:
import tests.mock_pyrealsense2 as rs
import tests.cameras.mock_pyrealsense2 as rs
else:
import pyrealsense2 as rs

Expand Down Expand Up @@ -375,7 +375,7 @@ def read(self, temporary_color: str | None = None) -> np.ndarray | tuple[np.ndar
)

if self.mock:
import tests.mock_cv2 as cv2
import tests.cameras.mock_cv2 as cv2
else:
import cv2

Expand Down
8 changes: 4 additions & 4 deletions lerobot/common/robot_devices/cameras/opencv.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _find_cameras(
possible_camera_ids: list[int | str], raise_when_empty=False, mock=False
) -> list[int | str]:
if mock:
import tests.mock_cv2 as cv2
import tests.cameras.mock_cv2 as cv2
else:
import cv2

Expand Down Expand Up @@ -269,7 +269,7 @@ def __init__(self, config: OpenCVCameraConfig):
self.logs = {}

if self.mock:
import tests.mock_cv2 as cv2
import tests.cameras.mock_cv2 as cv2
else:
import cv2

Expand All @@ -286,7 +286,7 @@ def connect(self):
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")

if self.mock:
import tests.mock_cv2 as cv2
import tests.cameras.mock_cv2 as cv2
else:
import cv2

Expand Down Expand Up @@ -398,7 +398,7 @@ def read(self, temporary_color_mode: str | None = None) -> np.ndarray:
# so we convert the image color from BGR to RGB.
if requested_color_mode == "rgb":
if self.mock:
import tests.mock_cv2 as cv2
import tests.cameras.mock_cv2 as cv2
else:
import cv2

Expand Down
12 changes: 6 additions & 6 deletions lerobot/common/robot_devices/motors/dynamixel.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def connect(self):
)

if self.mock:
import tests.mock_dynamixel_sdk as dxl
import tests.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl

Expand All @@ -356,7 +356,7 @@ def connect(self):

def reconnect(self):
if self.mock:
import tests.mock_dynamixel_sdk as dxl
import tests.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl

Expand Down Expand Up @@ -646,7 +646,7 @@ def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] |

def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
if self.mock:
import tests.mock_dynamixel_sdk as dxl
import tests.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl

Expand Down Expand Up @@ -691,7 +691,7 @@ def read(self, data_name, motor_names: str | list[str] | None = None):
start_time = time.perf_counter()

if self.mock:
import tests.mock_dynamixel_sdk as dxl
import tests.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl

Expand Down Expand Up @@ -757,7 +757,7 @@ def read(self, data_name, motor_names: str | list[str] | None = None):

def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
if self.mock:
import tests.mock_dynamixel_sdk as dxl
import tests.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl

Expand Down Expand Up @@ -793,7 +793,7 @@ def write(self, data_name, values: int | float | np.ndarray, motor_names: str |
start_time = time.perf_counter()

if self.mock:
import tests.mock_dynamixel_sdk as dxl
import tests.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl

Expand Down
12 changes: 6 additions & 6 deletions lerobot/common/robot_devices/motors/feetech.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def connect(self):
)

if self.mock:
import tests.mock_scservo_sdk as scs
import tests.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs

Expand All @@ -337,7 +337,7 @@ def connect(self):

def reconnect(self):
if self.mock:
import tests.mock_scservo_sdk as scs
import tests.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs

Expand Down Expand Up @@ -664,7 +664,7 @@ def avoid_rotation_reset(self, values, motor_names, data_name):

def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
if self.mock:
import tests.mock_scservo_sdk as scs
import tests.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs

Expand Down Expand Up @@ -702,7 +702,7 @@ def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_

def read(self, data_name, motor_names: str | list[str] | None = None):
if self.mock:
import tests.mock_scservo_sdk as scs
import tests.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs

Expand Down Expand Up @@ -782,7 +782,7 @@ def read(self, data_name, motor_names: str | list[str] | None = None):

def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
if self.mock:
import tests.mock_scservo_sdk as scs
import tests.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs

Expand Down Expand Up @@ -818,7 +818,7 @@ def write(self, data_name, values: int | float | np.ndarray, motor_names: str |
start_time = time.perf_counter()

if self.mock:
import tests.mock_scservo_sdk as scs
import tests.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs

Expand Down
25 changes: 1 addition & 24 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -102,30 +102,7 @@ requires-poetry = ">=2.1"
[tool.ruff]
line-length = 110
target-version = "py310"
exclude = [
"tests/data",
".bzr",
".direnv",
".eggs",
".git",
".git-rewrite",
".hg",
".mypy_cache",
".nox",
".pants.d",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"venv",
]
exclude = ["tests/artifacts/**/*.safetensors"]

[tool.ruff.lint]
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
doesnt need to be merged into the `main` branch. Then you need to run this script and update the tests artifacts.

Example usage:
`python tests/scripts/save_dataset_to_safetensors.py`
`python tests/artifacts/datasets/save_dataset_to_safetensors.py`
"""

import shutil
Expand Down Expand Up @@ -88,4 +88,4 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
"lerobot/nyu_franka_play_dataset",
"lerobot/cmu_stretch",
]:
save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors", repo_id=dataset)
save_dataset_to_safetensors("tests/artifacts/datasets", repo_id=dataset)
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from lerobot.common.utils.random_utils import seeded_context

ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
ARTIFACT_DIR = Path("tests/artifacts/image_transforms")
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,5 +141,5 @@ def save_policy_to_safetensors(output_dir: Path, ds_repo_id: str, policy_name: s
raise RuntimeError("No policies were provided!")
for ds_repo_id, policy, policy_kwargs, file_name_extra in artifacts_cfg:
ds_name = ds_repo_id.split("/")[-1]
output_dir = Path("tests/data/save_policy_to_safetensors") / f"{ds_name}_{policy}_{file_name_extra}"
output_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy}_{file_name_extra}"
save_policy_to_safetensors(output_dir, ds_repo_id, policy, policy_kwargs)
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion tests/test_cameras.py → tests/cameras/test_cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_camera(request, camera_type, mock):
camera.connect()

if mock:
import tests.mock_cv2 as cv2
import tests.cameras.mock_cv2 as cv2
else:
import cv2

Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions tests/test_datasets.py → tests/datasets/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,12 +473,12 @@ def test_flatten_unflatten_dict():
)
@require_x86_64_kernel
def test_backward_compatibility(repo_id):
"""The artifacts for this test have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
"""The artifacts for this test have been generated by `tests/artifacts/datasets/save_dataset_to_safetensors.py`."""

# TODO(rcadene, aliberts): remove dataset download
dataset = LeRobotDataset(repo_id, episodes=[0])

test_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id
test_dir = Path("tests/artifacts/datasets") / repo_id

def load_and_compare(i):
new_frame = dataset[i] # noqa: B023
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
save_all_transforms,
save_each_transform,
)
from tests.scripts.save_image_transforms_to_safetensors import ARTIFACT_DIR
from tests.artifacts.image_transforms.save_image_transforms_to_safetensors import ARTIFACT_DIR
from tests.utils import require_x86_64_kernel


Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
27 changes: 24 additions & 3 deletions tests/test_utils.py → tests/datasets/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python

# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -11,13 +13,32 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from datasets import Dataset
from huggingface_hub import DatasetCard

from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.common.datasets.utils import (
hf_transform_to_torch,
)
from lerobot.common.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch


def test_default_parameters():
card = create_lerobot_dataset_card()
assert isinstance(card, DatasetCard)
assert card.data.tags == ["LeRobot"]
assert card.data.task_categories == ["robotics"]
assert card.data.configs == [
{
"config_name": "default",
"data_files": "data/*/*.parquet",
}
]


def test_with_tags():
tags = ["tag1", "tag2"]
card = create_lerobot_dataset_card(tags=tags)
assert card.data.tags == ["LeRobot", "tag1", "tag2"]


def test_calculate_episode_data_index():
Expand Down
3 changes: 1 addition & 2 deletions tests/test_envs.py → tests/envs/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
import lerobot
from lerobot.common.envs.factory import make_env, make_env_config
from lerobot.common.envs.utils import preprocess_observation

from .utils import require_env
from tests.utils import require_env

OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]

Expand Down
File renamed without changes.
Loading
Loading