Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/lerobot/policies/pi0/modeling_pi0.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
[Jax code](https://github.com/Physical-Intelligence/openpi)

Designed by Physical Intelligence. Ported from Jax by Hugging Face.
Disclaimer: It is not expected to perform as well as the original implementation.

Install pi0 extra dependencies:
```bash
Expand Down Expand Up @@ -260,6 +261,16 @@ def reset(self):
def get_optim_params(self) -> dict:
return self.parameters()

@classmethod
def from_pretrained(cls, *args, **kwargs):
"""Override the from_pretrained method to display important disclaimer."""
print(
"⚠️ DISCLAIMER: The PI0 model is ported from JAX by the Hugging Face team. \n"
" It is not expected to perform as well as the original implementation. \n"
" Original implementation: https://github.com/Physical-Intelligence/openpi"
)
return super().from_pretrained(*args, **kwargs)

@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
Expand Down
11 changes: 11 additions & 0 deletions src/lerobot/policies/pi0fast/modeling_pi0fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
[Jax code](https://github.com/Physical-Intelligence/openpi)

Designed by Physical Intelligence. Ported from Jax by Hugging Face.
Disclaimer: It is not expected to perform as well as the original implementation.

Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`):
```bash
Expand Down Expand Up @@ -162,6 +163,16 @@ def reset(self):
"""This should be called whenever the environment is reset."""
self._action_queue = deque([], maxlen=self.config.n_action_steps)

@classmethod
def from_pretrained(cls, *args, **kwargs):
"""Override the from_pretrained method to display important disclaimer."""
print(
"⚠️ DISCLAIMER: The PI0FAST model is ported from JAX by the Hugging Face team. \n"
" It is not expected to perform as well as the original implementation. \n"
" Original implementation: https://github.com/Physical-Intelligence/openpi"
)
return super().from_pretrained(*args, **kwargs)

def get_optim_params(self) -> dict:
return self.parameters()

Expand Down
Loading