Skip to content

Commit 291bd89

Browse files
fix(robots): bootstrap dataset recording again and improve watchdog behavior
1 parent 5e28321 commit 291bd89

File tree

5 files changed

+132
-80
lines changed

5 files changed

+132
-80
lines changed

examples/robots/lekiwi_client_app.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@
1313
# limitations under the License.
1414

1515
import logging
16+
import time
1617

1718
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
1819
from lerobot.common.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
19-
from lerobot.common.robots.lekiwi.lekiwi_client import LeKiwiClient
20+
from lerobot.common.robots.lekiwi.lekiwi_client import OBS_STATE, LeKiwiClient
2021
from lerobot.common.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig
2122
from lerobot.common.teleoperators.so100 import SO100Leader, SO100LeaderConfig
2223

24+
NB_CYCLES_CLIENT_CONNECTION = 250
25+
2326

2427
def main():
2528
logging.info("Configuring Teleop Devices")
@@ -35,10 +38,20 @@ def main():
3538

3639
logging.info("Creating LeRobot Dataset")
3740

41+
# The observations that we get are expected to be in body frame (x,y,theta)
42+
obs_dict = {f"{OBS_STATE}." + key: value for key, value in robot.state_feature_client.items()}
43+
# The actions that we send are expected to be in wheel frame (motor encoders)
44+
act_dict = {"action." + key: value for key, value in robot.action_feature.items()}
45+
46+
features_dict = {
47+
**act_dict,
48+
**obs_dict,
49+
**robot.camera_features,
50+
}
3851
dataset = LeRobotDataset.create(
39-
repo_id="user/lekiwi",
52+
repo_id="user/lekiwi" + str(int(time.time())),
4053
fps=10,
41-
features={**robot.state_feature, **robot.camera_features},
54+
features=features_dict,
4255
)
4356

4457
logging.info("Connecting Teleop Devices")
@@ -54,7 +67,7 @@ def main():
5467

5568
logging.info("Starting LeKiwi teleoperation")
5669
i = 0
57-
while i < 1000:
70+
while i < NB_CYCLES_CLIENT_CONNECTION:
5871
arm_action = leader_arm.get_action()
5972
base_action = keyboard.get_action()
6073
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
@@ -63,19 +76,21 @@ def main():
6376
observation = robot.get_observation()
6477

6578
frame = {**action_sent, **observation}
66-
frame.update({"task": "Dummy Task Dataset"})
79+
frame.update({"task": "Dummy Example Task Dataset"})
6780

6881
logging.info("Saved a frame into the dataset")
6982
dataset.add_frame(frame)
7083
i += 1
7184

72-
dataset.save_episode()
73-
dataset.push_to_hub()
74-
7585
logging.info("Disconnecting Teleop Devices and LeKiwi Client")
7686
robot.disconnect()
7787
leader_arm.disconnect()
7888
keyboard.disconnect()
89+
90+
logging.info("Uploading dataset to the hub")
91+
dataset.save_episode()
92+
dataset.push_to_hub()
93+
7994
logging.info("Finished LeKiwi cleanly")
8095

8196

lerobot/common/robots/lekiwi/config_lekiwi.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ class LeKiwiConfig(RobotConfig):
3535
cameras: dict[str, CameraConfig] = field(
3636
default_factory=lambda: {
3737
"front": OpenCVCameraConfig(
38-
camera_index="/dev/video1", fps=30, width=640, height=480, rotation=90
38+
camera_index="/dev/video0", fps=30, width=640, height=480, rotation=None
3939
),
4040
"wrist": OpenCVCameraConfig(
41-
camera_index="/dev/video4", fps=30, width=640, height=480, rotation=180
41+
camera_index="/dev/video2", fps=30, width=640, height=480, rotation=180
4242
),
4343
}
4444
)
@@ -51,10 +51,10 @@ class LeKiwiHostConfig:
5151
port_zmq_observations: int = 5556
5252

5353
# Duration of the application
54-
connection_time_s: int = 100
54+
connection_time_s: int = 30
5555

5656
# Watchdog: stop the robot if no command is received for over 0.5 seconds.
57-
watchdog_timeout_s: int = 1
57+
watchdog_timeout_ms: int = 500
5858

5959
# If robot jitters decrease the frequency and monitor cpu load with `top` in cmd
6060
max_loop_freq_hz: int = 30

lerobot/common/robots/lekiwi/lekiwi.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing import Any
2020

