diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index a34aa34f98..feff4bdc0d 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -392,6 +392,11 @@ def prepare_language(self, batch) -> tuple[Tensor, Tensor]: """Tokenize the text input""" device = batch[OBS_STATE].device tasks = batch["task"] + if isinstance(tasks, str): + tasks = [tasks] + + if len(tasks) == 1: + tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])] # PaliGemma prompt has to end with a new line tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]