@@ -119,9 +119,7 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
119119 batch = self .normalize_inputs (batch )
120120 if self .config .image_features :
121121 batch = dict (batch ) # shallow copy so that adding a key doesn't modify the original
122- batch ["observation.images" ] = torch .stack (
123- [batch [key ] for key in self .config .image_features ], dim = - 4
124- )
122+ batch ["observation.images" ] = [batch [key ] for key in self .config .image_features ]
125123
126124 # If we are doing temporal ensembling, do online updates where we keep track of the number of actions
127125 # we are ensembling over.
@@ -149,9 +147,8 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
149147 batch = self .normalize_inputs (batch )
150148 if self .config .image_features :
151149 batch = dict (batch ) # shallow copy so that adding a key doesn't modify the original
152- batch ["observation.images" ] = torch .stack (
153- [batch [key ] for key in self .config .image_features ], dim = - 4
154- )
150+ batch ["observation.images" ] = [batch [key ] for key in self .config .image_features ]
151+
155152 batch = self .normalize_targets (batch )
156153 actions_hat , (mu_hat , log_sigma_x2_hat ) = self .model (batch )
157154
@@ -413,11 +410,10 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso
413410 "actions must be provided when using the variational objective in training mode."
414411 )
415412
416- batch_size = (
417- batch ["observation.images" ]
418- if "observation.images" in batch
419- else batch ["observation.environment_state" ]
420- ).shape [0 ]
413+ if "observation.images" in batch :
414+ batch_size = batch ["observation.images" ][0 ].shape [0 ]
415+ else :
416+ batch_size = batch ["observation.environment_state" ].shape [0 ]
421417
422418 # Prepare the latent for input to the transformer encoder.
423419 if self .config .use_vae and "action" in batch :
@@ -490,20 +486,21 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso
490486 all_cam_features = []
491487 all_cam_pos_embeds = []
492488
493- for cam_index in range (batch ["observation.images" ].shape [- 4 ]):
494- cam_features = self .backbone (batch ["observation.images" ][:, cam_index ])["feature_map" ]
495- # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use
496- # buffer
489+ # For a list of images, the H and W may vary but H*W is constant.
490+ for img in batch ["observation.images" ]:
491+ cam_features = self .backbone (img )["feature_map" ]
497492 cam_pos_embed = self .encoder_cam_feat_pos_embed (cam_features ).to (dtype = cam_features .dtype )
498- cam_features = self .encoder_img_feat_input_proj (cam_features ) # (B, C, h, w)
493+ cam_features = self .encoder_img_feat_input_proj (cam_features )
494+
495+ # Rearrange features to (sequence, batch, dim).
496+ cam_features = einops .rearrange (cam_features , "b c h w -> (h w) b c" )
497+ cam_pos_embed = einops .rearrange (cam_pos_embed , "b c h w -> (h w) b c" )
498+
499499 all_cam_features .append (cam_features )
500500 all_cam_pos_embeds .append (cam_pos_embed )
501- # Concatenate camera observation feature maps and positional embeddings along the width dimension,
502- # and move to (sequence, batch, dim).
503- all_cam_features = torch .cat (all_cam_features , axis = - 1 )
504- encoder_in_tokens .extend (einops .rearrange (all_cam_features , "b c h w -> (h w) b c" ))
505- all_cam_pos_embeds = torch .cat (all_cam_pos_embeds , axis = - 1 )
506- encoder_in_pos_embed .extend (einops .rearrange (all_cam_pos_embeds , "b c h w -> (h w) b c" ))
501+
502+ encoder_in_tokens .extend (torch .cat (all_cam_features , axis = 0 ))
503+ encoder_in_pos_embed .extend (torch .cat (all_cam_pos_embeds , axis = 0 ))
507504
508505 # Stack all tokens along the sequence dimension.
509506 encoder_in_tokens = torch .stack (encoder_in_tokens , axis = 0 )
0 commit comments