Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions lerobot/common/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# keys
OBS_ENV = "observation.environment_state"
OBS_ROBOT = "observation.state"
OBS_IMAGE = "observation.image"
OBS_IMAGES = "observation.images"
ACTION = "action"
33 changes: 32 additions & 1 deletion lerobot/common/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from torchvision import transforms

from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.configs.types import DictLike
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature

DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk

Expand Down Expand Up @@ -302,6 +302,37 @@ def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES}


def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
# TODO(aliberts): Implement "type" in dataset features and simplify this
policy_features = {}
for key, ft in features.items():
shape = ft["shape"]
if ft["dtype"] in ["image", "video"]:
type = FeatureType.VISUAL
if len(shape) != 3:
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")

names = ft["names"]
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
shape = (shape[2], shape[0], shape[1])
elif key == "observation.environment_state":
type = FeatureType.ENV
elif key.startswith("observation"):
type = FeatureType.STATE
elif key == "action":
type = FeatureType.ACTION
else:
continue

policy_features[key] = PolicyFeature(
type=type,
shape=shape,
)

return policy_features


def create_empty_dataset_info(
codebase_version: str,
fps: int,
Expand Down
88 changes: 58 additions & 30 deletions lerobot/common/envs/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@

import draccus

from lerobot.configs.types import FeatureType
from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT
from lerobot.configs.types import FeatureType, PolicyFeature


@dataclass
class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
n_envs: int | None = None
task: str | None = None
fps: int = 30
feature_types: dict = field(default_factory=dict)
features: dict[str, PolicyFeature] = field(default_factory=dict)
features_map: dict[str, str] = field(default_factory=dict)

@property
def type(self) -> str:
Expand All @@ -28,17 +30,28 @@ class AlohaEnv(EnvConfig):
task: str = "AlohaInsertion-v0"
fps: int = 50
episode_length: int = 400
feature_types: dict = field(
obs_type: str = "pixels_agent_pos"
render_mode: str = "rgb_array"
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"agent_pos": FeatureType.STATE,
"pixels": {
"top": FeatureType.VISUAL,
},
"action": FeatureType.ACTION,
"action": PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"agent_pos": OBS_ROBOT,
"top": f"{OBS_IMAGE}.top",
"pixels/top": f"{OBS_IMAGES}.top",
}
)
obs_type: str = "pixels_agent_pos"
render_mode: str = "rgb_array"

def __post_init__(self):
if self.obs_type == "pixels":
self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
elif self.obs_type == "pixels_agent_pos":
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,))
self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))

@property
def gym_kwargs(self) -> dict:
Expand All @@ -55,25 +68,30 @@ class PushtEnv(EnvConfig):
task: str = "PushT-v0"
fps: int = 10
episode_length: int = 300
feature_types: dict = field(
default_factory=lambda: {
"agent_pos": FeatureType.STATE,
"pixels": FeatureType.VISUAL,
"action": FeatureType.ACTION,
}
)
obs_type: str = "pixels_agent_pos"
render_mode: str = "rgb_array"
visualization_width: int = 384
visualization_height: int = 384
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
"agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"agent_pos": OBS_ROBOT,
"environment_state": OBS_ENV,
"pixels": OBS_IMAGE,
}
)

def __post_init__(self):
if self.obs_type == "environment_state_agent_pos":
self.feature_types = {
"agent_pos": FeatureType.STATE,
"environment_state": FeatureType.ENV,
"action": FeatureType.ACTION,
}
if self.obs_type == "pixels_agent_pos":
self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3))
elif self.obs_type == "environment_state_agent_pos":
self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,))

@property
def gym_kwargs(self) -> dict:
Expand All @@ -91,17 +109,27 @@ class XarmEnv(EnvConfig):
task: str = "XarmLift-v0"
fps: int = 15
episode_length: int = 200
feature_types: dict = field(
default_factory=lambda: {
"agent_pos": FeatureType.STATE,
"pixels": FeatureType.VISUAL,
"action": FeatureType.ACTION,
}
)
obs_type: str = "pixels_agent_pos"
render_mode: str = "rgb_array"
visualization_width: int = 384
visualization_height: int = 384
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
"pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"agent_pos": OBS_ROBOT,
"pixels": OBS_IMAGE,
}
)

def __post_init__(self):
if self.obs_type == "pixels_agent_pos":
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))

@property
def gym_kwargs(self) -> dict:
Expand Down
25 changes: 25 additions & 0 deletions lerobot/common/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
import torch
from torch import Tensor

from lerobot.common.envs.configs import EnvConfig
from lerobot.common.utils.utils import get_channel_first_image_shape
from lerobot.configs.types import FeatureType, PolicyFeature


def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
Expand All @@ -36,6 +40,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
imgs = {"observation.image": observations["pixels"]}

for imgkey, img in imgs.items():
# TODO(aliberts, rcadene): use transforms.ToTensor()?
img = torch.from_numpy(img)

