diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 6bad1a4c22..b3b9b302bd 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -168,7 +168,7 @@ def make_pre_post_processors( config=policy_cfg, dataset_stats=kwargs.get("dataset_stats") ) - elif isinstance(policy_cfg.type, ACTConfig): + elif isinstance(policy_cfg, ACTConfig): from lerobot.policies.act.processor_act import make_act_pre_post_processors processors = make_act_pre_post_processors( diff --git a/src/lerobot/processor/device_processor.py b/src/lerobot/processor/device_processor.py index dfdc80d14b..78a3ad7972 100644 --- a/src/lerobot/processor/device_processor.py +++ b/src/lerobot/processor/device_processor.py @@ -35,32 +35,30 @@ class DeviceProcessor(ProcessorStep): device: str = "cpu" float_dtype: str | None = None - _device: torch.device | None = None + + DTYPE_MAPPING = { + "float16": torch.float16, + "float32": torch.float32, + "float64": torch.float64, + "bfloat16": torch.bfloat16, + "half": torch.float16, + "float": torch.float32, + "double": torch.float64, + } def __post_init__(self): - self._device = get_safe_torch_device(self.device) - self.device = self._device.type + self._device: torch.device = get_safe_torch_device(self.device) + self.device = self._device.type # cuda might have changed to cuda:1 self.non_blocking = "cuda" in str(self.device) # Validate and convert float_dtype string to torch dtype if self.float_dtype is not None: - dtype_mapping = { - "float16": torch.float16, - "float32": torch.float32, - "float64": torch.float64, - "bfloat16": torch.bfloat16, - "half": torch.float16, - "float": torch.float32, - "double": torch.float64, - } - - if self.float_dtype not in dtype_mapping: - available_dtypes = list(dtype_mapping.keys()) + if self.float_dtype not in self.DTYPE_MAPPING: raise ValueError( - f"Invalid float_dtype '{self.float_dtype}'. Available options: {available_dtypes}" + f"Invalid float_dtype '{self.float_dtype}'. Available options: {list(self.DTYPE_MAPPING.keys())}" ) - self._target_float_dtype = dtype_mapping[self.float_dtype] + self._target_float_dtype = self.DTYPE_MAPPING[self.float_dtype] else: self._target_float_dtype = None @@ -93,51 +91,35 @@ def _process_tensor(self, tensor: torch.Tensor) -> torch.Tensor: return tensor def __call__(self, transition: EnvTransition) -> EnvTransition: - # Create a copy of the transition new_transition = transition.copy() - # Process observation tensors - observation = transition.get(TransitionKey.OBSERVATION) - if observation is not None: - new_observation = { - k: self._process_tensor(v) if isinstance(v, torch.Tensor) else v - for k, v in observation.items() - } - new_transition[TransitionKey.OBSERVATION] = new_observation - - # Process action tensor - action = transition.get(TransitionKey.ACTION) - if action is not None and isinstance(action, torch.Tensor): - new_transition[TransitionKey.ACTION] = self._process_tensor(action) - - # Process reward tensor - reward = transition.get(TransitionKey.REWARD) - if reward is not None and isinstance(reward, torch.Tensor): - new_transition[TransitionKey.REWARD] = self._process_tensor(reward) - - # Process done tensor - done = transition.get(TransitionKey.DONE) - if done is not None and isinstance(done, torch.Tensor): - new_transition[TransitionKey.DONE] = self._process_tensor(done) - - # Process truncated tensor - truncated = transition.get(TransitionKey.TRUNCATED) - if truncated is not None and isinstance(truncated, torch.Tensor): - new_transition[TransitionKey.TRUNCATED] = self._process_tensor(truncated) - - # Process complementary data tensors - complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) - if complementary_data is not None: - new_complementary_data = {} - - # Process all items in complementary_data - for key, value in complementary_data.items(): - if isinstance(value, torch.Tensor): - new_complementary_data[key] = self._process_tensor(value) - else: - new_complementary_data[key] = value - - new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data + simple_tensor_keys = [ + TransitionKey.ACTION, + TransitionKey.REWARD, + TransitionKey.DONE, + TransitionKey.TRUNCATED, + ] + + dict_tensor_keys = [ + TransitionKey.OBSERVATION, + TransitionKey.COMPLEMENTARY_DATA, + ] + + # Process simple tensors + for key in simple_tensor_keys: + value = transition.get(key) + if isinstance(value, torch.Tensor): + new_transition[key] = self._process_tensor(value) + + # Process dictionary-like tensors + for key in dict_tensor_keys: + data_dict = transition.get(key) + if data_dict is not None: + new_data_dict = { + k: self._process_tensor(v) if isinstance(v, torch.Tensor) else v + for k, v in data_dict.items() + } + new_transition[key] = new_data_dict return new_transition diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index a3b119356f..8ffe490d6c 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -837,6 +837,7 @@ def __repr__(self) -> str: def __post_init__(self): for i, step in enumerate(self.steps): if not callable(step): + # TODO(steven): This should instead check isinstance(step, ProcessorStep), test need to be updated raise TypeError( f"Step {i} ({type(step).__name__}) must define __call__(transition) -> EnvTransition" )