Skip to content
32 changes: 32 additions & 0 deletions src/lerobot/policies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import torch
from torch import nn

from lerobot.datasets.utils import build_dataset_frame
from lerobot.utils.constants import OBS_STR


def populate_queues(
queues: dict[str, deque], batch: dict[str, torch.Tensor], exclude_keys: list[str] | None = None
Expand Down Expand Up @@ -85,3 +88,32 @@ def log_model_loading_keys(missing_keys: list[str], unexpected_keys: list[str])
logging.warning(f"Missing key(s) when loading model: {missing_keys}")
if unexpected_keys:
logging.warning(f"Unexpected key(s) when loading model: {unexpected_keys}")


def build_inference_frame(
observation: dict[str, torch.Tensor],
ds_features: dict[str, dict],
device: torch.device,
task: str | None = None,
robot_type: str | None = None,
) -> dict[str, torch.Tensor]:
"""Build a inference frame from a raw observation."""

# Extracts the inference keys from the raw observation
observation = build_dataset_frame(ds_features, observation, prefix=OBS_STR)

# Performs the necessary conversions to the observation
for name in observation:
observation[name] = torch.from_numpy(observation[name])
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()

# Needs to add a batch dimension when running inference
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)

observation["task"] = task if task else ""
observation["robot_type"] = robot_type if robot_type else ""

return observation
12 changes: 2 additions & 10 deletions src/lerobot/utils/control_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import DEFAULT_FEATURES
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import build_inference_frame
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.robots import Robot

Expand Down Expand Up @@ -102,16 +103,7 @@ def predict_action(
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation:
observation[name] = torch.from_numpy(observation[name])
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)

observation["task"] = task if task else ""
observation["robot_type"] = robot_type if robot_type else ""
observation = build_inference_frame(observation, device)

observation = preprocessor(observation)

Expand Down