# sanity check that images are channel last
Expand All @@ -61,3 +66,23 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
# requirement for "agent_pos"
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
return return_observations


def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
# (need to also refactor preprocess_observation and externalize normalization from policies)
policy_features = {}
for key, ft in env_cfg.features.items():
if ft.type is FeatureType.VISUAL:
if len(ft.shape) != 3:
raise ValueError(f"Number of dimensions of {key} != 3 (shape={ft.shape})")

shape = get_channel_first_image_shape(ft.shape)
feature = PolicyFeature(type=ft.type, shape=shape)
else:
feature = ft

policy_key = env_cfg.features_map[key]
policy_features[policy_key] = feature

return policy_features
16 changes: 10 additions & 6 deletions lerobot/common/policies/act/modeling_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,13 @@ def __init__(
config.validate_features()
self.config = config

self.normalize_inputs = Normalize(config.input_features, dataset_stats)
self.normalize_targets = Normalize(config.output_features, dataset_stats)
self.unnormalize_outputs = Unnormalize(config.output_features, dataset_stats)
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)

self.model = ACT(config)

Expand Down Expand Up @@ -121,7 +125,7 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[ft.key] for ft in self.config.image_features], dim=-4
[batch[key] for key in self.config.image_features], dim=-4
)

# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
Expand Down Expand Up @@ -151,7 +155,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[ft.key] for ft in self.config.image_features], dim=-4
[batch[key] for key in self.config.image_features], dim=-4
)
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
Expand Down Expand Up @@ -411,7 +415,7 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso
"""
if self.config.use_vae and self.training:
assert (
self.config.action_feature.key in batch
"action" in batch
), "actions must be provided when using the variational objective in training mode."

batch_size = (
Expand Down
11 changes: 5 additions & 6 deletions lerobot/common/policies/diffusion/configuration_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,21 +208,20 @@ def validate_features(self) -> None:
raise ValueError("You must provide at least one image or the environment state among the inputs.")

if self.crop_shape is not None:
for image_ft in self.image_features:
for key, image_ft in self.image_features.items():
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
raise ValueError(
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for "
f"`{image_ft.key}`."
f"`{key}`."
)

# Check that all input images have the same shape.
first_image_ft = next(iter(self.image_features))
for image_ft in self.image_features:
first_image_key, first_image_ft = next(iter(self.image_features.items()))
for key, image_ft in self.image_features.items():
if image_ft.shape != first_image_ft.shape:
raise ValueError(
f"`{image_ft.key}` does not match `{first_image_ft.key}`, but we "
"expect all image shapes to match."
f"`{key}` does not match `{first_image_key}`, but we " "expect all image shapes to match."
)

@property
Expand Down
28 changes: 16 additions & 12 deletions lerobot/common/policies/diffusion/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor, nn

from lerobot.common.constants import OBS_ENV, OBS_ROBOT
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.utils import (
Expand Down Expand Up @@ -74,9 +75,13 @@ def __init__(
config.validate_features()
self.config = config

self.normalize_inputs = Normalize(config.input_features, dataset_stats)
self.normalize_targets = Normalize(config.output_features, dataset_stats)
self.unnormalize_outputs = Unnormalize(config.output_features, dataset_stats)
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)

# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
Expand Down Expand Up @@ -125,7 +130,7 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[ft.key] for ft in self.config.image_features], dim=-4
[batch[key] for key in self.config.image_features], dim=-4
)
# Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)
Expand All @@ -149,7 +154,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[ft.key] for ft in self.config.image_features], dim=-4
[batch[key] for key in self.config.image_features], dim=-4
)
batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)
Expand Down Expand Up @@ -237,8 +242,8 @@ def conditional_sample(

def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
"""Encode image features and concatenate them all together along with the state vector."""
batch_size, n_obs_steps = batch[self.config.robot_state_feature.key].shape[:2]
global_cond_feats = [batch[self.config.robot_state_feature.key]]
batch_size, n_obs_steps = batch[OBS_ROBOT].shape[:2]
global_cond_feats = [batch[OBS_ROBOT]]
# Extract image features.
if self.config.image_features:
if self.config.use_separate_rgb_encoder_per_camera:
Expand Down Expand Up @@ -268,7 +273,7 @@ def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
global_cond_feats.append(img_features)

if self.config.env_state_feature:
global_cond_feats.append(batch[self.config.env_state_feature.key])
global_cond_feats.append(batch[OBS_ENV])

# Concatenate features then flatten to (B, global_cond_dim).
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
Expand Down Expand Up @@ -482,10 +487,9 @@ def __init__(self, config: DiffusionConfig):
# height and width from `config.image_features`.

# Note: we have a check in the config class to make sure all images have the same shape.
dummy_shape_h_w = (
config.crop_shape if config.crop_shape is not None else config.image_features[0].shape[1:]
)
dummy_shape = (1, config.image_features[0].shape[0], *dummy_shape_h_w)
images_shape = next(iter(config.image_features.values())).shape
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]

self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
Expand Down
Loading
Loading