2121
from lerobot.common.cameras.utils import make_cameras_from_configs
22-
from lerobot.common.constants import OBS_IMAGES
22+
from lerobot.common.constants import OBS_IMAGES, OBS_STATE
2323
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
2424
from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode
2525
from lerobot.common.motors.feetech import (
@@ -71,11 +71,18 @@ def __init__(self, config: LeKiwiConfig):
7171

7272
@property
7373
def state_feature(self) -> dict:
74-
return {
75-
"dtype": "float32",
76-
"shape": (len(self.bus),),
77-
"names": {"motors": list(self.bus.motors)},
74+
state_ft = {
75+
"arm_shoulder_pan": {"dtype": "float32"},
76+
"arm_shoulder_lift": {"dtype": "float32"},
77+
"arm_elbow_flex": {"dtype": "float32"},
78+
"arm_wrist_flex": {"dtype": "float32"},
79+
"arm_wrist_roll": {"dtype": "float32"},
80+
"arm_gripper": {"dtype": "float32"},
81+
"base_left_wheel": {"dtype": "float32"},
82+
"base_right_wheel": {"dtype": "float32"},
83+
"base_back_wheel": {"dtype": "float32"},
7884
}
85+
return state_ft
7986

8087
@property
8188
def action_feature(self) -> dict:
@@ -187,6 +194,7 @@ def get_observation(self) -> dict[str, Any]:
187194
arm_pos = self.bus.sync_read("Present_Position", self.arm_motors)
188195
base_vel = self.bus.sync_read("Present_Velocity", self.base_motors)
189196
obs_dict = {**arm_pos, **base_vel}
197+
obs_dict = {f"{OBS_STATE}." + key: value for key, value in obs_dict.items()}
190198
dt_ms = (time.perf_counter() - start) * 1e3
191199
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
192200

lerobot/common/robots/lekiwi/lekiwi_client.py

Lines changed: 70 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,12 @@
1515
import base64
1616
import json
1717
import logging
18-
from typing import Any
18+
from typing import Any, Dict, Optional, Tuple
1919

2020
import cv2
2121
import numpy as np
2222
import torch
2323
import zmq
24-
from typing import Tuple, Dict, Any, Optional
2524

2625
from lerobot.common.constants import OBS_IMAGES, OBS_STATE
2726
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
@@ -46,7 +45,7 @@ def __init__(self, config: LeKiwiClientConfig):
4645

4746
self.teleop_keys = config.teleop_keys
4847

49-
self.polling_timeot_ms = config.polling_timeout_ms
48+
self.polling_timeout_ms = config.polling_timeout_ms
5049
self.connect_timeout_s = config.connect_timeout_s
5150

5251
self.zmq_context = None
@@ -70,28 +69,46 @@ def __init__(self, config: LeKiwiClientConfig):
7069
self.logs = {}
7170

7271
@property
73-
def state_feature(self) -> dict:
74-
return {
75-
"dtype": "float32",
76-
"shape": (9,),
77-
"names": {
78-
"motors": [
79-
"arm_shoulder_pan",
80-
"arm_shoulder_lift",
81-
"arm_elbow_flex",
82-
"arm_wrist_flex",
83-
"arm_wrist_roll",
84-
"arm_gripper",
85-
"base_left_wheel",
86-
"base_right_wheel",
87-
"base_back_wheel",
88-
]
89-
},
72+
def state_feature_client(self) -> dict:
73+
state_ft = {
74+
"arm_shoulder_pan": {"shape": (1,), "info": None, "dtype": "float32"},
75+
"arm_shoulder_lift": {"shape": (1,), "info": None, "dtype": "float32"},
76+
"arm_elbow_flex": {"shape": (1,), "info": None, "dtype": "float32"},
77+
"arm_wrist_flex": {"shape": (1,), "info": None, "dtype": "float32"},
78+
"arm_wrist_roll": {"shape": (1,), "info": None, "dtype": "float32"},
79+
"arm_gripper": {"shape": (1,), "info": None, "dtype": "float32"},
80+
"x_cmd": {"shape": (1,), "info": None, "dtype": "float32"},
81+
"y_cmd": {"shape": (1,), "info": None, "dtype": "float32"},
82+
"theta_cmd": {"shape": (1,), "info": None, "dtype": "float32"},
83+
}
84+
return state_ft
85+
86+
@property
87+
def state_feature_host(self) -> dict:
88+
state_ft = {
89+
"arm_shoulder_pan": {"shape": (1,), "info": None, "dtype": "float32"},
90+
"arm_shoulder_lift": {"shape": (1,), "info": None, "dtype": "float32"},
91+
"arm_elbow_flex": {"shape": (1,), "info": None, "dtype": "float32"},
92+
"arm_wrist_flex": {"shape": (1,), "info": None, "dtype": "float32"},
93+
"arm_wrist_roll": {"shape": (1,), "info": None, "dtype": "float32"},
94+
"arm_gripper": {"shape": (1,), "info": None, "dtype": "float32"},
95+
"base_left_wheel": {"shape": (1,), "info": None, "dtype": "float32"},
96+
"base_right_wheel": {"shape": (1,), "info": None, "dtype": "float32"},
97+
"base_back_wheel": {"shape": (1,), "info": None, "dtype": "float32"},
9098
}
99+
return state_ft
100+
101+
@property
102+
def state_feature(self) -> dict:
103+
raise (
104+
NotImplementedError(
105+
"state_feature is not implemented for LeKiwiClient. Use state_feature_client or state_feature_host instead."
106+
)
107+
)
91108

92109
@property
93110
def action_feature(self) -> dict:
94-
return self.state_feature
111+
return self.state_feature_host
95112

96113
@property
97114
def camera_features(self) -> dict[str, dict]:
@@ -100,10 +117,12 @@ def camera_features(self) -> dict[str, dict]:
100117
"shape": (480, 640, 3),
101118
"names": ["height", "width", "channels"],
102119
"info": None,
120+
"dtype": "image",
103121
},
104122
f"{OBS_IMAGES}.wrist": {
105123
"shape": (480, 640, 3),
106124
"names": ["height", "width", "channels"],
125+
"dtype": "image",
107126
"info": None,
108127
},
109128
}
@@ -261,16 +280,15 @@ def _wheel_raw_to_body(
261280
velocity_vector = m_inv.dot(wheel_linear_speeds)
262281
x_cmd, y_cmd, theta_rad = velocity_vector
263282
theta_cmd = theta_rad * (180.0 / np.pi)
264-
return {"x_cmd": x_cmd, "y_cmd": y_cmd, "theta_cmd": theta_cmd}
265-
283+
return {f"{OBS_STATE}.x_cmd": x_cmd, f"{OBS_STATE}.y_cmd": y_cmd, f"{OBS_STATE}.theta_cmd": theta_cmd}
266284

267285
def _poll_and_get_latest_message(self) -> Optional[str]:
268286
"""Polls the ZMQ socket for a limited time and returns the latest message string."""
269287
poller = zmq.Poller()
270288
poller.register(self.zmq_observation_socket, zmq.POLLIN)
271-
289+
272290
try:
273-
socks = dict(poller.poll(self.polling_timeot_ms))
291+
socks = dict(poller.poll(self.polling_timeout_ms))
274292
except zmq.ZMQError as e:
275293
logging.error(f"ZMQ polling error: {e}")
276294
return None
@@ -291,7 +309,7 @@ def _poll_and_get_latest_message(self) -> Optional[str]:
291309
logging.warning("Poller indicated data, but failed to retrieve message.")
292310

293311
return last_msg
294-
312+
295313
def _parse_observation_json(self, obs_string: str) -> Optional[Dict[str, Any]]:
296314
"""Parses the JSON observation string."""
297315
try:
@@ -300,7 +318,6 @@ def _parse_observation_json(self, obs_string: str) -> Optional[Dict[str, Any]]:
300318
logging.error(f"Error decoding JSON observation: {e}")
301319
return None
302320

303-
304321
def _decode_image_from_b64(self, image_b64: str) -> Optional[np.ndarray]:
305322
"""Decodes a base64 encoded image string to an OpenCV image."""
306323
if not image_b64:
@@ -315,10 +332,12 @@ def _decode_image_from_b64(self, image_b64: str) -> Optional[np.ndarray]:
315332
except (TypeError, ValueError) as e:
316333
logging.error(f"Error decoding base64 image data: {e}")
317334
return None
318-
319-
def _process_observation_data(self, observation: Dict[str, Any]) -> Tuple[Dict[str, np.ndarray], Dict[str, Any], Dict[str, Any]]:
335+
336+
def _process_observation_data(
337+
self, observation: Dict[str, Any]
338+
) -> Tuple[Dict[str, np.ndarray], Dict[str, Any], Dict[str, Any]]:
320339
"""Extracts frames, speed, and arm state from the parsed observation."""
321-
340+
322341
# Separate image and state data
323342
image_observation = {k: v for k, v in observation.items() if k.startswith(OBS_IMAGES)}
324343
state_observation = {k: v for k, v in observation.items() if k.startswith(OBS_STATE)}
@@ -331,16 +350,11 @@ def _process_observation_data(self, observation: Dict[str, Any]) -> Tuple[Dict[s
331350
current_frames[cam_name] = frame
332351

333352
# Extract state components
334-
current_speed = {
335-
k: v for k, v in state_observation.items() if k.startswith(f"{OBS_STATE}.base")
336-
}
337-
current_arm_state = {
338-
k: v for k, v in state_observation.items() if k.startswith(f"{OBS_STATE}.arm")
339-
}
353+
current_speed = {k: v for k, v in state_observation.items() if k.startswith(f"{OBS_STATE}.base")}
354+
current_arm_state = {k: v for k, v in state_observation.items() if k.startswith(f"{OBS_STATE}.arm")}
340355

341356
return current_frames, current_speed, current_arm_state
342357

343-
344358
def _get_data(self) -> Tuple[Dict[str, np.ndarray], Dict[str, Any], Dict[str, Any]]:
345359
"""
346360
Polls the video socket for the latest observation data.
@@ -349,7 +363,7 @@ def _get_data(self) -> Tuple[Dict[str, np.ndarray], Dict[str, Any], Dict[str, An
349363
If successful, updates and returns the new frames, speed, and arm state.
350364
If no new data arrives or decoding fails, returns the last known values.
351365
"""
352-
366+
353367
# 1. Get the latest message string from the socket
354368
latest_message_str = self._poll_and_get_latest_message()
355369

@@ -386,27 +400,23 @@ def get_observation(self) -> dict[str, Any]:
386400
if not self._is_connected:
387401
raise DeviceNotConnectedError("LeKiwiClient is not connected. You need to run `robot.connect()`.")
388402

389-
# TODO(Steven): remove hard-coded cam names & dims
390-
# This is needed at init for when there's no comms
391-
obs_dict = {
392-
f"{OBS_IMAGES}.wrist": np.zeros(shape=(480, 640, 3)),
393-
f"{OBS_IMAGES}.front": np.zeros(shape=(640, 480, 3)),
394-
}
395-
396403
frames, present_speed, remote_arm_state_tensor = self._get_data()
397404
body_state = self._wheel_raw_to_body(present_speed)
398405
body_state_mm = {k: v * 1000.0 for k, v in body_state.items()} # Convert x,y to mm/s
399406

407+
obs_dict = {}
400408
obs_dict.update(remote_arm_state_tensor)
401409
obs_dict.update(body_state_mm)
402410

411+
# TODO(Steven): Remove this when it is possible to record a non-numpy array value
412+
obs_dict = {k: np.array([v], dtype=np.float32) for k, v in obs_dict.items()}
413+
403414
# Loop over each configured camera
404415
for cam_name, frame in frames.items():
405416
if frame is None:
406-
# TODO(Steven): Daemon doesn't know camera dimensions (hard-coded for now), consider at least getting them from state features
407417
logging.warning("Frame is None")
408-
frame = np.zeros((480, 640, 3), dtype=np.uint8)
409-
obs_dict[f"{OBS_IMAGES}.{cam_name}"] = torch.from_numpy(frame)
418+
frame = np.zeros((640, 480, 3), dtype=np.uint8)
419+
obs_dict[cam_name] = torch.from_numpy(frame)
410420

411421
return obs_dict
412422

@@ -459,22 +469,26 @@ def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
459469
)
460470

461471
goal_pos = {}
462-
motors_name = self.state_feature.get("names").get("motors")
463472

464-
common_keys = [key for key in action if key in (motor.replace("arm_", "") for motor in motors_name)]
473+
common_keys = [
474+
key
475+
for key in action
476+
if key in (motor.replace("arm_", "") for motor, _ in self.state_feature_host.items())
477+
]
465478

466479
arm_actions = {"arm_" + arm_motor: action[arm_motor] for arm_motor in common_keys}
467480
goal_pos = arm_actions
468481

469-
if len(action) > 6:
470-
keyboard_keys = np.array(list(set(action.keys()) - set(common_keys)))
471-
wheel_actions = {
472-
"base_" + k: v for k, v in self._from_keyboard_to_wheel_action(keyboard_keys).items()
473-
}
474-
goal_pos = {**arm_actions, **wheel_actions}
482+
keyboard_keys = np.array(list(set(action.keys()) - set(common_keys)))
483+
wheel_actions = {
484+
"base_" + k: v for k, v in self._from_keyboard_to_wheel_action(keyboard_keys).items()
485+
}
486+
goal_pos = {**arm_actions, **wheel_actions}
475487

476488
self.zmq_cmd_socket.send_string(json.dumps(goal_pos)) # action is in motor space
477489

490+
# TODO(Steven): Remove the np conversion when it is possible to record a non-numpy array value
491+
goal_pos = {"action." + k: np.array([v], dtype=np.float32) for k, v in goal_pos.items()}
478492
return goal_pos
479493

480494
def disconnect(self):

0 commit comments

Comments
 (0)