Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/advanced/1_train_act_pusht/act_pusht.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ policy:
n_vae_encoder_layers: 4

# Inference.
temporal_ensemble_momentum: null
temporal_ensemble_coeff: null

# Training and loss computation.
dropout: 0.1
Expand Down
14 changes: 6 additions & 8 deletions lerobot/common/policies/act/configuration_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,10 @@ class ACTConfig:
documentation in the policy class).
latent_dim: The VAE's latent dimension.
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
temporal_ensemble_momentum: Exponential moving average (EMA) momentum parameter (α) for ensembling
actions for a given time step over multiple policy invocations. Updates are calculated as:
x⁻ₙ = αx⁻ₙ₋₁ + (1-α)xₙ. Note that the ACT paper and original ACT code describes a different
parameter here: they refer to a weighting scheme wᵢ = exp(-m⋅i) and set m = 0.01. With our
formulation, this is equivalent to α = exp(-0.01) ≈ 0.99. When this parameter is provided, we
require `n_action_steps == 1` (since we need to query the policy every step anyway).
temporal_ensemble_coeff: Coefficient for the exponential weighting scheme to apply for temporal
ensembling. Defaults to None which means temporal ensembling is not used. `n_action_steps` must be
1 when using this feature, as inference needs to happen at every step to form an ensemble. For
more information on how ensembling works, please see `ACTTemporalEnsembler`.
dropout: Dropout to use in the transformer layers (see code for details).
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
Expand Down Expand Up @@ -139,7 +137,7 @@ class ACTConfig:
n_vae_encoder_layers: int = 4

# Inference.
temporal_ensemble_momentum: float | None = None
temporal_ensemble_coeff: float | None = None

# Training and loss computation.
dropout: float = 0.1
Expand All @@ -151,7 +149,7 @@ def __post_init__(self):
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
if self.temporal_ensemble_momentum is not None and self.n_action_steps > 1:
if self.temporal_ensemble_coeff is not None and self.n_action_steps > 1:
raise NotImplementedError(
"`n_action_steps` must be 1 when using temporal ensembling. This is "
"because the policy needs to be queried every step to compute the ensembled action."
Expand Down
114 changes: 96 additions & 18 deletions lerobot/common/policies/act/modeling_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,15 @@ def __init__(

self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]

if config.temporal_ensemble_coeff is not None:
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)

self.reset()

def reset(self):
"""This should be called whenever the environment is reset."""
if self.config.temporal_ensemble_momentum is not None:
self._ensembled_actions = None
if self.config.temporal_ensemble_coeff is not None:
self.temporal_ensembler.reset()
else:
self._action_queue = deque([], maxlen=self.config.n_action_steps)

Expand All @@ -100,24 +103,12 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
if len(self.expected_image_keys) > 0:
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)

# If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return
# the first action.
if self.config.temporal_ensemble_momentum is not None:
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
# we are ensembling over.
if self.config.temporal_ensemble_coeff is not None:
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
actions = self.unnormalize_outputs({"action": actions})["action"]
if self._ensembled_actions is None:
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
# time step of the episode.
self._ensembled_actions = actions.clone()
else:
# self._ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
# the EMA update for those entries.
alpha = self.config.temporal_ensemble_momentum
self._ensembled_actions = alpha * self._ensembled_actions + (1 - alpha) * actions[:, :-1]
# The last action, which has no prior moving average, needs to get concatenated onto the end.
self._ensembled_actions = torch.cat([self._ensembled_actions, actions[:, -1:]], dim=1)
# "Consume" the first action.
action, self._ensembled_actions = self._ensembled_actions[:, 0], self._ensembled_actions[:, 1:]
action = self.temporal_ensembler.update(actions)
return action

# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
Expand Down Expand Up @@ -162,6 +153,93 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
return loss_dict


class ACTTemporalEnsembler:
def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None:
"""Temporal ensembling as described in Algorithm 2 of https://arxiv.org/abs/2304.13705.

The weights are calculated as wᵢ = exp(-temporal_ensemble_coeff * i) where w₀ is the oldest action.
They are then normalized to sum to 1 by dividing by Σwᵢ.

Here we use an online method for computing the average rather than caching a history of actions in
order to compute the average offline. For a simple 1D sequence it looks something like:

```
import torch

seq = torch.linspace(8, 8.5, 100)
print(seq)

m = 0.01
exp_weights = torch.exp(-m * torch.arange(len(seq)))
print(exp_weights)

# Calculate offline
avg = (exp_weights * seq).sum() / exp_weights.sum()
print("offline", avg)

# Calculate online
for i, item in enumerate(seq):
if i == 0:
avg = item
continue
avg *= exp_weights[:i].sum()
avg += item * exp_weights[i]
avg /= exp_weights[:i+1].sum()
print("online", avg)
```

NOTE: The default value for `temporal_ensemble_coeff` used by the original ACT work is 0.01. This
results in older actions being weighed more highly than newer actions (if you know why, please submit
a PR with the explanation).
"""
self.chunk_size = chunk_size
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
self.reset()

def reset(self):
"""Resets the online computation variables."""
self.ensembled_actions = None
# (chunk_size,) count of how many actions are in the ensemble for each time step in the sequence.
self.ensembled_actions_count = None

