-
Notifications
You must be signed in to change notification settings - Fork 3.2k
fix pi0 prepare_language will raise an error if the task is a string #1625
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix pi0 prepare_language will raise an error if the task is a string #1625
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes a bug in the Pi0 model where string-based task inputs during inference cause tokenization errors. The fix ensures string tasks are properly wrapped in a list and replicated to match the batch size.
- Adds type checking to handle single string task inputs
- Implements batch size replication for single task strings
- Aligns Pi0 behavior with smol-vla model handling
| if isinstance(tasks, str): | ||
| tasks = [tasks] | ||
|
|
||
| if len(tasks) == 1: |
Copilot
AI
Jul 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition len(tasks) == 1 will always be true after the previous block that converts a string to a single-element list. This means single tasks will always be replicated to match batch size, even when the original input was already a list with one element that shouldn't be replicated.
| if isinstance(tasks, str): | |
| tasks = [tasks] | |
| if len(tasks) == 1: | |
| was_string = isinstance(tasks, str) # Track if the input was originally a string | |
| if was_string: | |
| tasks = [tasks] | |
| if was_string and len(tasks) == 1: # Only replicate if the input was originally a string |
| if isinstance(tasks, str): | ||
| tasks = [tasks] | ||
|
|
||
| if len(tasks) == 1: |
Copilot
AI
Jul 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic assumes that when len(tasks) == 1, it should replicate the task across the batch dimension. However, this doesn't distinguish between a single string input (which should be replicated) and a legitimate single-element list (which may not need replication). Consider checking the original input type or batch size mismatch instead.
| if isinstance(tasks, str): | |
| tasks = [tasks] | |
| if len(tasks) == 1: | |
| was_string = isinstance(tasks, str) # Track if the input was originally a string | |
| if was_string: | |
| tasks = [tasks] | |
| if was_string or len(tasks) == 1 and len(tasks) != batch[OBS_STATE].shape[0]: |
|
hey @captainfffsama, this issue will be fixed when #1431 #1452 will be merged 😄 |
Fix: Handle string-based task inputs during Pi0 inference
Description:
When the task input is a string during inference with record.py, it was incorrectly split, leading to tokenization errors. This PR wraps the string in a list to ensure correct processing, aligning the behavior with smol-vla.