diff --git a/src/lerobot/policies/pi0fast/modeling_pi0fast.py b/src/lerobot/policies/pi0fast/modeling_pi0fast.py index 0e53bd3497..23dc339572 100644 --- a/src/lerobot/policies/pi0fast/modeling_pi0fast.py +++ b/src/lerobot/policies/pi0fast/modeling_pi0fast.py @@ -234,14 +234,14 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor: self._action_queue.extend(actions.transpose(0, 1)) return self._action_queue.popleft() - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: if self.config.adapt_to_pi_aloha: batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) batch = self.normalize_inputs(batch) batch = self.normalize_targets(batch) - loss_dict = self.model.forward(batch) - return loss_dict["loss"], loss_dict + loss, loss_dict = self.model.forward(batch) + return loss, loss_dict def block_causal_update_causal_mask( @@ -745,8 +745,8 @@ def forward(self, batch: dict[str, Tensor]): loss = token_loss.sum() / torch.clamp(loss_mask.sum(), min=1) # Return loss dictionary - loss_dict = {"ce_loss": loss.item(), "loss": loss} - return loss_dict + loss_dict = {"ce_loss": loss.item(), "loss": loss.item()} + return loss, loss_dict def decode_actions_with_fast( self,