Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/lerobot/policies/pi0fast/modeling_pi0fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,14 +234,14 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()

def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
loss_dict = self.model.forward(batch)
return loss_dict["loss"], loss_dict
loss, loss_dict = self.model.forward(batch)
return loss, loss_dict


def block_causal_update_causal_mask(
Expand Down Expand Up @@ -745,8 +745,8 @@ def forward(self, batch: dict[str, Tensor]):
loss = token_loss.sum() / torch.clamp(loss_mask.sum(), min=1)

# Return loss dictionary
loss_dict = {"ce_loss": loss.item(), "loss": loss}
return loss_dict
loss_dict = {"ce_loss": loss.item(), "loss": loss.item()}
return loss, loss_dict

def decode_actions_with_fast(
self,
Expand Down
Loading