Skip to content
113 changes: 113 additions & 0 deletions src/lerobot/policies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,16 @@

import logging
from collections import deque
from typing import Any

import numpy as np
import torch
from torch import nn

from lerobot.datasets.utils import build_dataset_frame
from lerobot.processor import PolicyAction, RobotAction, RobotObservation
from lerobot.utils.constants import ACTION, OBS_STR


def populate_queues(
queues: dict[str, deque], batch: dict[str, torch.Tensor], exclude_keys: list[str] | None = None
Expand Down Expand Up @@ -85,3 +91,110 @@ def log_model_loading_keys(missing_keys: list[str], unexpected_keys: list[str])
logging.warning(f"Missing key(s) when loading model: {missing_keys}")
if unexpected_keys:
logging.warning(f"Unexpected key(s) when loading model: {unexpected_keys}")


# TODO(Steven): Move this function to a proper preprocessor step
def prepare_observation_for_inference(
observation: dict[str, np.ndarray],
device: torch.device,
task: str | None = None,
robot_type: str | None = None,
) -> RobotObservation:
"""Converts observation data to model-ready PyTorch tensors.

This function takes a dictionary of NumPy arrays, performs necessary
preprocessing, and prepares it for model inference. The steps include:
1. Converting NumPy arrays to PyTorch tensors.
2. Normalizing and permuting image data (if any).
3. Adding a batch dimension to each tensor.
4. Moving all tensors to the specified compute device.
5. Adding task and robot type information to the dictionary.

Args:
observation: A dictionary mapping observation names (str) to NumPy
array data. For images, the format is expected to be (H, W, C).
device: The PyTorch device (e.g., 'cpu' or 'cuda') to which the
tensors will be moved.
task: An optional string identifier for the current task.
robot_type: An optional string identifier for the robot being used.

Returns:
A dictionary where values are PyTorch tensors preprocessed for
inference, residing on the target device. Image tensors are reshaped
to (C, H, W) and normalized to a [0, 1] range.
"""
for name in observation:
observation[name] = torch.from_numpy(observation[name])
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)

observation["task"] = task if task else ""
observation["robot_type"] = robot_type if robot_type else ""

return observation


def build_inference_frame(
observation: dict[str, Any],
device: torch.device,
ds_features: dict[str, dict],
task: str | None = None,
robot_type: str | None = None,
) -> RobotObservation:
"""Constructs a model-ready observation tensor dict from a raw observation.

This utility function orchestrates the process of converting a raw,
unstructured observation from an environment into a structured,
tensor-based format suitable for passing to a policy model.

Args:
observation: The raw observation dictionary, which may contain
superfluous keys.
device: The target PyTorch device for the final tensors.
ds_features: A configuration dictionary that specifies which features
to extract from the raw observation.
task: An optional string identifier for the current task.
robot_type: An optional string identifier for the robot being used.

Returns:
A dictionary of preprocessed tensors ready for model inference.
"""
# Extracts the correct keys from the incoming raw observation
observation = build_dataset_frame(ds_features, observation, prefix=OBS_STR)

# Performs the necessary conversions to the observation
observation = prepare_observation_for_inference(observation, device, task, robot_type)

return observation


def make_robot_action(action_tensor: PolicyAction, ds_features: dict[str, dict]) -> RobotAction:
"""Converts a policy's output tensor into a dictionary of named actions.

This function translates the numerical output from a policy model into a
human-readable and robot-consumable format, where each dimension of the
action tensor is mapped to a named motor or actuator command.

Args:
action_tensor: A PyTorch tensor representing the policy's action,
typically with a batch dimension (e.g., shape [1, action_dim]).
ds_features: A configuration dictionary containing metadata, including
the names corresponding to each index of the action tensor.

Returns:
A dictionary mapping action names (e.g., "joint_1_motor") to their
corresponding floating-point values, ready to be sent to a robot
controller.
"""
# TODO(Steven): Check if these steps are already in all postprocessor policies
action_tensor = action_tensor.squeeze(0)
action_tensor = action_tensor.to("cpu")

action_names = ds_features[ACTION]["names"]
act_processed_policy: RobotAction = {
f"{name}": float(action_tensor[i]) for i, name in enumerate(action_names)
}
return act_processed_policy
6 changes: 2 additions & 4 deletions src/lerobot/scripts/lerobot_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import make_robot_action
from lerobot.processor import (
PolicyAction,
PolicyProcessorPipeline,
Expand Down Expand Up @@ -316,10 +317,7 @@ def record_loop(
robot_type=robot.robot_type,
)

action_names = dataset.features[ACTION]["names"]
act_processed_policy: RobotAction = {
f"{name}": float(action_values[i]) for i, name in enumerate(action_names)
}
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)

elif policy is None and isinstance(teleop, Teleoperator):
act = teleop.get_action()
Expand Down
19 changes: 2 additions & 17 deletions src/lerobot/utils/control_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import DEFAULT_FEATURES
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import prepare_observation_for_inference
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.robots import Robot

Expand Down Expand Up @@ -102,17 +103,7 @@ def predict_action(
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation:
observation[name] = torch.from_numpy(observation[name])
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)

observation["task"] = task if task else ""
observation["robot_type"] = robot_type if robot_type else ""

observation = prepare_observation_for_inference(observation, device, task, robot_type)
observation = preprocessor(observation)

# Compute the next action with the policy
Expand All @@ -121,12 +112,6 @@ def predict_action(

action = postprocessor(action)

# Remove batch dimension
action = action.squeeze(0)

# Move to cpu, if not already the case
action = action.to("cpu")

return action


Expand Down