diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index e8a70186a5f..4a8d177c5d3 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -1,6 +1,5 @@ from __future__ import annotations -from collections.abc import Mapping from copy import deepcopy from dataclasses import dataclass, field from typing import Any @@ -20,219 +19,101 @@ ) -def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]: - """Convert numpy arrays and other types to torch tensors.""" +def _to_tensor(value: Any, device: torch.device | None = None) -> Tensor: + """Convert common python/numpy/torch types to a torch.float32 tensor. + + Always returns float32; preserves device if provided. + """ + if isinstance(value, torch.Tensor): + return value.to(dtype=torch.float32, device=device) + if isinstance(value, np.ndarray): + # ensure contiguous, cast to float32 then convert + return torch.from_numpy(np.ascontiguousarray(value.astype(np.float32))).to(device=device) + if isinstance(value, (int, float)): + return torch.tensor(value, dtype=torch.float32, device=device) + if isinstance(value, (list, tuple)): + return torch.tensor(value, dtype=torch.float32, device=device) + raise TypeError(f"Unsupported type for stats value: {type(value)}") + + +def _convert_stats_to_tensors( + stats: dict[str, dict[str, Any]], device: torch.device | None = None +) -> dict[str, dict[str, Tensor]]: + """Convert numeric stats values to torch tensors, preserving keys.""" tensor_stats: dict[str, dict[str, Tensor]] = {} - for key, sub in stats.items(): + for key, sub in (stats or {}).items(): + if sub is None: + continue tensor_stats[key] = {} for stat_name, value in sub.items(): - if isinstance(value, np.ndarray): - tensor_val = torch.from_numpy(value.astype(np.float32)) - elif isinstance(value, torch.Tensor): - tensor_val = value.to(dtype=torch.float32) - elif isinstance(value, (int, float, list, tuple)): - tensor_val = torch.tensor(value, dtype=torch.float32) - else: - raise TypeError(f"Unsupported type for stats['{key}']['{stat_name}']: {type(value)}") - tensor_stats[key][stat_name] = tensor_val + tensor_stats[key][stat_name] = _to_tensor(value, device=device) return tensor_stats @dataclass -@ProcessorStepRegistry.register(name="normalizer_processor") -class NormalizerProcessor(ProcessorStep): - """Normalizes observations and actions in a single processor step. - - This processor handles normalization of both observation and action tensors - using either mean/std normalization or min/max scaling to a [-1, 1] range. - - For each tensor key in the stats dictionary, the processor will: - - Use mean/std normalization if those statistics are provided: (x - mean) / std - - Use min/max scaling if those statistics are provided: 2 * (x - min) / (max - min) - 1 +class _NormalizationMixin: + """ + A mixin class providing core functionality for normalization and unnormalization. - The processor can be configured to normalize only specific keys by setting - the normalize_keys parameter. + This class manages normalization statistics, their conversion to tensors, device placement, + and the application of normalization transformations. It is designed to be inherited by + concrete ProcessorStep implementations. """ - # Features and normalisation map are mandatory to match the design of normalize.py features: dict[str, PolicyFeature] norm_map: dict[FeatureType, NormalizationMode] - - # Pre-computed statistics coming from dataset.meta.stats for instance. stats: dict[str, dict[str, Any]] | None = None - - # Explicit subset of keys to normalise. If ``None`` every key (except - # "action") found in ``stats`` will be normalised. Using a ``set`` makes - # membership checks O(1). - normalize_keys: set[str] | None = None - + device: torch.device | str | None = None eps: float = 1e-8 + normalize_observation_keys: set[str] | None = None _tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False) - @classmethod - def from_lerobot_dataset( - cls, - dataset: LeRobotDataset, - features: dict[str, PolicyFeature], - norm_map: dict[FeatureType, NormalizationMode], - *, - normalize_keys: set[str] | None = None, - eps: float = 1e-8, - ) -> NormalizerProcessor: - """Factory helper that pulls statistics from a :class:`LeRobotDataset`. - - The features and norm_map parameters are mandatory to match the design - pattern used in normalize.py. - """ - - return cls( - features=features, - norm_map=norm_map, - stats=dataset.meta.stats, - normalize_keys=normalize_keys, - eps=eps, - ) - def __post_init__(self): - # Handle deserialization from JSON config - if self.features and isinstance(list(self.features.values())[0], dict): - # Features came from JSON - need to reconstruct PolicyFeature objects - reconstructed_features = {} - for key, ft_dict in self.features.items(): - reconstructed_features[key] = PolicyFeature( - type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"]) - ) - self.features = reconstructed_features - - if self.norm_map and isinstance(list(self.norm_map.keys())[0], str): - # norm_map came from JSON - need to reconstruct enum keys and values - reconstructed_norm_map = {} - for ft_type_str, norm_mode_str in self.norm_map.items(): - reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str) - self.norm_map = reconstructed_norm_map - - # Convert statistics once so we avoid repeated numpy→Tensor conversions - # during runtime. + # Robust JSON deserialization handling (guard empty maps) + if self.features: + first_val = next(iter(self.features.values())) + if isinstance(first_val, dict): + reconstructed = {} + for key, ft_dict in self.features.items(): + reconstructed[key] = PolicyFeature( + type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"]) + ) + self.features = reconstructed + + if self.norm_map: + # if keys are strings (JSON), rebuild enum map + if all(isinstance(k, str) for k in self.norm_map.keys()): + reconstructed = {} + for ft_type_str, norm_mode_str in self.norm_map.items(): + reconstructed[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str) + self.norm_map = reconstructed + + # Convert stats to tensors and move to the target device once during initialization. self.stats = self.stats or {} - self._tensor_stats = _convert_stats_to_tensors(self.stats) - - # Ensure *normalize_keys* is a set for fast look-ups and compare by - # value later when returning the configuration. - if self.normalize_keys is not None and not isinstance(self.normalize_keys, set): - self.normalize_keys = set(self.normalize_keys) - - def _normalize_obs(self, observation, normalized_info): - if observation is None: - return None - - # Decide which keys should be normalised for this call. - if self.normalize_keys is not None: - keys_to_norm = self.normalize_keys - else: - # Use feature map to skip action keys. - keys_to_norm = {k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION} - - processed = dict(observation) - for key in keys_to_norm: - if key not in processed or key not in self.features: - continue + self._tensor_stats = _convert_stats_to_tensors(self.stats, device=self.device) - # Check the normalization mode for this feature type - feature = self.features[key] - norm_mode = self.norm_map.get(feature.type, NormalizationMode.IDENTITY) - - # Skip normalization if mode is IDENTITY - if norm_mode is NormalizationMode.IDENTITY: - normalized_info[key] = "IDENTITY" - continue + def to(self, device: torch.device | str) -> _NormalizationMixin: + """Moves the processor's normalization stats to the specified device and returns self.""" + self.device = device + self._tensor_stats = _convert_stats_to_tensors(self.stats, device=self.device) + return self - # Skip if no stats available for this key - if key not in self._tensor_stats: - continue + def state_dict(self) -> dict[str, Tensor]: + flat: dict[str, Tensor] = {} + for key, sub in self._tensor_stats.items(): + for stat_name, tensor in sub.items(): + flat[f"{key}.{stat_name}"] = tensor.cpu() # Always save to CPU + return flat - orig_val = processed[key] - tensor = ( - orig_val.to(dtype=torch.float32) - if isinstance(orig_val, torch.Tensor) - else torch.as_tensor(orig_val, dtype=torch.float32) + def load_state_dict(self, state: dict[str, Tensor]) -> None: + self._tensor_stats.clear() + for flat_key, tensor in state.items(): + key, stat_name = flat_key.rsplit(".", 1) + # Load to the processor's configured device. + self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to( + dtype=torch.float32, device=self.device ) - stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()} - - if norm_mode is NormalizationMode.MEAN_STD: - if "mean" in stats and "std" in stats: - mean, std = stats["mean"], stats["std"] - processed[key] = (tensor - mean) / (std + self.eps) - normalized_info[key] = "MEAN_STD" - elif norm_mode is NormalizationMode.MIN_MAX: - if "min" in stats and "max" in stats: - min_val, max_val = stats["min"], stats["max"] - processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1 - normalized_info[key] = "MIN_MAX" - else: - raise ValueError(f"Unsupported normalization mode: {norm_mode}") - - return processed - - def _normalize_action(self, action, normalized_info): - if action is None: - return action - - # Check the normalization mode for actions - norm_mode = self.norm_map.get(FeatureType.ACTION, NormalizationMode.IDENTITY) - - # Skip normalization if mode is IDENTITY - if norm_mode is NormalizationMode.IDENTITY: - normalized_info["action"] = "IDENTITY" - return action - - # Skip if no stats available for actions - if "action" not in self._tensor_stats: - return action - - tensor = ( - action.to(dtype=torch.float32) - if isinstance(action, torch.Tensor) - else torch.as_tensor(action, dtype=torch.float32) - ) - stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()} - - if norm_mode is NormalizationMode.MEAN_STD: - if "mean" in stats and "std" in stats: - mean, std = stats["mean"], stats["std"] - normalized_info["action"] = "MEAN_STD" - return (tensor - mean) / (std + self.eps) - elif norm_mode is NormalizationMode.MIN_MAX: - if "min" in stats and "max" in stats: - min_val, max_val = stats["min"], stats["max"] - normalized_info["action"] = "MIN_MAX" - return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1 - else: - raise ValueError(f"Unsupported normalization mode: {norm_mode}") - - # If we reach here, the required stats for the normalization mode are not available - raise ValueError(f"Action stats must contain appropriate values for {norm_mode} normalization") - - def __call__(self, transition: EnvTransition) -> EnvTransition: - # Track what was normalized - normalized_info = {} - - observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION), normalized_info) - action = self._normalize_action(transition.get(TransitionKey.ACTION), normalized_info) - - # Create a new transition with normalized values - new_transition = transition.copy() - new_transition[TransitionKey.OBSERVATION] = observation - new_transition[TransitionKey.ACTION] = action - - # Add normalization info to complementary data - if normalized_info: - comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) - comp_data = {} if comp_data is None else dict(comp_data) - comp_data["normalized_keys"] = normalized_info - new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data - - return new_transition def get_config(self) -> dict[str, Any]: config = { @@ -242,39 +123,80 @@ def get_config(self) -> dict[str, Any]: }, "norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()}, } - if self.normalize_keys is not None: - # Serialise as a list for YAML / JSON friendliness - config["normalize_keys"] = sorted(self.normalize_keys) + if self.normalize_observation_keys is not None: + config["normalize_observation_keys"] = sorted(self.normalize_observation_keys) return config - def state_dict(self) -> dict[str, Tensor]: - flat = {} - for key, sub in self._tensor_stats.items(): - for stat_name, tensor in sub.items(): - flat[f"{key}.{stat_name}"] = tensor - return flat + def _normalize_observation(self, observation: dict[str, Any], inverse: bool) -> dict[str, Tensor]: + new_observation = dict(observation) + for key, feature in self.features.items(): + if self.normalize_observation_keys is not None and key not in self.normalize_observation_keys: + continue + if feature.type != FeatureType.ACTION and key in new_observation: + tensor = torch.as_tensor(new_observation[key], dtype=torch.float32) + new_observation[key] = self._apply_transform(tensor, key, feature.type, inverse=inverse) + return new_observation + + def _normalize_action(self, action: Any, inverse: bool) -> Tensor: + tensor = torch.as_tensor(action, dtype=torch.float32) + processed_action = self._apply_transform(tensor, "action", FeatureType.ACTION, inverse=inverse) + return processed_action + + def _apply_transform( + self, tensor: Tensor, key: str, feature_type: FeatureType, *, inverse: bool = False + ) -> Tensor: + """Core logic to apply normalization or unnormalization.""" + norm_mode = self.norm_map.get(feature_type, NormalizationMode.IDENTITY) + if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats: + return tensor + + if norm_mode not in (NormalizationMode.MEAN_STD, NormalizationMode.MIN_MAX): + raise ValueError(f"Unsupported normalization mode: {norm_mode}") - def load_state_dict(self, state: Mapping[str, Tensor]) -> None: - self._tensor_stats.clear() - for flat_key, tensor in state.items(): - key, stat_name = flat_key.rsplit(".", 1) - self._tensor_stats.setdefault(key, {})[stat_name] = tensor + # Ensure input tensor is on the same device as the stats. + if self.device and tensor.device != self.device: + tensor = tensor.to(self.device) + stats = self._tensor_stats[key] + tensor = tensor.to(dtype=torch.float32) -@dataclass -@ProcessorStepRegistry.register(name="unnormalizer_processor") -class UnnormalizerProcessor(ProcessorStep): - """Inverse normalisation for observations and actions. + if norm_mode == NormalizationMode.MEAN_STD and "mean" in stats and "std" in stats: + mean, std = stats["mean"], stats["std"] + # Avoid division by zero by adding a small epsilon. + denom = std + self.eps + if inverse: + return tensor * std + mean + return (tensor - mean) / denom + + if norm_mode == NormalizationMode.MIN_MAX and "min" in stats and "max" in stats: + min_val, max_val = stats["min"], stats["max"] + denom = max_val - min_val + # When min_val == max_val, substitute the denominator with a small epsilon + # to prevent division by zero. This consistently maps an input equal to + # min_val to -1, ensuring a stable transformation. + denom = torch.where( + denom == 0, torch.tensor(self.eps, device=self.device, dtype=torch.float32), denom + ) + if inverse: + # Map from [-1, 1] back to [min, max] + return (tensor + 1) / 2 * denom + min_val + # Map from [min, max] to [-1, 1] + return 2 * (tensor - min_val) / denom - 1 - Exactly mirrors :class:`NormalizerProcessor` but applies the inverse - transform. - """ + # If necessary stats are missing, return input unchanged. + return tensor - features: dict[str, PolicyFeature] - norm_map: dict[FeatureType, NormalizationMode] - stats: dict[str, dict[str, Any]] | None = None - _tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False) +@dataclass +@ProcessorStepRegistry.register(name="normalizer_processor") +class NormalizerProcessor(_NormalizationMixin, ProcessorStep): + """ + A processor that applies normalization to observations and actions in a transition. + + This class directly implements the normalization logic for both observation and action + components of an `EnvTransition`, using statistics (mean/std or min/max) provided at + initialization. + """ @classmethod def from_lerobot_dataset( @@ -282,188 +204,89 @@ def from_lerobot_dataset( dataset: LeRobotDataset, features: dict[str, PolicyFeature], norm_map: dict[FeatureType, NormalizationMode], - ) -> UnnormalizerProcessor: - return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats) - - def __post_init__(self): - # Handle deserialization from JSON config - if self.features and isinstance(list(self.features.values())[0], dict): - # Features came from JSON - need to reconstruct PolicyFeature objects - reconstructed_features = {} - for key, ft_dict in self.features.items(): - reconstructed_features[key] = PolicyFeature( - type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"]) - ) - self.features = reconstructed_features - - if self.norm_map and isinstance(list(self.norm_map.keys())[0], str): - # norm_map came from JSON - need to reconstruct enum keys and values - reconstructed_norm_map = {} - for ft_type_str, norm_mode_str in self.norm_map.items(): - reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str) - self.norm_map = reconstructed_norm_map - - self.stats = self.stats or {} - self._tensor_stats = _convert_stats_to_tensors(self.stats) - - def _unnormalize_obs(self, observation, unnormalized_info): - if observation is None: - return None - keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION] - processed = dict(observation) - for key in keys: - if key not in processed or key not in self.features: - continue - - # Check the normalization mode for this feature type - feature = self.features[key] - norm_mode = self.norm_map.get(feature.type, NormalizationMode.IDENTITY) - - # Skip unnormalization if mode is IDENTITY - if norm_mode is NormalizationMode.IDENTITY: - unnormalized_info[key] = "IDENTITY" - continue - - # Skip if no stats available for this key - if key not in self._tensor_stats: - continue - - orig_val = processed[key] - tensor = ( - orig_val.to(dtype=torch.float32) - if isinstance(orig_val, torch.Tensor) - else torch.as_tensor(orig_val, dtype=torch.float32) - ) - stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()} - - if norm_mode is NormalizationMode.MEAN_STD: - if "mean" in stats and "std" in stats: - mean, std = stats["mean"], stats["std"] - processed[key] = tensor * std + mean - unnormalized_info[key] = "MEAN_STD" - elif norm_mode is NormalizationMode.MIN_MAX: - if "min" in stats and "max" in stats: - min_val, max_val = stats["min"], stats["max"] - processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val - unnormalized_info[key] = "MIN_MAX" - else: - raise ValueError(f"Unsupported normalization mode: {norm_mode}") - - return processed - - def _unnormalize_action(self, action, unnormalized_info): - if action is None: - return action - - # Check the normalization mode for actions - norm_mode = self.norm_map.get(FeatureType.ACTION, NormalizationMode.IDENTITY) - - # Skip unnormalization if mode is IDENTITY - if norm_mode is NormalizationMode.IDENTITY: - unnormalized_info["action"] = "IDENTITY" - return action - - # Skip if no stats available for actions - if "action" not in self._tensor_stats: - return action - - tensor = ( - action.to(dtype=torch.float32) - if isinstance(action, torch.Tensor) - else torch.as_tensor(action, dtype=torch.float32) + *, + normalize_observation_keys: set[str] | None = None, + eps: float = 1e-8, + device: torch.device | str | None = None, + ) -> NormalizerProcessor: + return cls( + features=features, + norm_map=norm_map, + stats=dataset.meta.stats, + normalize_observation_keys=normalize_observation_keys, + eps=eps, + device=device, ) - stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()} - - if norm_mode is NormalizationMode.MEAN_STD: - if "mean" in stats and "std" in stats: - mean, std = stats["mean"], stats["std"] - unnormalized_info["action"] = "MEAN_STD" - return tensor * std + mean - elif norm_mode is NormalizationMode.MIN_MAX: - if "min" in stats and "max" in stats: - min_val, max_val = stats["min"], stats["max"] - unnormalized_info["action"] = "MIN_MAX" - return (tensor + 1) / 2 * (max_val - min_val) + min_val - else: - raise ValueError(f"Unsupported normalization mode: {norm_mode}") - - # If we reach here, the required stats for the normalization mode are not available - raise ValueError(f"Action stats must contain appropriate values for {norm_mode} normalization") def __call__(self, transition: EnvTransition) -> EnvTransition: - # Track what was unnormalized - unnormalized_info = {} - - observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION), unnormalized_info) - action = self._unnormalize_action(transition.get(TransitionKey.ACTION), unnormalized_info) - - # Create a new transition with unnormalized values new_transition = transition.copy() - new_transition[TransitionKey.OBSERVATION] = observation - new_transition[TransitionKey.ACTION] = action - # Add unnormalization info to complementary data - if unnormalized_info: - comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) - comp_data = {} if comp_data is None else dict(comp_data) - comp_data["unnormalized_keys"] = unnormalized_info - new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data + # Handle observation normalization. + observation = new_transition.get(TransitionKey.OBSERVATION) + if observation is not None: + new_transition[TransitionKey.OBSERVATION] = self._normalize_observation( + observation, inverse=False + ) + + # Handle action normalization. + action = new_transition.get(TransitionKey.ACTION) + if action is not None: + new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=False) return new_transition - def get_config(self) -> dict[str, Any]: - return { - "features": { - key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items() - }, - "norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()}, - } - def state_dict(self) -> dict[str, Tensor]: - flat = {} - for key, sub in self._tensor_stats.items(): - for stat_name, tensor in sub.items(): - flat[f"{key}.{stat_name}"] = tensor - return flat +@dataclass +@ProcessorStepRegistry.register(name="unnormalizer_processor") +class UnnormalizerProcessor(_NormalizationMixin, ProcessorStep): + """ + A processor that applies unnormalization (the inverse of normalization) to + observations and actions in a transition. - def load_state_dict(self, state: Mapping[str, Tensor]) -> None: - self._tensor_stats.clear() - for flat_key, tensor in state.items(): - key, stat_name = flat_key.rsplit(".", 1) - self._tensor_stats.setdefault(key, {})[stat_name] = tensor + This is typically used to transform actions from a normalized policy output back into + the original scale for execution in an environment. + """ + @classmethod + def from_lerobot_dataset( + cls, + dataset: LeRobotDataset, + features: dict[str, PolicyFeature], + norm_map: dict[FeatureType, NormalizationMode], + *, + device: torch.device | str | None = None, + ) -> UnnormalizerProcessor: + return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, device=device) -def hotswap_stats(robot_processor: RobotProcessor, stats: dict[str, dict[str, Any]]) -> RobotProcessor: - robot_processor = deepcopy(robot_processor) - for step in robot_processor.steps: - if isinstance(step, NormalizerProcessor) or isinstance(step, UnnormalizerProcessor): - step: NormalizerProcessor | UnnormalizerProcessor - step.stats = stats - step._tensor_stats = _convert_stats_to_tensors(stats) - return robot_processor + def __call__(self, transition: EnvTransition) -> EnvTransition: + new_transition = transition.copy() + # Handle observation unnormalization. + observation = new_transition.get(TransitionKey.OBSERVATION) + if observation is not None: + new_transition[TransitionKey.OBSERVATION] = self._normalize_observation(observation, inverse=True) -def rename_stats(stats: dict[str, dict[str, Any]], rename_map: dict[str, str]) -> dict[str, dict[str, Any]]: - """Rename keys in the stats dictionary according to the provided mapping. + # Handle action unnormalization. + action = new_transition.get(TransitionKey.ACTION) + if action is not None: + new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=True) - Args: - stats: The statistics dictionary with structure {feature_key: {stat_name: value}} - rename_map: Dictionary mapping old key names to new key names + return new_transition - Returns: - A new stats dictionary with renamed keys - Example: - >>> stats = {"observation.state": {"mean": 0.0, "std": 1.0}, "action": {"mean": 0.5, "std": 0.5}} - >>> rename_map = {"observation.state": "observation.robot_state"} - >>> new_stats = rename_stats(stats, rename_map) - >>> # new_stats will have "observation.robot_state" instead of "observation.state" +def hotswap_stats(robot_processor: RobotProcessor, stats: dict[str, dict[str, Any]]) -> RobotProcessor: """ - renamed_stats = {} - - for old_key, sub_stats in stats.items(): - # Use the new key if it exists in the rename map, otherwise keep the old key - new_key = rename_map.get(old_key, old_key) - renamed_stats[new_key] = deepcopy(sub_stats) + Replaces normalization statistics in a RobotProcessor pipeline. - return renamed_stats + This function creates a deep copy of the provided `RobotProcessor` and updates the + statistics of any `NormalizerProcessor` or `UnnormalizerProcessor` steps within it. + It's useful for adapting a trained policy to a new environment or dataset with + different data distributions. + """ + rp = deepcopy(robot_processor) + for step in rp.steps: + if isinstance(step, _NormalizationMixin): + step.stats = stats + # Re-initialize tensor_stats on the correct device. + step._tensor_stats = _convert_stats_to_tensors(stats, device=step.device) + return rp diff --git a/src/lerobot/processor/rename_processor.py b/src/lerobot/processor/rename_processor.py index db20424df06..ebc867cacc0 100644 --- a/src/lerobot/processor/rename_processor.py +++ b/src/lerobot/processor/rename_processor.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from copy import deepcopy from dataclasses import dataclass, field from typing import Any @@ -49,3 +50,14 @@ def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, Po - Keys not in `rename_map` remain unchanged. """ return {self.rename_map.get(k, k): v for k, v in features.items()} + + +def rename_stats(stats: dict[str, dict[str, Any]], rename_map: dict[str, str]) -> dict[str, dict[str, Any]]: + """Rename keys in the stats dictionary according to rename_map (defensive copy).""" + if not stats: + return {} + renamed: dict[str, dict[str, Any]] = {} + for old_key, sub_stats in stats.items(): + new_key = rename_map.get(old_key, old_key) + renamed[new_key] = deepcopy(sub_stats) if sub_stats is not None else {} + return renamed diff --git a/src/lerobot/record.py b/src/lerobot/record.py index 7531be24271..093b18d7a31 100644 --- a/src/lerobot/record.py +++ b/src/lerobot/record.py @@ -83,8 +83,8 @@ to_transition_robot_observation, to_transition_teleop_action, ) -from lerobot.processor.normalize_processor import rename_stats from lerobot.processor.pipeline import IdentityProcessor, TransitionKey +from lerobot.processor.rename_processor import rename_stats from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index 5813cc37d97..6b904eee7cf 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -25,7 +25,6 @@ UnnormalizerProcessor, _convert_stats_to_tensors, hotswap_stats, - rename_stats, ) from lerobot.processor.pipeline import IdentityProcessor, RobotProcessor, TransitionKey @@ -182,7 +181,10 @@ def test_selective_normalization(observation_stats): features = _create_observation_features() norm_map = _create_observation_norm_map() normalizer = NormalizerProcessor( - features=features, norm_map=norm_map, stats=observation_stats, normalize_keys={"observation.image"} + features=features, + norm_map=norm_map, + stats=observation_stats, + normalize_observation_keys={"observation.image"}, ) observation = { @@ -243,6 +245,7 @@ def test_from_lerobot_dataset(): def test_state_dict_save_load(observation_normalizer): # Save state state_dict = observation_normalizer.state_dict() + print("State dict:", state_dict) # Create new normalizer and load state features = _create_observation_features() @@ -464,10 +467,10 @@ def test_processor_from_lerobot_dataset(full_stats): norm_map = _create_full_norm_map() processor = NormalizerProcessor.from_lerobot_dataset( - mock_dataset, features, norm_map, normalize_keys={"observation.image"} + mock_dataset, features, norm_map, normalize_observation_keys={"observation.image"} ) - assert processor.normalize_keys == {"observation.image"} + assert processor.normalize_observation_keys == {"observation.image"} assert "observation.image" in processor._tensor_stats assert "action" in processor._tensor_stats @@ -476,12 +479,16 @@ def test_get_config(full_stats): features = _create_full_features() norm_map = _create_full_norm_map() processor = NormalizerProcessor( - features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6 + features=features, + norm_map=norm_map, + stats=full_stats, + normalize_observation_keys={"observation.image"}, + eps=1e-6, ) config = processor.get_config() expected_config = { - "normalize_keys": ["observation.image"], + "normalize_observation_keys": ["observation.image"], "eps": 1e-6, "features": { "observation.image": {"type": "VISUAL", "shape": (3, 96, 96)}, @@ -580,7 +587,11 @@ def test_serialization_roundtrip(full_stats): features = _create_full_features() norm_map = _create_full_norm_map() original_processor = NormalizerProcessor( - features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6 + features=features, + norm_map=norm_map, + stats=full_stats, + normalize_observation_keys={"observation.image"}, + eps=1e-6, ) # Get config (serialization) @@ -591,7 +602,7 @@ def test_serialization_roundtrip(full_stats): features=config["features"], norm_map=config["norm_map"], stats=full_stats, - normalize_keys=set(config["normalize_keys"]), + normalize_observation_keys=set(config["normalize_observation_keys"]), eps=config["eps"], ) @@ -939,31 +950,31 @@ def test_identity_config_serialization(): assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION]) -def test_unsupported_normalization_mode_error(): - """Test that unsupported normalization modes raise appropriate errors.""" - features = {"observation.state": PolicyFeature(FeatureType.STATE, (2,))} +# def test_unsupported_normalization_mode_error(): +# """Test that unsupported normalization modes raise appropriate errors.""" +# features = {"observation.state": PolicyFeature(FeatureType.STATE, (2,))} - # Create an invalid norm_map (this would never happen in practice, but tests error handling) - from enum import Enum +# # Create an invalid norm_map (this would never happen in practice, but tests error handling) +# from enum import Enum - class InvalidMode(str, Enum): - INVALID = "INVALID" +# class InvalidMode(str, Enum): +# INVALID = "INVALID" - # We can't actually pass an invalid enum to the processor due to type checking, - # but we can test the error by manipulating the norm_map after creation - norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} - stats = {"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}} +# # We can't actually pass an invalid enum to the processor due to type checking, +# # but we can test the error by manipulating the norm_map after creation +# norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} +# stats = {"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}} - normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) +# normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) - # Manually inject an invalid mode to test error handling - normalizer.norm_map[FeatureType.STATE] = "INVALID_MODE" +# # Manually inject an invalid mode to test error handling +# normalizer.norm_map[FeatureType.STATE] = "INVALID_MODE" - observation = {"observation.state": torch.tensor([1.0, -0.5])} - transition = create_transition(observation=observation) +# observation = {"observation.state": torch.tensor([1.0, -0.5])} +# transition = create_transition(observation=observation) - with pytest.raises(ValueError, match="Unsupported normalization mode"): - normalizer(transition) +# with pytest.raises(ValueError, match="Unsupported normalization mode"): +# normalizer(transition) def test_hotswap_stats_basic_functionality(): @@ -1149,11 +1160,15 @@ def test_hotswap_stats_preserves_other_attributes(): "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), } norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} - normalize_keys = {"observation.image"} + normalize_observation_keys = {"observation.image"} eps = 1e-6 normalizer = NormalizerProcessor( - features=features, norm_map=norm_map, stats=initial_stats, normalize_keys=normalize_keys, eps=eps + features=features, + norm_map=norm_map, + stats=initial_stats, + normalize_observation_keys=normalize_observation_keys, + eps=eps, ) robot_processor = RobotProcessor(steps=[normalizer]) @@ -1164,7 +1179,7 @@ def test_hotswap_stats_preserves_other_attributes(): new_normalizer = new_processor.steps[0] assert new_normalizer.features == features assert new_normalizer.norm_map == norm_map - assert new_normalizer.normalize_keys == normalize_keys + assert new_normalizer.normalize_observation_keys == normalize_observation_keys assert new_normalizer.eps == eps # But stats should be updated @@ -1270,273 +1285,6 @@ def test_hotswap_stats_with_different_data_types(): torch.testing.assert_close(tensor_stats["observation.image"]["max"], torch.tensor(1.0)) -def test_normalization_info_tracking(): - """Test that normalization info is tracked in complementary_data.""" - features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), - "action": PolicyFeature(FeatureType.ACTION, (2,)), - } - - norm_map = { - FeatureType.VISUAL: NormalizationMode.MEAN_STD, - FeatureType.STATE: NormalizationMode.MIN_MAX, - FeatureType.ACTION: NormalizationMode.IDENTITY, - } - - stats = { - "observation.image": { - "mean": np.array([0.5, 0.5, 0.5]), - "std": np.array([0.2, 0.2, 0.2]), - }, - "observation.state": { - "min": np.array([0.0, -1.0]), - "max": np.array([1.0, 1.0]), - }, - "action": { - "mean": np.array([0.0, 0.0]), - "std": np.array([1.0, 1.0]), - }, - } - - normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) - - observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.5, 0.0]), - } - action = torch.tensor([1.0, -0.5]) - transition = create_transition(observation=observation, action=action) - - # Process the transition - normalized_transition = normalizer(transition) - - # Check that normalization info is added - comp_data = normalized_transition.get(TransitionKey.COMPLEMENTARY_DATA) - assert comp_data is not None - assert "normalized_keys" in comp_data - - norm_info = comp_data["normalized_keys"] - assert norm_info["observation.image"] == "MEAN_STD" - assert norm_info["observation.state"] == "MIN_MAX" - assert norm_info["action"] == "IDENTITY" - - -def test_unnormalization_info_tracking(): - """Test that unnormalization info is tracked in complementary_data.""" - features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), - "action": PolicyFeature(FeatureType.ACTION, (2,)), - } - - norm_map = { - FeatureType.VISUAL: NormalizationMode.MEAN_STD, - FeatureType.ACTION: NormalizationMode.MIN_MAX, - } - - stats = { - "observation.image": { - "mean": np.array([0.5, 0.5, 0.5]), - "std": np.array([0.2, 0.2, 0.2]), - }, - "action": { - "min": np.array([-1.0, -1.0]), - "max": np.array([1.0, 1.0]), - }, - } - - unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats) - - observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])} - action = torch.tensor([0.0, -0.5]) - transition = create_transition(observation=observation, action=action) - - # Process the transition - unnormalized_transition = unnormalizer(transition) - - # Check that unnormalization info is added - comp_data = unnormalized_transition.get(TransitionKey.COMPLEMENTARY_DATA) - assert comp_data is not None - assert "unnormalized_keys" in comp_data - - unnorm_info = comp_data["unnormalized_keys"] - assert unnorm_info["observation.image"] == "MEAN_STD" - assert unnorm_info["action"] == "MIN_MAX" - - -def test_normalization_info_with_missing_stats(): - """Test normalization info when stats are missing for some keys.""" - features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), - } - - norm_map = { - FeatureType.VISUAL: NormalizationMode.MEAN_STD, - FeatureType.STATE: NormalizationMode.MIN_MAX, - } - - # Only provide stats for image, not state - stats = { - "observation.image": { - "mean": np.array([0.5, 0.5, 0.5]), - "std": np.array([0.2, 0.2, 0.2]), - }, - } - - normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) - - observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.5, 0.0]), - } - transition = create_transition(observation=observation) - - # Process the transition - normalized_transition = normalizer(transition) - - # Check that only keys with stats are in normalization info - comp_data = normalized_transition.get(TransitionKey.COMPLEMENTARY_DATA) - assert comp_data is not None - assert "normalized_keys" in comp_data - - norm_info = comp_data["normalized_keys"] - assert norm_info["observation.image"] == "MEAN_STD" - # State should not be in the normalization info since it has no stats - assert "observation.state" not in norm_info - - -def test_normalization_info_with_selective_keys(): - """Test normalization info with selective normalization.""" - features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), - } - - norm_map = { - FeatureType.VISUAL: NormalizationMode.MEAN_STD, - FeatureType.STATE: NormalizationMode.MIN_MAX, - } - - stats = { - "observation.image": { - "mean": np.array([0.5, 0.5, 0.5]), - "std": np.array([0.2, 0.2, 0.2]), - }, - "observation.state": { - "min": np.array([0.0, -1.0]), - "max": np.array([1.0, 1.0]), - }, - } - - # Only normalize image - normalizer = NormalizerProcessor( - features=features, norm_map=norm_map, stats=stats, normalize_keys={"observation.image"} - ) - - observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.5, 0.0]), - } - transition = create_transition(observation=observation) - - # Process the transition - normalized_transition = normalizer(transition) - - # Check that only selected keys are in normalization info - comp_data = normalized_transition.get(TransitionKey.COMPLEMENTARY_DATA) - assert comp_data is not None - assert "normalized_keys" in comp_data - - norm_info = comp_data["normalized_keys"] - assert norm_info["observation.image"] == "MEAN_STD" - # State should not be in the normalization info since it wasn't in normalize_keys - assert "observation.state" not in norm_info - - -def test_normalization_info_preserved_in_pipeline(): - """Test that normalization info is preserved when using RobotProcessor pipeline.""" - features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), - "action": PolicyFeature(FeatureType.ACTION, (2,)), - } - - norm_map = { - FeatureType.VISUAL: NormalizationMode.MEAN_STD, - FeatureType.ACTION: NormalizationMode.MIN_MAX, - } - - stats = { - "observation.image": { - "mean": np.array([0.5, 0.5, 0.5]), - "std": np.array([0.2, 0.2, 0.2]), - }, - "action": { - "min": np.array([-1.0, -1.0]), - "max": np.array([1.0, 1.0]), - }, - } - - normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) - unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats) - - # Create pipeline - pipeline = RobotProcessor([normalizer, unnormalizer]) - - observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])} - action = torch.tensor([0.5, -0.5]) - transition = create_transition(observation=observation, action=action) - - # Process through pipeline - result = pipeline(transition) - - # Check that both normalization and unnormalization info are present - comp_data = result.get(TransitionKey.COMPLEMENTARY_DATA) - assert comp_data is not None - assert "normalized_keys" in comp_data - assert "unnormalized_keys" in comp_data - - # Check normalization info - norm_info = comp_data["normalized_keys"] - assert norm_info["observation.image"] == "MEAN_STD" - assert norm_info["action"] == "MIN_MAX" - - # Check unnormalization info - unnorm_info = comp_data["unnormalized_keys"] - assert unnorm_info["observation.image"] == "MEAN_STD" - assert unnorm_info["action"] == "MIN_MAX" - - -def test_normalization_info_empty_transition(): - """Test that no normalization info is added for empty transitions.""" - features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), - "action": PolicyFeature(FeatureType.ACTION, (2,)), - } - - norm_map = { - FeatureType.VISUAL: NormalizationMode.MEAN_STD, - FeatureType.ACTION: NormalizationMode.MIN_MAX, - } - - stats = { - "observation.image": {"mean": [0.5], "std": [0.2]}, - "action": {"min": [-1.0], "max": [1.0]}, - } - - normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) - - # Empty transition - transition = create_transition() - - # Process the transition - normalized_transition = normalizer(transition) - - # Check that no normalization info is added - comp_data = normalized_transition.get(TransitionKey.COMPLEMENTARY_DATA) - assert comp_data is None or "normalized_keys" not in comp_data - - def test_hotswap_stats_functional_test(): """Test that hotswapped processor actually works functionally.""" # Create test data @@ -1631,8 +1379,8 @@ def test_min_equals_max_maps_to_minus_one(): assert torch.allclose(out[TransitionKey.OBSERVATION]["observation.state"], torch.tensor([-1.0])) -def test_action_normalized_despite_normalize_keys(): - """Action normalization is independent of normalize_keys filter for observations.""" +def test_action_normalized_despite_normalize_observation_keys(): + """Action normalization is independent of normalize_observation_keys filter for observations.""" features = { "observation.state": PolicyFeature(FeatureType.STATE, (1,)), "action": PolicyFeature(FeatureType.ACTION, (2,)), @@ -1640,7 +1388,7 @@ def test_action_normalized_despite_normalize_keys(): norm_map = {FeatureType.STATE: NormalizationMode.IDENTITY, FeatureType.ACTION: NormalizationMode.MEAN_STD} stats = {"action": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}} normalizer = NormalizerProcessor( - features=features, norm_map=norm_map, stats=stats, normalize_keys={"observation.state"} + features=features, norm_map=norm_map, stats=stats, normalize_observation_keys={"observation.state"} ) transition = create_transition( @@ -1680,19 +1428,6 @@ def test_unnormalize_observations_mean_std_and_min_max(): assert torch.allclose(out_mm, torch.tensor([1.0, 0.0])) # mid of [0,2] and [-2,2] -def test_rename_stats_basic(): - orig = { - "observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])}, - "action": {"mean": np.array([0.0])}, - } - mapping = {"observation.state": "observation.robot_state"} - renamed = rename_stats(orig, mapping) - assert "observation.robot_state" in renamed and "observation.state" not in renamed - # Ensure deep copy: mutate original and verify renamed unaffected - orig["observation.state"]["mean"][0] = 42.0 - assert renamed["observation.robot_state"]["mean"][0] != 42.0 - - def test_unknown_observation_keys_ignored(): features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))} norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} @@ -1705,8 +1440,6 @@ def test_unknown_observation_keys_ignored(): # Unknown key should pass through unchanged and not be tracked assert torch.allclose(out[TransitionKey.OBSERVATION]["observation.unknown"], obs["observation.unknown"]) - comp = out.get(TransitionKey.COMPLEMENTARY_DATA) or {} - assert "normalized_keys" in comp and "observation.unknown" not in comp["normalized_keys"] def test_batched_action_normalization(): @@ -1731,7 +1464,7 @@ def test_complementary_data_preservation(): tr = create_transition(observation={"observation.state": torch.tensor([1.0])}, complementary_data=comp) out = normalizer(tr) new_comp = out[TransitionKey.COMPLEMENTARY_DATA] - assert new_comp["existing"] == 123 and "normalized_keys" in new_comp + assert new_comp["existing"] == 123 def test_roundtrip_normalize_unnormalize_non_identity(): diff --git a/tests/processor/test_rename_processor.py b/tests/processor/test_rename_processor.py index 398b3ec9cdf..4efb249dd73 100644 --- a/tests/processor/test_rename_processor.py +++ b/tests/processor/test_rename_processor.py @@ -21,6 +21,7 @@ from lerobot.configs.types import FeatureType from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionKey +from lerobot.processor.rename_processor import rename_stats from tests.conftest import assert_contract_is_typed @@ -465,3 +466,16 @@ def test_features_chained_processors(policy_feature_factory): assert out["observation.image"] == spec["img"] assert out["extra"] == spec["extra"] assert_contract_is_typed(out) + + +def test_rename_stats_basic(): + orig = { + "observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])}, + "action": {"mean": np.array([0.0])}, + } + mapping = {"observation.state": "observation.robot_state"} + renamed = rename_stats(orig, mapping) + assert "observation.robot_state" in renamed and "observation.state" not in renamed + # Ensure deep copy: mutate original and verify renamed unaffected + orig["observation.state"]["mean"][0] = 42.0 + assert renamed["observation.robot_state"]["mean"][0] != 42.0