1515import base64
1616import json
1717import logging
18- from typing import Any
18+ from typing import Any , Dict , Optional , Tuple
1919
2020import cv2
2121import numpy as np
2222import torch
2323import zmq
24- from typing import Tuple , Dict , Any , Optional
2524
2625from lerobot .common .constants import OBS_IMAGES , OBS_STATE
2726from lerobot .common .errors import DeviceAlreadyConnectedError , DeviceNotConnectedError
@@ -46,7 +45,7 @@ def __init__(self, config: LeKiwiClientConfig):
4645
4746 self .teleop_keys = config .teleop_keys
4847
49- self .polling_timeot_ms = config .polling_timeout_ms
48+ self .polling_timeout_ms = config .polling_timeout_ms
5049 self .connect_timeout_s = config .connect_timeout_s
5150
5251 self .zmq_context = None
@@ -70,28 +69,46 @@ def __init__(self, config: LeKiwiClientConfig):
7069 self .logs = {}
7170
7271 @property
73- def state_feature (self ) -> dict :
74- return {
75- "dtype" : "float32" ,
76- "shape" : (9 ,),
77- "names" : {
78- "motors" : [
79- "arm_shoulder_pan" ,
80- "arm_shoulder_lift" ,
81- "arm_elbow_flex" ,
82- "arm_wrist_flex" ,
83- "arm_wrist_roll" ,
84- "arm_gripper" ,
85- "base_left_wheel" ,
86- "base_right_wheel" ,
87- "base_back_wheel" ,
88- ]
89- },
72+ def state_feature_client (self ) -> dict :
73+ state_ft = {
74+ "arm_shoulder_pan" : {"shape" : (1 ,), "info" : None , "dtype" : "float32" },
75+ "arm_shoulder_lift" : {"shape" : (1 ,), "info" : None , "dtype" : "float32" },
76+ "arm_elbow_flex" : {"shape" : (1 ,), "info" : None , "dtype" : "float32" },
77+ "arm_wrist_flex" : {"shape" : (1 ,), "info" : None , "dtype" : "float32" },
78+ "arm_wrist_roll" : {"shape" : (1 ,), "info" : None , "dtype" : "float32" },
79+ "arm_gripper" : {"shape" : (1 ,), "info" : None , "dtype" : "float32" },
80+ "x_cmd" : {"shape" : (1 ,), "info" : None , "dtype" : "float32" },
81+ "y_cmd" : {"shape" : (1 ,), "info" : None , "dtype" : "float32" },
82+ "theta_cmd" : {"shape" : (1 ,), "info" : None , "dtype" : "float32" },
83+ }
84+ return state_ft
85+
86+ @property
87+ def state_feature_host (self ) -> dict :
88+ state_ft = {
89+ "arm_shoulder_pan" : {"shape" : (1 ,), "info" : None , "dtype" : "float32" },
90+ "arm_shoulder_lift" : {"shape" : (1 ,), "info" : None , "dtype" : "float32" },
91+ "arm_elbow_flex" : {"shape" : (1 ,), "info" : None , "dtype" : "float32" },
92+ "arm_wrist_flex" : {"shape" : (1 ,), "info" : None , "dtype" : "float32" },
93+ "arm_wrist_roll" : {"shape" : (1 ,), "info" : None , "dtype" : "float32" },
94+ "arm_gripper" : {"shape" : (1 ,), "info" : None , "dtype" : "float32" },
95+ "base_left_wheel" : {"shape" : (1 ,), "info" : None , "dtype" : "float32" },
96+ "base_right_wheel" : {"shape" : (1 ,), "info" : None , "dtype" : "float32" },
97+ "base_back_wheel" : {"shape" : (1 ,), "info" : None , "dtype" : "float32" },
9098 }
99+ return state_ft
100+
101+ @property
102+ def state_feature (self ) -> dict :
103+ raise (
104+ NotImplementedError (
105+ "state_feature is not implemented for LeKiwiClient. Use state_feature_client or state_feature_host instead."
106+ )
107+ )
91108
92109 @property
93110 def action_feature (self ) -> dict :
94- return self .state_feature
111+ return self .state_feature_host
95112
96113 @property
97114 def camera_features (self ) -> dict [str , dict ]:
@@ -100,10 +117,12 @@ def camera_features(self) -> dict[str, dict]:
100117 "shape" : (480 , 640 , 3 ),
101118 "names" : ["height" , "width" , "channels" ],
102119 "info" : None ,
120+ "dtype" : "image" ,
103121 },
104122 f"{ OBS_IMAGES } .wrist" : {
105123 "shape" : (480 , 640 , 3 ),
106124 "names" : ["height" , "width" , "channels" ],
125+ "dtype" : "image" ,
107126 "info" : None ,
108127 },
109128 }
@@ -261,16 +280,15 @@ def _wheel_raw_to_body(
261280 velocity_vector = m_inv .dot (wheel_linear_speeds )
262281 x_cmd , y_cmd , theta_rad = velocity_vector
263282 theta_cmd = theta_rad * (180.0 / np .pi )
264- return {"x_cmd" : x_cmd , "y_cmd" : y_cmd , "theta_cmd" : theta_cmd }
265-
283+ return {f"{ OBS_STATE } .x_cmd" : x_cmd , f"{ OBS_STATE } .y_cmd" : y_cmd , f"{ OBS_STATE } .theta_cmd" : theta_cmd }
266284
267285 def _poll_and_get_latest_message (self ) -> Optional [str ]:
268286 """Polls the ZMQ socket for a limited time and returns the latest message string."""
269287 poller = zmq .Poller ()
270288 poller .register (self .zmq_observation_socket , zmq .POLLIN )
271-
289+
272290 try :
273- socks = dict (poller .poll (self .polling_timeot_ms ))
291+ socks = dict (poller .poll (self .polling_timeout_ms ))
274292 except zmq .ZMQError as e :
275293 logging .error (f"ZMQ polling error: { e } " )
276294 return None
@@ -291,7 +309,7 @@ def _poll_and_get_latest_message(self) -> Optional[str]:
291309 logging .warning ("Poller indicated data, but failed to retrieve message." )
292310
293311 return last_msg
294-
312+
295313 def _parse_observation_json (self , obs_string : str ) -> Optional [Dict [str , Any ]]:
296314 """Parses the JSON observation string."""
297315 try :
@@ -300,7 +318,6 @@ def _parse_observation_json(self, obs_string: str) -> Optional[Dict[str, Any]]:
300318 logging .error (f"Error decoding JSON observation: { e } " )
301319 return None
302320
303-
304321 def _decode_image_from_b64 (self , image_b64 : str ) -> Optional [np .ndarray ]:
305322 """Decodes a base64 encoded image string to an OpenCV image."""
306323 if not image_b64 :
@@ -315,10 +332,12 @@ def _decode_image_from_b64(self, image_b64: str) -> Optional[np.ndarray]:
315332 except (TypeError , ValueError ) as e :
316333 logging .error (f"Error decoding base64 image data: { e } " )
317334 return None
318-
319- def _process_observation_data (self , observation : Dict [str , Any ]) -> Tuple [Dict [str , np .ndarray ], Dict [str , Any ], Dict [str , Any ]]:
335+
336+ def _process_observation_data (
337+ self , observation : Dict [str , Any ]
338+ ) -> Tuple [Dict [str , np .ndarray ], Dict [str , Any ], Dict [str , Any ]]:
320339 """Extracts frames, speed, and arm state from the parsed observation."""
321-
340+
322341 # Separate image and state data
323342 image_observation = {k : v for k , v in observation .items () if k .startswith (OBS_IMAGES )}
324343 state_observation = {k : v for k , v in observation .items () if k .startswith (OBS_STATE )}
@@ -331,16 +350,11 @@ def _process_observation_data(self, observation: Dict[str, Any]) -> Tuple[Dict[s
331350 current_frames [cam_name ] = frame
332351
333352 # Extract state components
334- current_speed = {
335- k : v for k , v in state_observation .items () if k .startswith (f"{ OBS_STATE } .base" )
336- }
337- current_arm_state = {
338- k : v for k , v in state_observation .items () if k .startswith (f"{ OBS_STATE } .arm" )
339- }
353+ current_speed = {k : v for k , v in state_observation .items () if k .startswith (f"{ OBS_STATE } .base" )}
354+ current_arm_state = {k : v for k , v in state_observation .items () if k .startswith (f"{ OBS_STATE } .arm" )}
340355
341356 return current_frames , current_speed , current_arm_state
342357
343-
344358 def _get_data (self ) -> Tuple [Dict [str , np .ndarray ], Dict [str , Any ], Dict [str , Any ]]:
345359 """
346360 Polls the video socket for the latest observation data.
@@ -349,7 +363,7 @@ def _get_data(self) -> Tuple[Dict[str, np.ndarray], Dict[str, Any], Dict[str, An
349363 If successful, updates and returns the new frames, speed, and arm state.
350364 If no new data arrives or decoding fails, returns the last known values.
351365 """
352-
366+
353367 # 1. Get the latest message string from the socket
354368 latest_message_str = self ._poll_and_get_latest_message ()
355369
@@ -386,27 +400,23 @@ def get_observation(self) -> dict[str, Any]:
386400 if not self ._is_connected :
387401 raise DeviceNotConnectedError ("LeKiwiClient is not connected. You need to run `robot.connect()`." )
388402
389- # TODO(Steven): remove hard-coded cam names & dims
390- # This is needed at init for when there's no comms
391- obs_dict = {
392- f"{ OBS_IMAGES } .wrist" : np .zeros (shape = (480 , 640 , 3 )),
393- f"{ OBS_IMAGES } .front" : np .zeros (shape = (640 , 480 , 3 )),
394- }
395-
396403 frames , present_speed , remote_arm_state_tensor = self ._get_data ()
397404 body_state = self ._wheel_raw_to_body (present_speed )
398405 body_state_mm = {k : v * 1000.0 for k , v in body_state .items ()} # Convert x,y to mm/s
399406
407+ obs_dict = {}
400408 obs_dict .update (remote_arm_state_tensor )
401409 obs_dict .update (body_state_mm )
402410
411+ # TODO(Steven): Remove this when it is possible to record a non-numpy array value
412+ obs_dict = {k : np .array ([v ], dtype = np .float32 ) for k , v in obs_dict .items ()}
413+
403414 # Loop over each configured camera
404415 for cam_name , frame in frames .items ():
405416 if frame is None :
406- # TODO(Steven): Daemon doesn't know camera dimensions (hard-coded for now), consider at least getting them from state features
407417 logging .warning ("Frame is None" )
408- frame = np .zeros ((480 , 640 , 3 ), dtype = np .uint8 )
409- obs_dict [f" { OBS_IMAGES } . { cam_name } " ] = torch .from_numpy (frame )
418+ frame = np .zeros ((640 , 480 , 3 ), dtype = np .uint8 )
419+ obs_dict [cam_name ] = torch .from_numpy (frame )
410420
411421 return obs_dict
412422
@@ -459,22 +469,26 @@ def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
459469 )
460470
461471 goal_pos = {}
462- motors_name = self .state_feature .get ("names" ).get ("motors" )
463472
464- common_keys = [key for key in action if key in (motor .replace ("arm_" , "" ) for motor in motors_name )]
473+ common_keys = [
474+ key
475+ for key in action
476+ if key in (motor .replace ("arm_" , "" ) for motor , _ in self .state_feature_host .items ())
477+ ]
465478
466479 arm_actions = {"arm_" + arm_motor : action [arm_motor ] for arm_motor in common_keys }
467480 goal_pos = arm_actions
468481
469- if len (action ) > 6 :
470- keyboard_keys = np .array (list (set (action .keys ()) - set (common_keys )))
471- wheel_actions = {
472- "base_" + k : v for k , v in self ._from_keyboard_to_wheel_action (keyboard_keys ).items ()
473- }
474- goal_pos = {** arm_actions , ** wheel_actions }
482+ keyboard_keys = np .array (list (set (action .keys ()) - set (common_keys )))
483+ wheel_actions = {
484+ "base_" + k : v for k , v in self ._from_keyboard_to_wheel_action (keyboard_keys ).items ()
485+ }
486+ goal_pos = {** arm_actions , ** wheel_actions }
475487
476488 self .zmq_cmd_socket .send_string (json .dumps (goal_pos )) # action is in motor space
477489
490+ # TODO(Steven): Remove the np conversion when it is possible to record a non-numpy array value
491+ goal_pos = {"action." + k : np .array ([v ], dtype = np .float32 ) for k , v in goal_pos .items ()}
478492 return goal_pos
479493
480494 def disconnect (self ):
0 commit comments