From 7814633faa188109c1b57910e05bc51cd21156d3 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Thu, 16 Jan 2025 17:31:57 +0100 Subject: [PATCH 1/2] Add features back to policy configs --- lerobot/common/constants.py | 6 + lerobot/common/datasets/utils.py | 33 +++- lerobot/common/envs/configs.py | 88 ++++++--- lerobot/common/envs/utils.py | 25 +++ lerobot/common/policies/act/modeling_act.py | 16 +- .../diffusion/configuration_diffusion.py | 11 +- .../policies/diffusion/modeling_diffusion.py | 28 +-- lerobot/common/policies/factory.py | 16 +- lerobot/common/policies/normalize.py | 85 +++++---- .../policies/tdmpc/configuration_tdmpc.py | 2 +- .../common/policies/tdmpc/modeling_tdmpc.py | 32 +++- .../policies/vqbet/configuration_vqbet.py | 11 +- .../common/policies/vqbet/modeling_vqbet.py | 25 ++- lerobot/common/utils/utils.py | 11 ++ lerobot/configs/policies.py | 180 +++--------------- lerobot/configs/types.py | 7 + tests/test_policies.py | 65 ++++--- 17 files changed, 328 insertions(+), 313 deletions(-) create mode 100644 lerobot/common/constants.py diff --git a/lerobot/common/constants.py b/lerobot/common/constants.py new file mode 100644 index 00000000000..73889594863 --- /dev/null +++ b/lerobot/common/constants.py @@ -0,0 +1,6 @@ +# keys +OBS_ENV = "observation.environment_state" +OBS_ROBOT = "observation.state" +OBS_IMAGE = "observation.image" +OBS_IMAGES = "observation.images" +ACTION = "action" diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 1850c8aa463..612bac39ab7 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -35,7 +35,7 @@ from torchvision import transforms from lerobot.common.robot_devices.robots.utils import Robot -from lerobot.configs.types import DictLike +from lerobot.configs.types import DictLike, FeatureType, PolicyFeature DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk @@ -302,6 +302,37 @@ def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict: return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES} +def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]: + # TODO(aliberts): Implement "type" in dataset features and simplify this + policy_features = {} + for key, ft in features.items(): + shape = ft["shape"] + if ft["dtype"] in ["image", "video"]: + type = FeatureType.VISUAL + if len(shape) != 3: + raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})") + + names = ft["names"] + # Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets. + if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) + shape = (shape[2], shape[0], shape[1]) + elif key == "observation.environment_state": + type = FeatureType.ENV + elif key.startswith("observation"): + type = FeatureType.STATE + elif key == "action": + type = FeatureType.ACTION + else: + continue + + policy_features[key] = PolicyFeature( + type=type, + shape=shape, + ) + + return policy_features + + def create_empty_dataset_info( codebase_version: str, fps: int, diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py index 340682a6926..d0c53d22d67 100644 --- a/lerobot/common/envs/configs.py +++ b/lerobot/common/envs/configs.py @@ -3,7 +3,8 @@ import draccus -from lerobot.configs.types import FeatureType +from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT +from lerobot.configs.types import FeatureType, PolicyFeature @dataclass @@ -11,7 +12,8 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC): n_envs: int | None = None task: str | None = None fps: int = 30 - feature_types: dict = field(default_factory=dict) + features: dict[str, PolicyFeature] = field(default_factory=dict) + features_map: dict[str, str] = field(default_factory=dict) @property def type(self) -> str: @@ -28,17 +30,28 @@ class AlohaEnv(EnvConfig): task: str = "AlohaInsertion-v0" fps: int = 50 episode_length: int = 400 - feature_types: dict = field( + obs_type: str = "pixels_agent_pos" + render_mode: str = "rgb_array" + features: dict[str, PolicyFeature] = field( default_factory=lambda: { - "agent_pos": FeatureType.STATE, - "pixels": { - "top": FeatureType.VISUAL, - }, - "action": FeatureType.ACTION, + "action": PolicyFeature(type=FeatureType.ACTION, shape=(14,)), + } + ) + features_map: dict[str, str] = field( + default_factory=lambda: { + "action": ACTION, + "agent_pos": OBS_ROBOT, + "top": f"{OBS_IMAGE}.top", + "pixels/top": f"{OBS_IMAGES}.top", } ) - obs_type: str = "pixels_agent_pos" - render_mode: str = "rgb_array" + + def __post_init__(self): + if self.obs_type == "pixels": + self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3)) + elif self.obs_type == "pixels_agent_pos": + self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,)) + self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3)) @property def gym_kwargs(self) -> dict: @@ -55,25 +68,30 @@ class PushtEnv(EnvConfig): task: str = "PushT-v0" fps: int = 10 episode_length: int = 300 - feature_types: dict = field( - default_factory=lambda: { - "agent_pos": FeatureType.STATE, - "pixels": FeatureType.VISUAL, - "action": FeatureType.ACTION, - } - ) obs_type: str = "pixels_agent_pos" render_mode: str = "rgb_array" visualization_width: int = 384 visualization_height: int = 384 + features: dict[str, PolicyFeature] = field( + default_factory=lambda: { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + "agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(2,)), + } + ) + features_map: dict[str, str] = field( + default_factory=lambda: { + "action": ACTION, + "agent_pos": OBS_ROBOT, + "environment_state": OBS_ENV, + "pixels": OBS_IMAGE, + } + ) def __post_init__(self): - if self.obs_type == "environment_state_agent_pos": - self.feature_types = { - "agent_pos": FeatureType.STATE, - "environment_state": FeatureType.ENV, - "action": FeatureType.ACTION, - } + if self.obs_type == "pixels_agent_pos": + self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3)) + elif self.obs_type == "environment_state_agent_pos": + self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,)) @property def gym_kwargs(self) -> dict: @@ -91,17 +109,27 @@ class XarmEnv(EnvConfig): task: str = "XarmLift-v0" fps: int = 15 episode_length: int = 200 - feature_types: dict = field( - default_factory=lambda: { - "agent_pos": FeatureType.STATE, - "pixels": FeatureType.VISUAL, - "action": FeatureType.ACTION, - } - ) obs_type: str = "pixels_agent_pos" render_mode: str = "rgb_array" visualization_width: int = 384 visualization_height: int = 384 + features: dict[str, PolicyFeature] = field( + default_factory=lambda: { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)), + "pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)), + } + ) + features_map: dict[str, str] = field( + default_factory=lambda: { + "action": ACTION, + "agent_pos": OBS_ROBOT, + "pixels": OBS_IMAGE, + } + ) + + def __post_init__(self): + if self.obs_type == "pixels_agent_pos": + self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,)) @property def gym_kwargs(self) -> dict: diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py index 9f4f7e17e9f..06abee3f896 100644 --- a/lerobot/common/envs/utils.py +++ b/lerobot/common/envs/utils.py @@ -18,6 +18,10 @@ import torch from torch import Tensor +from lerobot.common.envs.configs import EnvConfig +from lerobot.common.utils.utils import get_channel_first_image_shape +from lerobot.configs.types import FeatureType, PolicyFeature + def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]: # TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding) @@ -36,6 +40,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten imgs = {"observation.image": observations["pixels"]} for imgkey, img in imgs.items(): + # TODO(aliberts, rcadene): use transforms.ToTensor()? img = torch.from_numpy(img) # sanity check that images are channel last @@ -61,3 +66,23 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten # requirement for "agent_pos" return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float() return return_observations + + +def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]: + # TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is + # (need to also refactor preprocess_observation and externalize normalization from policies) + policy_features = {} + for key, ft in env_cfg.features.items(): + if ft.type is FeatureType.VISUAL: + if len(ft.shape) != 3: + raise ValueError(f"Number of dimensions of {key} != 3 (shape={ft.shape})") + + shape = get_channel_first_image_shape(ft.shape, raise_if_not_channel_first=True) + feature = PolicyFeature(type=ft.type, shape=shape) + else: + feature = ft + + policy_key = env_cfg.features_map[key] + policy_features[policy_key] = feature + + return policy_features diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index da66b5aaa1f..2e7b1f9ce70 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -68,9 +68,13 @@ def __init__( config.validate_features() self.config = config - self.normalize_inputs = Normalize(config.input_features, dataset_stats) - self.normalize_targets = Normalize(config.output_features, dataset_stats) - self.unnormalize_outputs = Unnormalize(config.output_features, dataset_stats) + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) self.model = ACT(config) @@ -121,7 +125,7 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor: if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch["observation.images"] = torch.stack( - [batch[ft.key] for ft in self.config.image_features], dim=-4 + [batch[key] for key in self.config.image_features], dim=-4 ) # If we are doing temporal ensembling, do online updates where we keep track of the number of actions @@ -151,7 +155,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch["observation.images"] = torch.stack( - [batch[ft.key] for ft in self.config.image_features], dim=-4 + [batch[key] for key in self.config.image_features], dim=-4 ) batch = self.normalize_targets(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) @@ -411,7 +415,7 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso """ if self.config.use_vae and self.training: assert ( - self.config.action_feature.key in batch + "action" in batch ), "actions must be provided when using the variational objective in training mode." batch_size = ( diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index fb58b3efdb2..b92c9974c93 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -208,21 +208,20 @@ def validate_features(self) -> None: raise ValueError("You must provide at least one image or the environment state among the inputs.") if self.crop_shape is not None: - for image_ft in self.image_features: + for key, image_ft in self.image_features.items(): if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]: raise ValueError( f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} " f"for `crop_shape` and {image_ft.shape} for " - f"`{image_ft.key}`." + f"`{key}`." ) # Check that all input images have the same shape. - first_image_ft = next(iter(self.image_features)) - for image_ft in self.image_features: + first_image_key, first_image_ft = next(iter(self.image_features.items())) + for key, image_ft in self.image_features.items(): if image_ft.shape != first_image_ft.shape: raise ValueError( - f"`{image_ft.key}` does not match `{first_image_ft.key}`, but we " - "expect all image shapes to match." + f"`{key}` does not match `{first_image_key}`, but we " "expect all image shapes to match." ) @property diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index ec58c49ce9f..174bb24ede0 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -34,6 +34,7 @@ from huggingface_hub import PyTorchModelHubMixin from torch import Tensor, nn +from lerobot.common.constants import OBS_ENV, OBS_ROBOT from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.utils import ( @@ -74,9 +75,13 @@ def __init__( config.validate_features() self.config = config - self.normalize_inputs = Normalize(config.input_features, dataset_stats) - self.normalize_targets = Normalize(config.output_features, dataset_stats) - self.unnormalize_outputs = Unnormalize(config.output_features, dataset_stats) + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) # queues are populated during rollout of the policy, they contain the n latest observations and actions self._queues = None @@ -125,7 +130,7 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor: if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch["observation.images"] = torch.stack( - [batch[ft.key] for ft in self.config.image_features], dim=-4 + [batch[key] for key in self.config.image_features], dim=-4 ) # Note: It's important that this happens after stacking the images into a single key. self._queues = populate_queues(self._queues, batch) @@ -149,7 +154,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch["observation.images"] = torch.stack( - [batch[ft.key] for ft in self.config.image_features], dim=-4 + [batch[key] for key in self.config.image_features], dim=-4 ) batch = self.normalize_targets(batch) loss = self.diffusion.compute_loss(batch) @@ -237,8 +242,8 @@ def conditional_sample( def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor: """Encode image features and concatenate them all together along with the state vector.""" - batch_size, n_obs_steps = batch[self.config.robot_state_feature.key].shape[:2] - global_cond_feats = [batch[self.config.robot_state_feature.key]] + batch_size, n_obs_steps = batch[OBS_ROBOT].shape[:2] + global_cond_feats = [batch[OBS_ROBOT]] # Extract image features. if self.config.image_features: if self.config.use_separate_rgb_encoder_per_camera: @@ -268,7 +273,7 @@ def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor: global_cond_feats.append(img_features) if self.config.env_state_feature: - global_cond_feats.append(batch[self.config.env_state_feature.key]) + global_cond_feats.append(batch[OBS_ENV]) # Concatenate features then flatten to (B, global_cond_dim). return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1) @@ -482,10 +487,9 @@ def __init__(self, config: DiffusionConfig): # height and width from `config.image_features`. # Note: we have a check in the config class to make sure all images have the same shape. - dummy_shape_h_w = ( - config.crop_shape if config.crop_shape is not None else config.image_features[0].shape[1:] - ) - dummy_shape = (1, config.image_features[0].shape[0], *dummy_shape_h_w) + images_shape = next(iter(config.image_features.values())).shape + dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:] + dummy_shape = (1, images_shape[0], *dummy_shape_h_w) feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:] self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints) diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 6bc445685b4..e89c9a72665 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -18,13 +18,16 @@ from torch import nn from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata +from lerobot.common.datasets.utils import dataset_to_policy_features from lerobot.common.envs.configs import EnvConfig +from lerobot.common.envs.utils import env_to_policy_features from lerobot.common.policies.act.configuration_act import ACTConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.policy_protocol import Policy from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.configs.policies import PretrainedConfig +from lerobot.configs.types import FeatureType def get_policy_class(name: str) -> Policy: @@ -100,11 +103,18 @@ def make_policy( kwargs = {} if ds_meta is not None: - cfg.parse_features_from_dataset(ds_meta) + features = dataset_to_policy_features(ds_meta.features) kwargs["dataset_stats"] = ds_meta.stats else: - cfg.parse_features_from_env(env, env_cfg) - + if not cfg.pretrained_path or not cfg.output_features or not cfg.input_features: + raise NotImplementedError( + "The policy must have already existing features in its config when initializing it " + "with an environment." + ) + features = env_to_policy_features(env_cfg, cfg) + + cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} + cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features} kwargs["config"] = cfg if cfg.pretrained_path: diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index d9bba38401a..d8f021d9385 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -16,12 +16,12 @@ import torch from torch import Tensor, nn -from lerobot.configs.policies import PolicyFeature -from lerobot.configs.types import FeatureType, NormalizationMode +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature def create_stats_buffers( - features: list[PolicyFeature], + features: dict[str, PolicyFeature], + norm_map: dict[str, NormalizationMode], stats: dict[str, dict[str, Tensor]] | None = None, ) -> dict[str, dict[str, nn.ParameterDict]]: """ @@ -36,19 +36,20 @@ def create_stats_buffers( """ stats_buffers = {} - for ft in features: - if ft.normalization_mode is None: + for key, ft in features.items(): + norm_mode = norm_map.get(ft.type, None) + if norm_mode is None: continue - assert isinstance(ft.normalization_mode, NormalizationMode) + assert isinstance(norm_mode, NormalizationMode) shape = tuple(ft.shape) if ft.type is FeatureType.VISUAL: # sanity checks - assert len(shape) == 3, f"number of dimensions of {ft.key} != 3 ({shape=}" + assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}" c, h, w = shape - assert c < h and c < w, f"{ft.key} is not channel first ({shape=})" + assert c < h and c < w, f"{key} is not channel first ({shape=})" # override image shape to be invariant to height and width shape = (c, 1, 1) @@ -57,7 +58,7 @@ def create_stats_buffers( # we assert they are not infinity anymore. buffer = {} - if ft.normalization_mode is NormalizationMode.MEAN_STD: + if norm_mode is NormalizationMode.MEAN_STD: mean = torch.ones(shape, dtype=torch.float32) * torch.inf std = torch.ones(shape, dtype=torch.float32) * torch.inf buffer = nn.ParameterDict( @@ -66,7 +67,7 @@ def create_stats_buffers( "std": nn.Parameter(std, requires_grad=False), } ) - elif ft.normalization_mode is NormalizationMode.MIN_MAX: + elif norm_mode is NormalizationMode.MIN_MAX: min = torch.ones(shape, dtype=torch.float32) * torch.inf max = torch.ones(shape, dtype=torch.float32) * torch.inf buffer = nn.ParameterDict( @@ -81,14 +82,14 @@ def create_stats_buffers( # tensors anywhere (for example, when we use the same stats for normalization and # unnormalization). See the logic here # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. - if ft.normalization_mode is NormalizationMode.MEAN_STD: - buffer["mean"].data = stats[ft.key]["mean"].clone() - buffer["std"].data = stats[ft.key]["std"].clone() - elif ft.normalization_mode is NormalizationMode.MIN_MAX: - buffer["min"].data = stats[ft.key]["min"].clone() - buffer["max"].data = stats[ft.key]["max"].clone() - - stats_buffers[ft.key] = buffer + if norm_mode is NormalizationMode.MEAN_STD: + buffer["mean"].data = stats[key]["mean"].clone() + buffer["std"].data = stats[key]["std"].clone() + elif norm_mode is NormalizationMode.MIN_MAX: + buffer["min"].data = stats[key]["min"].clone() + buffer["max"].data = stats[key]["max"].clone() + + stats_buffers[key] = buffer return stats_buffers @@ -104,7 +105,8 @@ class Normalize(nn.Module): def __init__( self, - features: list[PolicyFeature], + features: dict[str, PolicyFeature], + norm_map: dict[str, NormalizationMode], stats: dict[str, dict[str, Tensor]] | None = None, ): """ @@ -127,8 +129,9 @@ def __init__( """ super().__init__() self.features = features + self.norm_map = norm_map self.stats = stats - stats_buffers = create_stats_buffers(features, stats) + stats_buffers = create_stats_buffers(features, norm_map, stats) for key, buffer in stats_buffers.items(): setattr(self, "buffer_" + key.replace(".", "_"), buffer) @@ -136,29 +139,30 @@ def __init__( @torch.no_grad def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: batch = dict(batch) # shallow copy avoids mutating the input batch - for ft in self.features: - if ft.normalization_mode is None: + for key, ft in self.features.items(): + norm_mode = self.norm_map.get(ft.type, None) + if norm_mode is None: continue - buffer = getattr(self, "buffer_" + ft.key.replace(".", "_")) + buffer = getattr(self, "buffer_" + key.replace(".", "_")) - if ft.normalization_mode is NormalizationMode.MEAN_STD: + if norm_mode is NormalizationMode.MEAN_STD: mean = buffer["mean"] std = buffer["std"] assert not torch.isinf(mean).any(), _no_stats_error_str("mean") assert not torch.isinf(std).any(), _no_stats_error_str("std") - batch[ft.key] = (batch[ft.key] - mean) / (std + 1e-8) - elif ft.normalization_mode is NormalizationMode.MIN_MAX: + batch[key] = (batch[key] - mean) / (std + 1e-8) + elif norm_mode is NormalizationMode.MIN_MAX: min = buffer["min"] max = buffer["max"] assert not torch.isinf(min).any(), _no_stats_error_str("min") assert not torch.isinf(max).any(), _no_stats_error_str("max") # normalize to [0,1] - batch[ft.key] = (batch[ft.key] - min) / (max - min + 1e-8) + batch[key] = (batch[key] - min) / (max - min + 1e-8) # normalize to [-1, 1] - batch[ft.key] = batch[ft.key] * 2 - 1 + batch[key] = batch[key] * 2 - 1 else: - raise ValueError(ft.normalization_mode) + raise ValueError(norm_mode) return batch @@ -170,7 +174,8 @@ class Unnormalize(nn.Module): def __init__( self, - features: list[PolicyFeature], + features: dict[str, PolicyFeature], + norm_map: dict[str, NormalizationMode], stats: dict[str, dict[str, Tensor]] | None = None, ): """ @@ -193,9 +198,10 @@ def __init__( """ super().__init__() self.features = features + self.norm_map = norm_map self.stats = stats # `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)` - stats_buffers = create_stats_buffers(features, stats) + stats_buffers = create_stats_buffers(features, norm_map, stats) for key, buffer in stats_buffers.items(): setattr(self, "buffer_" + key.replace(".", "_"), buffer) @@ -203,22 +209,23 @@ def __init__( @torch.no_grad def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: batch = dict(batch) # shallow copy avoids mutating the input batch - for ft in self.features: - buffer = getattr(self, "buffer_" + ft.key.replace(".", "_")) + for key, ft in self.features.items(): + norm_mode = self.norm_map.get(ft.type, None) + buffer = getattr(self, "buffer_" + key.replace(".", "_")) - if ft.normalization_mode is NormalizationMode.MEAN_STD: + if norm_mode is NormalizationMode.MEAN_STD: mean = buffer["mean"] std = buffer["std"] assert not torch.isinf(mean).any(), _no_stats_error_str("mean") assert not torch.isinf(std).any(), _no_stats_error_str("std") - batch[ft.key] = batch[ft.key] * std + mean - elif ft.normalization_mode is NormalizationMode.MIN_MAX: + batch[key] = batch[key] * std + mean + elif norm_mode is NormalizationMode.MIN_MAX: min = buffer["min"] max = buffer["max"] assert not torch.isinf(min).any(), _no_stats_error_str("min") assert not torch.isinf(max).any(), _no_stats_error_str("max") - batch[ft.key] = (batch[ft.key] + 1) / 2 - batch[ft.key] = batch[ft.key] * (max - min) + min + batch[key] = (batch[key] + 1) / 2 + batch[key] = batch[key] * (max - min) + min else: - raise ValueError(ft.normalization_mode) + raise ValueError(norm_mode) return batch diff --git a/lerobot/common/policies/tdmpc/configuration_tdmpc.py b/lerobot/common/policies/tdmpc/configuration_tdmpc.py index e0757e13d72..0d2b046ff4f 100644 --- a/lerobot/common/policies/tdmpc/configuration_tdmpc.py +++ b/lerobot/common/policies/tdmpc/configuration_tdmpc.py @@ -201,7 +201,7 @@ def validate_features(self) -> None: ) if len(self.image_features) > 0: - image_ft = next(iter(self.image_features)) + image_ft = next(iter(self.image_features.values())) if image_ft.shape[-2] != image_ft.shape[-1]: # TODO(alexander-soare): This limitation is solely because of code in the random shift # augmentation. It should be able to be removed. diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index f7f4c7b96fd..5106b466733 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -36,6 +36,7 @@ from huggingface_hub import PyTorchModelHubMixin from torch import Tensor +from lerobot.common.constants import OBS_ENV, OBS_ROBOT from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues @@ -79,9 +80,13 @@ def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tenso config.validate_features() self.config = config - self.normalize_inputs = Normalize(config.input_features, dataset_stats) - self.normalize_targets = Normalize(config.output_features, dataset_stats) - self.unnormalize_outputs = Unnormalize(config.output_features, dataset_stats) + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) self.model = TDMPCTOLD(config) self.model_target = deepcopy(self.model) @@ -116,7 +121,7 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor: batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.image"] = batch[self.config.image_features[0].key] + batch["observation.image"] = batch[next(iter(self.config.image_features))] self._queues = populate_queues(self._queues, batch) @@ -312,7 +317,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.image"] = batch[self.config.image_features[0].key] + batch["observation.image"] = batch[next(iter(self.config.image_features))] batch = self.normalize_targets(batch) info = {} @@ -696,7 +701,12 @@ def __init__(self, config: TDMPCConfig): if config.image_features: self.image_enc_layers = nn.Sequential( - nn.Conv2d(config.image_features[0].shape[0], config.image_encoder_hidden_dim, 7, stride=2), + nn.Conv2d( + next(iter(config.image_features.values())).shape[0], + config.image_encoder_hidden_dim, + 7, + stride=2, + ), nn.ReLU(), nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2), nn.ReLU(), @@ -705,7 +715,7 @@ def __init__(self, config: TDMPCConfig): nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), nn.ReLU(), ) - dummy_shape = (1, *config.image_features[0].shape) + dummy_shape = (1, *next(iter(config.image_features.values())).shape) out_shape = get_output_shape(self.image_enc_layers, dummy_shape)[1:] self.image_enc_layers.extend( nn.Sequential( @@ -744,12 +754,14 @@ def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: # NOTE: Order of observations matters here. if self.config.image_features: feat.append( - flatten_forward_unflatten(self.image_enc_layers, obs_dict[self.config.image_features[0].key]) + flatten_forward_unflatten( + self.image_enc_layers, obs_dict[next(iter(self.config.image_features))] + ) ) if self.config.env_state_feature: - feat.append(self.env_state_enc_layers(obs_dict[self.config.env_state_feature.key])) + feat.append(self.env_state_enc_layers(obs_dict[OBS_ENV])) if self.config.robot_state_feature: - feat.append(self.state_enc_layers(obs_dict[self.config.robot_state_feature.key])) + feat.append(self.state_enc_layers(obs_dict[OBS_ROBOT])) return torch.stack(feat, dim=0).mean(0) diff --git a/lerobot/common/policies/vqbet/configuration_vqbet.py b/lerobot/common/policies/vqbet/configuration_vqbet.py index 3f875729772..c2a3ca69609 100644 --- a/lerobot/common/policies/vqbet/configuration_vqbet.py +++ b/lerobot/common/policies/vqbet/configuration_vqbet.py @@ -171,21 +171,20 @@ def validate_features(self) -> None: raise ValueError("You must provide only one image among the inputs.") if self.crop_shape is not None: - for image_ft in self.image_features: + for key, image_ft in self.image_features.items(): if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]: raise ValueError( f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} " f"for `crop_shape` and {image_ft.shape} for " - f"`{image_ft.key}`." + f"`{key}`." ) # Check that all input images have the same shape. - first_image_ft = next(iter(self.image_features)) - for image_ft in self.image_features: + first_image_key, first_image_ft = next(iter(self.image_features.items())) + for key, image_ft in self.image_features.items(): if image_ft.shape != first_image_ft.shape: raise ValueError( - f"`{image_ft.key}` does not match `{first_image_ft.key}`, but we " - "expect all image shapes to match." + f"`{key}` does not match `{first_image_key}`, but we " "expect all image shapes to match." ) @property diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py index aadbd29603c..7f19b19e08b 100644 --- a/lerobot/common/policies/vqbet/modeling_vqbet.py +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -65,9 +65,13 @@ def __init__( config.validate_features() self.config = config - self.normalize_inputs = Normalize(config.input_features, dataset_stats) - self.normalize_targets = Normalize(config.output_features, dataset_stats) - self.unnormalize_outputs = Unnormalize(config.output_features, dataset_stats) + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) self.vqbet = VQBeTModel(config) @@ -135,9 +139,7 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor: batch = self.normalize_inputs(batch) batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = torch.stack( - [batch[ft.key] for ft in self.config.image_features], dim=-4 - ) + batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) # Note: It's important that this happens after stacking the images into a single key. self._queues = populate_queues(self._queues, batch) @@ -163,9 +165,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = torch.stack( - [batch[ft.key] for ft in self.config.image_features], dim=-4 - ) + batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) batch = self.normalize_targets(batch) # VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181) if not self.vqbet.action_head.vqvae_model.discretized.item(): @@ -703,10 +703,9 @@ def __init__(self, config: VQBeTConfig): # use the height and width from `config.crop_shape` if it is provided, otherwise it should use the # height and width from `config.image_features`. - dummy_shape_h_w = ( - config.crop_shape if config.crop_shape is not None else config.image_features[0].shape[1:] - ) - dummy_shape = (1, config.image_features[0].shape[0], *dummy_shape_h_w) + images_shape = next(iter(config.image_features.values())).shape + dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:] + dummy_shape = (1, images_shape[0], *dummy_shape_h_w) feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:] self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints) diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py index 45af5907f68..cbd5e8b3915 100644 --- a/lerobot/common/utils/utils.py +++ b/lerobot/common/utils/utils.py @@ -19,6 +19,7 @@ import platform import random from contextlib import contextmanager +from copy import copy from datetime import datetime, timezone from pathlib import Path from typing import Any, Generator @@ -199,3 +200,13 @@ def log_say(text, play_sounds, blocking=False): if play_sounds: say(text, blocking) + + +def get_channel_first_image_shape(image_shape: tuple) -> tuple: + shape = copy(image_shape) + if shape[2] < shape[0] and shape[2] < shape[1]: # (h, w, c) -> (c, h, w) + shape = (shape[2], shape[0], shape[1]) + elif not (shape[0] < shape[1] and shape[0] < shape[2]): + raise ValueError(image_shape) + + return shape diff --git a/lerobot/configs/policies.py b/lerobot/configs/policies.py index 90bfa008758..abdcaec48b3 100644 --- a/lerobot/configs/policies.py +++ b/lerobot/configs/policies.py @@ -3,34 +3,21 @@ from copy import copy from dataclasses import dataclass, field from pathlib import Path -from pprint import pformat from typing import Type, TypeVar import draccus -import gymnasium as gym from huggingface_hub import ModelHubMixin, hf_hub_download from huggingface_hub.constants import CONFIG_NAME from huggingface_hub.errors import HfHubHTTPError -from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata -from lerobot.common.datasets.utils import flatten_dict, get_nested_item -from lerobot.common.envs.configs import EnvConfig from lerobot.common.optim.optimizers import OptimizerConfig from lerobot.common.optim.schedulers import LRSchedulerConfig -from lerobot.configs.types import FeatureType, NormalizationMode +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -# Generic variable that is either ModelHubMixin or a subclass thereof +# Generic variable that is either PretrainedConfig or a subclass thereof T = TypeVar("T", bound="PretrainedConfig") -@dataclass -class PolicyFeature: - key: str - type: FeatureType - shape: list | tuple - normalization_mode: NormalizationMode - - @dataclass class PretrainedConfig(draccus.ChoiceRegistry, ModelHubMixin, abc.ABC): """ @@ -52,6 +39,9 @@ class PretrainedConfig(draccus.ChoiceRegistry, ModelHubMixin, abc.ABC): n_obs_steps: int = 1 normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict) + input_features: dict[str, PolicyFeature] = field(default_factory=dict) + output_features: dict[str, PolicyFeature] = field(default_factory=dict) + def __post_init__(self): self.type = self.get_choice_name(self.__class__) self.pretrained_path = None @@ -81,17 +71,29 @@ def validate_features(self) -> None: raise NotImplementedError @property - def input_features(self) -> list[PolicyFeature]: - input_features = [] - for ft in [self.robot_state_feature, self.env_state_feature, *self.image_features]: - if ft is not None: - input_features.append(ft) + def robot_state_feature(self) -> PolicyFeature | None: + for _, ft in self.input_features.items(): + if ft.type is FeatureType.STATE: + return ft + return None - return input_features + @property + def env_state_feature(self) -> PolicyFeature | None: + for _, ft in self.input_features.items(): + if ft.type is FeatureType.ENV: + return ft + return None + + @property + def image_features(self) -> dict[str, PolicyFeature]: + return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL} @property - def output_features(self) -> list[PolicyFeature]: - return [self.action_feature] + def action_feature(self) -> PolicyFeature | None: + for _, ft in self.output_features.items(): + if ft.type is FeatureType.ACTION: + return ft + return None def _save_pretrained(self, save_directory: Path) -> None: to_save = copy(self) @@ -141,135 +143,3 @@ def from_pretrained( cli_overrides = model_kwargs.pop("cli_overrides", []) instance = draccus.parse(cls, config_file, args=[]) return draccus.parse(instance.__class__, config_file, args=cli_overrides) - - def parse_features_from_dataset(self, ds_meta: LeRobotDatasetMetadata): - # TODO(aliberts): Implement PolicyFeature in LeRobotDataset and remove the need for this - robot_state_features = [] - env_state_features = [] - action_features = [] - image_features = [] - - for key in ds_meta.features: - if key in ds_meta.camera_keys: - shape = ds_meta.features[key]["shape"] - names = ds_meta.features[key]["names"] - if len(shape) != 3: - raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})") - # Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets. - if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) - shape = (shape[2], shape[0], shape[1]) - image_features.append( - PolicyFeature( - key=key, - type=FeatureType.VISUAL, - shape=shape, - normalization_mode=self.normalization_mapping[FeatureType.VISUAL], - ) - ) - elif key == "observation.environment_state": - env_state_features.append( - PolicyFeature( - key=key, - type=FeatureType.ENV, - shape=ds_meta.features[key]["shape"], - normalization_mode=self.normalization_mapping[FeatureType.ENV], - ) - ) - elif key.startswith("observation"): - robot_state_features.append( - PolicyFeature( - key=key, - type=FeatureType.STATE, - shape=ds_meta.features[key]["shape"], - normalization_mode=self.normalization_mapping[FeatureType.STATE], - ) - ) - elif key == "action": - action_features.append( - PolicyFeature( - key=key, - type=FeatureType.ACTION, - shape=ds_meta.features[key]["shape"], - normalization_mode=self.normalization_mapping[FeatureType.ACTION], - ) - ) - - if len(robot_state_features) > 1: - raise ValueError( - "Found multiple features for the robot's state. Please select only one or concatenate them." - f"Robot state features found:\n{pformat(robot_state_features)}" - ) - - if len(env_state_features) > 1: - raise ValueError( - "Found multiple features for the env's state. Please select only one or concatenate them." - f"Env state features found:\n{pformat(env_state_features)}" - ) - - if len(action_features) > 1: - raise ValueError( - "Found multiple features for the action. Please select only one or concatenate them." - f"Action features found:\n{pformat(action_features)}" - ) - - self.robot_state_feature = robot_state_features[0] if len(robot_state_features) == 1 else None - self.env_state_feature = env_state_features[0] if len(env_state_features) == 1 else None - self.action_feature = action_features[0] if len(action_features) == 1 else None - self.image_features = image_features - - def parse_features_from_env(self, env: gym.Env, env_cfg: EnvConfig): - robot_state_features = [] - env_state_features = [] - action_features = [] - image_features = [] - - flat_dict = flatten_dict(env_cfg.feature_types) - - for key, _type in flat_dict.items(): - env_ft = ( - env.action_space - if _type is FeatureType.ACTION - else get_nested_item(env.observation_space, key) - ) - shape = env_ft.shape[1:] - if _type is FeatureType.VISUAL: - h, w, c = shape - if not c < h and c < w: - raise ValueError( - f"Expect channel last images for visual feature {key} of {env_cfg.type} env, but instead got {shape=}" - ) - shape = (c, h, w) - - feature = PolicyFeature( - key=key, - type=_type, - shape=shape, - normalization_mode=self.normalization_mapping[_type], - ) - if _type is FeatureType.VISUAL: - image_features.append(feature) - elif _type is FeatureType.STATE: - robot_state_features.append(feature) - elif _type is FeatureType.ENV: - env_state_features.append(feature) - elif _type is FeatureType.ACTION: - action_features.append(feature) - - # TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is - # (need to also refactor preprocess_observation and externalize normalization from policies) - for ft in image_features: - if len(ft.key.split("/")) > 1: - ft.key = f"observation.images.{ft.key.split('/')[-1]}" - elif len(ft.key.split("/")) == 1: - image_features[0].key = "observation.image" - - if len(robot_state_features) == 1: - robot_state_features[0].key = "observation.state" - - if len(env_state_features) == 1: - env_state_features[0].key = "observation.environment_state" - - self.robot_state_feature = robot_state_features[0] if len(robot_state_features) == 1 else None - self.env_state_feature = env_state_features[0] if len(env_state_features) == 1 else None - self.action_feature = action_features[0] if len(action_features) == 1 else None - self.image_features = image_features diff --git a/lerobot/configs/types.py b/lerobot/configs/types.py index a5f6ac4fba2..f31f437b1e3 100644 --- a/lerobot/configs/types.py +++ b/lerobot/configs/types.py @@ -1,5 +1,6 @@ # Note: We subclass str so that serialization is straightforward # https://stackoverflow.com/questions/24481852/serialising-an-enum-member-to-json +from dataclasses import dataclass from enum import Enum from typing import Any, Protocol @@ -18,3 +19,9 @@ class NormalizationMode(str, Enum): class DictLike(Protocol): def __getitem__(self, key: Any) -> Any: ... + + +@dataclass +class PolicyFeature: + type: FeatureType + shape: tuple diff --git a/tests/test_policies.py b/tests/test_policies.py index bfb82d024d2..4d05af93c39 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -25,7 +25,7 @@ from lerobot import available_policies from lerobot.common.datasets.factory import make_dataset -from lerobot.common.datasets.utils import cycle +from lerobot.common.datasets.utils import cycle, dataset_to_policy_features from lerobot.common.envs.factory import make_env, make_env_config from lerobot.common.envs.utils import preprocess_observation from lerobot.common.optim.factory import make_optimizer_and_scheduler @@ -39,9 +39,8 @@ from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import seeded_context from lerobot.configs.default import DatasetConfig -from lerobot.configs.policies import PolicyFeature from lerobot.configs.training import TrainPipelineConfig -from lerobot.configs.types import FeatureType, NormalizationMode +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from tests.scripts.save_policy_to_safetensors import get_policy_stats from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel @@ -240,7 +239,11 @@ def test_policy_defaults(dummy_dataset_metadata, policy_name: str): """Check that the policy can be instantiated with defaults.""" policy_cls = get_policy_class(policy_name) policy_cfg = make_policy_config(policy_name) - policy_cfg.parse_features_from_dataset(dummy_dataset_metadata) + features = dataset_to_policy_features(dummy_dataset_metadata.features) + policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} + policy_cfg.input_features = { + key: ft for key, ft in features.items() if key not in policy_cfg.output_features + } policy_cls(policy_cfg) @@ -248,7 +251,11 @@ def test_policy_defaults(dummy_dataset_metadata, policy_name: str): def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name: str): policy_cls = get_policy_class(policy_name) policy_cfg = make_policy_config(policy_name) - policy_cfg.parse_features_from_dataset(dummy_dataset_metadata) + features = dataset_to_policy_features(dummy_dataset_metadata.features) + policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} + policy_cfg.input_features = { + key: ft for key, ft in features.items() if key not in policy_cfg.output_features + } policy = policy_cls(policy_cfg) save_dir = tmp_path / f"test_save_and_load_pretrained_{policy_cls.__name__}" policy.save_pretrained(save_dir) @@ -266,28 +273,28 @@ def test_normalize(insert_temporal_dim): expected. """ - input_features = [ - PolicyFeature( - key="observation.image", + input_features = { + "observation.image": PolicyFeature( type=FeatureType.VISUAL, - normalization_mode=NormalizationMode.MEAN_STD, - shape=[3, 96, 96], + shape=(3, 96, 96), ), - PolicyFeature( - key="observation.state", + "observation.state": PolicyFeature( type=FeatureType.STATE, - normalization_mode=NormalizationMode.MIN_MAX, - shape=[10], + shape=(10,), ), - ] - output_features = [ - PolicyFeature( - key="action", + } + output_features = { + "action": PolicyFeature( type=FeatureType.ACTION, - normalization_mode=NormalizationMode.MIN_MAX, - shape=[5], + shape=(5,), ), - ] + } + + norm_map = { + "VISUAL": NormalizationMode.MEAN_STD, + "STATE": NormalizationMode.MIN_MAX, + "ACTION": NormalizationMode.MIN_MAX, + } dataset_stats = { "observation.image": { @@ -330,30 +337,30 @@ def test_normalize(insert_temporal_dim): output_batch[key] = torch.stack([output_batch[key]] * tdim, dim=1) # test without stats - normalize = Normalize(input_features, stats=None) + normalize = Normalize(input_features, norm_map, stats=None) with pytest.raises(AssertionError): normalize(input_batch) # test with stats - normalize = Normalize(input_features, stats=dataset_stats) + normalize = Normalize(input_features, norm_map, stats=dataset_stats) normalize(input_batch) # test loading pretrained models - new_normalize = Normalize(input_features, stats=None) + new_normalize = Normalize(input_features, norm_map, stats=None) new_normalize.load_state_dict(normalize.state_dict()) new_normalize(input_batch) # test without stats - unnormalize = Unnormalize(output_features, stats=None) + unnormalize = Unnormalize(output_features, norm_map, stats=None) with pytest.raises(AssertionError): unnormalize(output_batch) # test with stats - unnormalize = Unnormalize(output_features, stats=dataset_stats) + unnormalize = Unnormalize(output_features, norm_map, stats=dataset_stats) unnormalize(output_batch) # test loading pretrained models - new_unnormalize = Unnormalize(output_features, stats=None) + new_unnormalize = Unnormalize(output_features, norm_map, stats=None) new_unnormalize.load_state_dict(unnormalize.state_dict()) unnormalize(output_batch) @@ -487,7 +494,3 @@ def test_act_temporal_ensembler(): assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max")) # Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error. assert torch.allclose(online_avg, offline_avg, atol=1e-4) - - -if __name__ == "__main__": - test_act_temporal_ensembler() From 34104dcf4dee7aec806e639d4db5a1769ddca437 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Thu, 16 Jan 2025 17:43:32 +0100 Subject: [PATCH 2/2] Fix --- lerobot/common/envs/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py index 06abee3f896..30bbaf39688 100644 --- a/lerobot/common/envs/utils.py +++ b/lerobot/common/envs/utils.py @@ -77,7 +77,7 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]: if len(ft.shape) != 3: raise ValueError(f"Number of dimensions of {key} != 3 (shape={ft.shape})") - shape = get_channel_first_image_shape(ft.shape, raise_if_not_channel_first=True) + shape = get_channel_first_image_shape(ft.shape) feature = PolicyFeature(type=ft.type, shape=shape) else: feature = ft