diff --git a/pyproject.toml b/pyproject.toml index 210d70b6bc..c4b1c547e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -360,9 +360,9 @@ ignore_errors = false module = "lerobot.cameras.*" ignore_errors = false -# [[tool.mypy.overrides]] -# module = "lerobot.motors.*" -# ignore_errors = false +[[tool.mypy.overrides]] +module = "lerobot.motors.*" +ignore_errors = false # [[tool.mypy.overrides]] # module = "lerobot.robots.*" diff --git a/src/lerobot/motors/calibration_gui.py b/src/lerobot/motors/calibration_gui.py index 02bba454f2..3410cb28ad 100644 --- a/src/lerobot/motors/calibration_gui.py +++ b/src/lerobot/motors/calibration_gui.py @@ -221,7 +221,7 @@ def __init__(self, bus: MotorsBus, groups: dict[str, list[str]] | None = None): self.bus = bus self.groups = groups if groups is not None else {"all": list(bus.motors)} - self.group_names = list(groups) + self.group_names = list(self.groups) self.current_group = self.group_names[0] if not bus.is_connected: @@ -230,18 +230,20 @@ def __init__(self, bus: MotorsBus, groups: dict[str, list[str]] | None = None): self.calibration = bus.read_calibration() self.res_table = bus.model_resolution_table self.present_cache = { - m: bus.read("Present_Position", m, normalize=False) for motors in groups.values() for m in motors + m: bus.read("Present_Position", m, normalize=False) + for motors in self.groups.values() + for m in motors } pygame.init() self.font = pygame.font.Font(None, FONT_SIZE) - label_pad = max(self.font.size(m)[0] for ms in groups.values() for m in ms) + label_pad = max(self.font.size(m)[0] for ms in self.groups.values() for m in ms) self.label_pad = label_pad width = 40 + label_pad + BAR_LEN + 6 + BTN_W + 10 + SAVE_W + 10 self.controls_bottom = 10 + SAVE_H self.base_y = self.controls_bottom + TOP_GAP - height = self.base_y + PADDING_Y * len(groups[self.current_group]) + 40 + height = self.base_y + PADDING_Y * len(self.groups[self.current_group]) + 40 self.screen = pygame.display.set_mode((width, height)) pygame.display.set_caption("Motors range finder") diff --git a/src/lerobot/motors/dynamixel/dynamixel.py b/src/lerobot/motors/dynamixel/dynamixel.py index c6752ee96f..3e1500b70a 100644 --- a/src/lerobot/motors/dynamixel/dynamixel.py +++ b/src/lerobot/motors/dynamixel/dynamixel.py @@ -181,10 +181,10 @@ def read_calibration(self) -> dict[str, MotorCalibration]: for motor, m in self.motors.items(): calibration[motor] = MotorCalibration( id=m.id, - drive_mode=drive_modes[motor], - homing_offset=offsets[motor], - range_min=mins[motor], - range_max=maxes[motor], + drive_mode=int(drive_modes[motor]), + homing_offset=int(offsets[motor]), + range_min=int(mins[motor]), + range_max=int(maxes[motor]), ) return calibration @@ -198,7 +198,7 @@ def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache if cache: self.calibration = calibration_dict - def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None: for motor in self._get_motors_list(motors): self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry) @@ -206,7 +206,7 @@ def _disable_torque(self, motor: int, model: str, num_retry: int = 0) -> None: addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable") self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry) - def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None: for motor in self._get_motors_list(motors): self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry) @@ -230,12 +230,12 @@ def _decode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, return ids_values - def _get_half_turn_homings(self, positions: dict[NameOrID, Value]) -> dict[NameOrID, Value]: + def _get_half_turn_homings(self, positions: dict[str, Value]) -> dict[str, Value]: """ On Dynamixel Motors: Present_Position = Actual_Position + Homing_Offset """ - half_turn_homings = {} + half_turn_homings: dict[str, Value] = {} for motor, pos in positions.items(): model = self._get_motor_model(motor) max_res = self.model_resolution_table[model] - 1 @@ -258,6 +258,6 @@ def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> di if raise_on_error: raise ConnectionError(self.packet_handler.getTxRxResult(comm)) - return + return None return {id_: data[0] for id_, data in data_list.items()} diff --git a/src/lerobot/motors/feetech/feetech.py b/src/lerobot/motors/feetech/feetech.py index 7ce3388b6b..42bf3ce544 100644 --- a/src/lerobot/motors/feetech/feetech.py +++ b/src/lerobot/motors/feetech/feetech.py @@ -126,7 +126,7 @@ def __init__( self.port_handler = scs.PortHandler(self.port) # HACK: monkeypatch - self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( + self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( # type: ignore[method-assign] self.port_handler, scs.PortHandler ) self.packet_handler = scs.PacketHandler(protocol_version) @@ -262,9 +262,9 @@ def read_calibration(self) -> dict[str, MotorCalibration]: calibration[motor] = MotorCalibration( id=m.id, drive_mode=0, - homing_offset=offsets[motor], - range_min=mins[motor], - range_max=maxes[motor], + homing_offset=int(offsets[motor]), + range_min=int(mins[motor]), + range_max=int(maxes[motor]), ) return calibration @@ -279,12 +279,12 @@ def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache if cache: self.calibration = calibration_dict - def _get_half_turn_homings(self, positions: dict[NameOrID, Value]) -> dict[NameOrID, Value]: + def _get_half_turn_homings(self, positions: dict[str, Value]) -> dict[str, Value]: """ On Feetech Motors: Present_Position = Actual_Position - Homing_Offset """ - half_turn_homings = {} + half_turn_homings: dict[str, Value] = {} for motor, pos in positions.items(): model = self._get_motor_model(motor) max_res = self.model_resolution_table[model] - 1 @@ -292,7 +292,7 @@ def _get_half_turn_homings(self, positions: dict[NameOrID, Value]) -> dict[NameO return half_turn_homings - def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None: for motor in self._get_motors_list(motors): self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry) self.write("Lock", motor, 0, num_retry=num_retry) @@ -303,7 +303,7 @@ def _disable_torque(self, motor: int, model: str, num_retry: int = 0) -> None: addr, length = get_address(self.model_ctrl_table, model, "Lock") self._write(addr, length, motor, 0, num_retry=num_retry) - def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None: for motor in self._get_motors_list(motors): self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry) self.write("Lock", motor, 1, num_retry=num_retry) @@ -334,7 +334,7 @@ def _split_into_byte_chunks(self, value: int, length: int) -> list[int]: def _broadcast_ping(self) -> tuple[dict[int, int], int]: import scservo_sdk as scs - data_list = {} + data_list: dict[int, int] = {} status_length = 6 @@ -414,7 +414,7 @@ def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> di if not self._is_comm_success(comm): if raise_on_error: raise ConnectionError(self.packet_handler.getTxRxResult(comm)) - return + return None ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)} if ids_errors: diff --git a/src/lerobot/motors/motors_bus.py b/src/lerobot/motors/motors_bus.py index c04f718b63..5d81122fd9 100644 --- a/src/lerobot/motors/motors_bus.py +++ b/src/lerobot/motors/motors_bus.py @@ -28,6 +28,7 @@ from enum import Enum from functools import cached_property from pprint import pformat +from collections.abc import Mapping, Sequence from typing import Protocol, TypeAlias import serial @@ -179,15 +180,16 @@ class Motor: class PortHandler(Protocol): - def __init__(self, port_name): - self.is_open: bool - self.baudrate: int - self.packet_start_time: float - self.packet_timeout: float - self.tx_time_per_byte: float - self.is_using: bool - self.port_name: str - self.ser: serial.Serial + is_open: bool + baudrate: int + packet_start_time: float + packet_timeout: float + tx_time_per_byte: float + is_using: bool + port_name: str + ser: serial.Serial + + def __init__(self, port_name: str) -> None: ... def openPort(self): ... def closePort(self): ... @@ -240,19 +242,20 @@ def regWriteTxOnly(self, port, id, address, length, data): ... def regWriteTxRx(self, port, id, address, length, data): ... def syncReadTx(self, port, start_address, data_length, param, param_length): ... def syncWriteTxOnly(self, port, start_address, data_length, param, param_length): ... + def broadcastPing(self, port): ... class GroupSyncRead(Protocol): - def __init__(self, port, ph, start_address, data_length): - self.port: str - self.ph: PortHandler - self.start_address: int - self.data_length: int - self.last_result: bool - self.is_param_changed: bool - self.param: list - self.data_dict: dict - + port: str + ph: PortHandler + start_address: int + data_length: int + last_result: bool + is_param_changed: bool + param: list + data_dict: dict + + def __init__(self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int) -> None: ... def makeParam(self): ... def addParam(self, id): ... def removeParam(self, id): ... @@ -265,15 +268,15 @@ def getData(self, id, address, data_length): ... class GroupSyncWrite(Protocol): - def __init__(self, port, ph, start_address, data_length): - self.port: str - self.ph: PortHandler - self.start_address: int - self.data_length: int - self.is_param_changed: bool - self.param: list - self.data_dict: dict - + port: str + ph: PortHandler + start_address: int + data_length: int + is_param_changed: bool + param: list + data_dict: dict + + def __init__(self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int) -> None: ... def makeParam(self): ... def addParam(self, id, data): ... def removeParam(self, id): ... @@ -400,7 +403,7 @@ def _get_motor_id(self, motor: NameOrID) -> int: else: raise TypeError(f"'{motor}' should be int, str.") - def _get_motor_model(self, motor: NameOrID) -> int: + def _get_motor_model(self, motor: NameOrID) -> str: if isinstance(motor, str): return self.motors[motor].model elif isinstance(motor, int): @@ -408,17 +411,19 @@ def _get_motor_model(self, motor: NameOrID) -> int: else: raise TypeError(f"'{motor}' should be int, str.") - def _get_motors_list(self, motors: str | list[str] | None) -> list[str]: + def _get_motors_list(self, motors: NameOrID | Sequence[NameOrID] | None) -> list[str]: if motors is None: return list(self.motors) elif isinstance(motors, str): return [motors] - elif isinstance(motors, list): - return motors.copy() + elif isinstance(motors, int): + return [self._id_to_name(motors)] + elif isinstance(motors, Sequence): + return [m if isinstance(m, str) else self._id_to_name(m) for m in motors] else: raise TypeError(motors) - def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]: + def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> dict[int, Value]: if isinstance(values, (int | float)): return dict.fromkeys(self.ids, values) elif isinstance(values, dict): @@ -640,11 +645,12 @@ def _disable_torque(self, motor: int, model: str, num_retry: int = 0) -> None: pass @abc.abstractmethod - def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None: """Enable torque on selected motors. Args: - motor (int): Same semantics as :pymeth:`disable_torque`. Defaults to `None`. + motors (int | str | list[str] | None, optional): Same semantics as :pymeth:`disable_torque`. + Defaults to `None`. num_retry (int, optional): Number of additional retry attempts on communication failure. Defaults to 0. """ @@ -728,24 +734,19 @@ def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache """ pass - def reset_calibration(self, motors: NameOrID | list[NameOrID] | None = None) -> None: + def reset_calibration(self, motors: NameOrID | Sequence[NameOrID] | None = None) -> None: """Restore factory calibration for the selected motors. Homing offset is set to ``0`` and min/max position limits are set to the full usable range. The in-memory :pyattr:`calibration` is cleared. Args: - motors (NameOrID | list[NameOrID] | None, optional): Selection of motors. `None` (default) + motors (NameOrID | Sequence[NameOrID] | None, optional): Selection of motors. `None` (default) resets every motor. """ - if motors is None: - motors = list(self.motors) - elif isinstance(motors, (str | int)): - motors = [motors] - elif not isinstance(motors, list): - raise TypeError(motors) + motor_names = self._get_motors_list(motors) - for motor in motors: + for motor in motor_names: model = self._get_motor_model(motor) max_res = self.model_resolution_table[model] - 1 self.write("Homing_Offset", motor, 0, normalize=False) @@ -754,7 +755,7 @@ def reset_calibration(self, motors: NameOrID | list[NameOrID] | None = None) -> self.calibration = {} - def set_half_turn_homings(self, motors: NameOrID | list[NameOrID] | None = None) -> dict[NameOrID, Value]: + def set_half_turn_homings(self, motors: NameOrID | Sequence[NameOrID] | None = None) -> dict[str, Value]: """Centre each motor range around its current position. The function computes and writes a homing offset such that the present position becomes exactly one @@ -764,17 +765,12 @@ def set_half_turn_homings(self, motors: NameOrID | list[NameOrID] | None = None) motors (NameOrID | list[NameOrID] | None, optional): Motors to adjust. Defaults to all motors (`None`). Returns: - dict[NameOrID, Value]: Mapping *motor → written homing offset*. + dict[str, Value]: Mapping *motor name → written homing offset*. """ - if motors is None: - motors = list(self.motors) - elif isinstance(motors, (str | int)): - motors = [motors] - elif not isinstance(motors, list): - raise TypeError(motors) + motor_names = self._get_motors_list(motors) - self.reset_calibration(motors) - actual_positions = self.sync_read("Present_Position", motors, normalize=False) + self.reset_calibration(motor_names) + actual_positions = self.sync_read("Present_Position", motor_names, normalize=False) homing_offsets = self._get_half_turn_homings(actual_positions) for motor, offset in homing_offsets.items(): self.write("Homing_Offset", motor, offset) @@ -782,12 +778,12 @@ def set_half_turn_homings(self, motors: NameOrID | list[NameOrID] | None = None) return homing_offsets @abc.abstractmethod - def _get_half_turn_homings(self, positions: dict[NameOrID, Value]) -> dict[NameOrID, Value]: + def _get_half_turn_homings(self, positions: dict[str, Value]) -> dict[str, Value]: pass def record_ranges_of_motion( - self, motors: NameOrID | list[NameOrID] | None = None, display_values: bool = True - ) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: + self, motors: NameOrID | Sequence[NameOrID] | None = None, display_values: bool = True + ) -> tuple[dict[str, Value], dict[str, Value]]: """Interactively record the min/max encoder values of each motor. Move the joints by hand (with torque disabled) while the method streams live positions. Press @@ -799,30 +795,25 @@ def record_ranges_of_motion( display_values (bool, optional): When `True` (default) a live table is printed to the console. Returns: - tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: Two dictionaries *mins* and *maxes* with the + tuple[dict[str, Value], dict[str, Value]]: Two dictionaries *mins* and *maxes* with the extreme values observed for each motor. """ - if motors is None: - motors = list(self.motors) - elif isinstance(motors, (str | int)): - motors = [motors] - elif not isinstance(motors, list): - raise TypeError(motors) + motor_names = self._get_motors_list(motors) - start_positions = self.sync_read("Present_Position", motors, normalize=False) + start_positions = self.sync_read("Present_Position", motor_names, normalize=False) mins = start_positions.copy() maxes = start_positions.copy() user_pressed_enter = False while not user_pressed_enter: - positions = self.sync_read("Present_Position", motors, normalize=False) + positions = self.sync_read("Present_Position", motor_names, normalize=False) mins = {motor: min(positions[motor], min_) for motor, min_ in mins.items()} maxes = {motor: max(positions[motor], max_) for motor, max_ in maxes.items()} if display_values: print("\n-------------------------------------------") print(f"{'NAME':<15} | {'MIN':>6} | {'POS':>6} | {'MAX':>6}") - for motor in motors: + for motor in motor_names: print(f"{motor:<15} | {mins[motor]:>6} | {positions[motor]:>6} | {maxes[motor]:>6}") if enter_pressed(): @@ -830,15 +821,15 @@ def record_ranges_of_motion( if display_values and not user_pressed_enter: # Move cursor up to overwrite the previous output - move_cursor_up(len(motors) + 3) + move_cursor_up(len(motor_names) + 3) - same_min_max = [motor for motor in motors if mins[motor] == maxes[motor]] + same_min_max = [motor for motor in motor_names if mins[motor] == maxes[motor]] if same_min_max: raise ValueError(f"Some motors have the same min and max values:\n{pformat(same_min_max)}") return mins, maxes - def _normalize(self, ids_values: dict[int, int]) -> dict[int, float]: + def _normalize(self, ids_values: Mapping[int, Value]) -> dict[int, float]: if not self.calibration: raise RuntimeError(f"{self} has no calibration registered.") @@ -955,12 +946,12 @@ def ping(self, motor: NameOrID, num_retry: int = 0, raise_on_error: bool = False if raise_on_error: raise ConnectionError(self.packet_handler.getTxRxResult(comm)) else: - return + return None if self._is_error(error): if raise_on_error: raise RuntimeError(self.packet_handler.getRxPacketError(error)) else: - return + return None return model_number @@ -1007,12 +998,13 @@ def read( err_msg = f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries." value, _, _ = self._read(addr, length, id_, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) - id_value = self._decode_sign(data_name, {id_: value}) + decoded = self._decode_sign(data_name, {id_: value}) if normalize and data_name in self.normalized_data: - id_value = self._normalize(id_value) + normalized = self._normalize(decoded) + return normalized[id_] - return id_value[id_] + return decoded[id_] def _read( self, @@ -1023,7 +1015,7 @@ def _read( num_retry: int = 0, raise_on_error: bool = True, err_msg: str = "", - ) -> tuple[int, int]: + ) -> tuple[int, int, int]: if length == 1: read_fn = self.packet_handler.read1ByteTxRx elif length == 2: @@ -1073,13 +1065,16 @@ def write( model = self.motors[motor].model addr, length = get_address(self.model_ctrl_table, model, data_name) + int_value: int if normalize and data_name in self.normalized_data: - value = self._unnormalize({id_: value})[id_] + int_value = self._unnormalize({id_: value})[id_] + else: + int_value = int(value) - value = self._encode_sign(data_name, {id_: value})[id_] + int_value = self._encode_sign(data_name, {id_: int_value})[id_] - err_msg = f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries." - self._write(addr, length, id_, value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) + err_msg = f"Failed to write '{data_name}' on {id_=} with '{int_value}' after {num_retry + 1} tries." + self._write(addr, length, id_, int_value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) def _write( self, @@ -1113,7 +1108,7 @@ def _write( def sync_read( self, data_name: str, - motors: str | list[str] | None = None, + motors: NameOrID | Sequence[NameOrID] | None = None, *, normalize: bool = True, num_retry: int = 0, @@ -1122,7 +1117,7 @@ def sync_read( Args: data_name (str): Register name. - motors (str | list[str] | None, optional): Motors to query. `None` (default) reads every motor. + motors (NameOrID | Sequence[NameOrID] | None, optional): Motors to query. `None` (default) reads every motor. normalize (bool, optional): Normalisation flag. Defaults to `True`. num_retry (int, optional): Retry attempts. Defaults to `0`. @@ -1143,16 +1138,17 @@ def sync_read( addr, length = get_address(self.model_ctrl_table, model, data_name) err_msg = f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries." - ids_values, _ = self._sync_read( + raw_ids_values, _ = self._sync_read( addr, length, ids, num_retry=num_retry, raise_on_error=True, err_msg=err_msg ) - ids_values = self._decode_sign(data_name, ids_values) + decoded = self._decode_sign(data_name, raw_ids_values) if normalize and data_name in self.normalized_data: - ids_values = self._normalize(ids_values) + normalized = self._normalize(decoded) + return {self._id_to_name(id_): value for id_, value in normalized.items()} - return {self._id_to_name(id_): value for id_, value in ids_values.items()} + return {self._id_to_name(id_): value for id_, value in decoded.items()} def _sync_read( self, @@ -1224,21 +1220,24 @@ def sync_write( num_retry (int, optional): Retry attempts. Defaults to `0`. """ - ids_values = self._get_ids_values_dict(values) - models = [self._id_to_model(id_) for id_ in ids_values] + raw_ids_values = self._get_ids_values_dict(values) + models = [self._id_to_model(id_) for id_ in raw_ids_values] if self._has_different_ctrl_tables: assert_same_address(self.model_ctrl_table, models, data_name) model = next(iter(models)) addr, length = get_address(self.model_ctrl_table, model, data_name) + int_ids_values: dict[int, int] if normalize and data_name in self.normalized_data: - ids_values = self._unnormalize(ids_values) + int_ids_values = self._unnormalize(raw_ids_values) + else: + int_ids_values = {id_: int(val) for id_, val in raw_ids_values.items()} - ids_values = self._encode_sign(data_name, ids_values) + int_ids_values = self._encode_sign(data_name, int_ids_values) - err_msg = f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries." - self._sync_write(addr, length, ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) + err_msg = f"Failed to sync write '{data_name}' with ids_values={int_ids_values} after {num_retry + 1} tries." + self._sync_write(addr, length, int_ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) def _sync_write( self,