diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 4a8d177c5d..f2823e3e44 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -157,9 +157,16 @@ def _apply_transform( if self.device and tensor.device != self.device: tensor = tensor.to(self.device) + # For Accelerate compatibility: move stats to match input tensor device + input_device = tensor.device stats = self._tensor_stats[key] tensor = tensor.to(dtype=torch.float32) + # Move stats to input device if needed + stats_device = next(iter(stats.values())).device + if stats_device != input_device: + stats = _convert_stats_to_tensors({key: self.stats[key]}, device=input_device)[key] + if norm_mode == NormalizationMode.MEAN_STD and "mean" in stats and "std" in stats: mean, std = stats["mean"], stats["std"] # Avoid division by zero by adding a small epsilon.