Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 56 additions & 25 deletions lerobot/common/policies/act/modeling_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,12 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
images = [batch[key] for key in self.config.image_features]

if all(img.shape == images[0].shape for img in images):
batch["observation.images"] = torch.stack(images, dim=-4)
else:
batch["observation.images"] = images

# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
# we are ensembling over.
Expand Down Expand Up @@ -149,9 +152,13 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
images = [batch[key] for key in self.config.image_features]

if all(img.shape == images[0].shape for img in images):
batch["observation.images"] = torch.stack(images, dim=-4)
else:
batch["observation.images"] = images

batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)

Expand Down Expand Up @@ -413,11 +420,13 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso
"actions must be provided when using the variational objective in training mode."
)

batch_size = (
batch["observation.images"]
if "observation.images" in batch
else batch["observation.environment_state"]
).shape[0]
if "observation.images" in batch:
if isinstance(batch["observation.images"], list):
batch_size = batch["observation.images"][0].shape[0]
else:
batch_size = batch["observation.images"].shape[0]
else:
batch_size = batch["observation.environment_state"].shape[0]

# Prepare the latent for input to the transformer encoder.
if self.config.use_vae and "action" in batch:
Expand Down Expand Up @@ -490,20 +499,42 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso
all_cam_features = []
all_cam_pos_embeds = []

for cam_index in range(batch["observation.images"].shape[-4]):
cam_features = self.backbone(batch["observation.images"][:, cam_index])["feature_map"]
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use
# buffer
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
all_cam_features.append(cam_features)
all_cam_pos_embeds.append(cam_pos_embed)
# Concatenate camera observation feature maps and positional embeddings along the width dimension,
# and move to (sequence, batch, dim).
all_cam_features = torch.cat(all_cam_features, axis=-1)
encoder_in_tokens.extend(einops.rearrange(all_cam_features, "b c h w -> (h w) b c"))
all_cam_pos_embeds = torch.cat(all_cam_pos_embeds, axis=-1)
encoder_in_pos_embed.extend(einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c"))
if isinstance(batch["observation.images"], list):
all_tokens = []
all_pos_tokens = []

# For a list of images, the H and W may vary but H*W is constant.
for img in batch["observation.images"]:
cam_features = self.backbone(img)["feature_map"]
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features)

# Rearrange features to (sequence, batch, dim).
tokens = einops.rearrange(cam_features, "b c h w -> (h w) b c")
pos_tokens = einops.rearrange(cam_pos_embed, "b c h w -> (h w) b c")

all_tokens.append(tokens)
all_pos_tokens.append(pos_tokens)

encoder_in_tokens.extend(torch.cat(all_tokens, axis=0))
encoder_in_pos_embed.extend(torch.cat(all_pos_tokens, axis=0))

else:
for cam_index in range(batch["observation.images"].shape[-4]):
cam_features = self.backbone(batch["observation.images"][:, cam_index])["feature_map"]
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use
# buffer
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features)
all_cam_features.append(cam_features)
all_cam_pos_embeds.append(cam_pos_embed)

# Concatenate camera observation feature maps and positional embeddings along the width dimension,
# and move to (sequence, batch, dim).
all_cam_features = torch.cat(all_cam_features, axis=-1)
all_cam_pos_embeds = torch.cat(all_cam_pos_embeds, axis=-1)
encoder_in_tokens.extend(einops.rearrange(all_cam_features, "b c h w -> (h w) b c"))
encoder_in_pos_embed.extend(einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c"))

# Stack all tokens along the sequence dimension.
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
Expand Down
Loading