def update(self, actions: Tensor) -> Tensor:
"""
Takes a (batch, chunk_size, action_dim) sequence of actions, update the temporal ensemble for all
time steps, and pop/return the next batch of actions in the sequence.
"""
self.ensemble_weights = self.ensemble_weights.to(device=actions.device)
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device)
if self.ensembled_actions is None:
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
# time step of the episode.
self.ensembled_actions = actions.clone()
# Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
# operations later.
self.ensembled_actions_count = torch.ones(
(self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device
)
else:
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
# the online update for those entries.
self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1]
self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count]
self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size)
# The last action, which has no prior online average, needs to get concatenated onto the end.
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1)
self.ensembled_actions_count = torch.cat(
[self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])]
)
# "Consume" the first action.
action, self.ensembled_actions, self.ensembled_actions_count = (
self.ensembled_actions[:, 0],
self.ensembled_actions[:, 1:],
self.ensembled_actions_count[1:],
)
return action


class ACT(nn.Module):
"""Action Chunking Transformer: The underlying neural network for ACTPolicy.

Expand Down
2 changes: 1 addition & 1 deletion lerobot/configs/policy/act.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ policy:
n_vae_encoder_layers: 4

# Inference.
temporal_ensemble_momentum: null
temporal_ensemble_coeff: null

# Training and loss computation.
dropout: 0.1
Expand Down
2 changes: 1 addition & 1 deletion lerobot/configs/policy/act_real.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ policy:
n_vae_encoder_layers: 4

# Inference.
temporal_ensemble_momentum: null
temporal_ensemble_coeff: null

# Training and loss computation.
dropout: 0.1
Expand Down
2 changes: 1 addition & 1 deletion lerobot/configs/policy/act_real_no_state.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ policy:
n_vae_encoder_layers: 4

# Inference.
temporal_ensemble_momentum: null
temporal_ensemble_coeff: null

# Training and loss computation.
dropout: 0.1
Expand Down
63 changes: 62 additions & 1 deletion tests/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import inspect
from pathlib import Path

import einops
import pytest
import torch
from huggingface_hub import PyTorchModelHubMixin
Expand All @@ -26,14 +27,15 @@
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.policies.act.modeling_act import ACTTemporalEnsembler
from lerobot.common.policies.factory import (
_policy_cfg_from_hydra_cfg,
get_policy_and_config_classes,
make_policy,
)
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import init_hydra_config
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from lerobot.scripts.train import make_optimizer_and_scheduler
from tests.scripts.save_policy_to_safetensors import get_policy_stats
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
Expand Down Expand Up @@ -390,3 +392,62 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides, file_nam
assert torch.isclose(param_stats[key], saved_param_stats[key], rtol=50, atol=1e-7).all()
for key in saved_actions:
assert torch.isclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7).all()


def test_act_temporal_ensembler():
"""Check that the online method in ACTTemporalEnsembler matches a simple offline calculation."""
temporal_ensemble_coeff = 0.01
chunk_size = 100
episode_length = 101
ensembler = ACTTemporalEnsembler(temporal_ensemble_coeff, chunk_size)
# An batch of arbitrary sequences of 1D actions we wish to compute the average over. We'll keep the
# "action space" in [-1, 1]. Apart from that, there is no real reason for the numbers chosen.
with seeded_context(0):
# Dimension is (batch, episode_length, chunk_size, action_dim(=1))
# Stepping through the episode_length dim is like running inference at each rollout step and getting
# a different action chunk.
batch_seq = torch.stack(
[
torch.rand(episode_length, chunk_size) * 0.05 - 0.6,
torch.rand(episode_length, chunk_size) * 0.02 - 0.01,
torch.rand(episode_length, chunk_size) * 0.2 + 0.3,
],
dim=0,
).unsqueeze(-1) # unsqueeze for action dim
batch_size = batch_seq.shape[0]
# Exponential weighting (normalized). Unsqueeze once to match the position of the `episode_length`
# dimension of `batch_seq`.
weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(-1)

# Simulate stepping through a rollout and computing a batch of actions with model on each step.
for i in range(episode_length):
# Mock a batch of actions.
actions = torch.zeros(size=(batch_size, chunk_size, 1)) + batch_seq[:, i]
online_avg = ensembler.update(actions)
# Simple offline calculation: avg = Σ(aᵢ*wᵢ) / Σ(wᵢ).
# Note: The complicated bit here is the slicing. Think about the (episode_length, chunk_size) grid.
# What we want to do is take diagonal slices across it starting from the left.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: I think this gets a little hairy for a "simple" test, but I really wanted to make sure it's properly checked. I hope the explanation is enough to make the reviewer feel comfortable that this test is doing what it's supposed to. Perhaps it's enough to know that we do the same thing with two approaches and get the same answer.

# eg: chunk_size=4, episode_length=6
# ┌───────┐
# │0 1 2 3│
# │1 2 3 4│
# │2 3 4 5│
# │3 4 5 6│
# │4 5 6 7│
# │5 6 7 8│
# └───────┘
chunk_indices = torch.arange(min(i, chunk_size - 1), -1, -1)
episode_step_indices = torch.arange(i + 1)[-len(chunk_indices) :]
seq_slice = batch_seq[:, episode_step_indices, chunk_indices]
offline_avg = (
einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum") / weights[: i + 1].sum()
)
# Sanity check. The average should be between the extrema.
assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg)
assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max"))
# Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error.
assert torch.allclose(online_avg, offline_avg, atol=1e-4)


if __name__ == "__main__":
test_act_temporal_ensembler()