-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Fix ACT temporal ensembling #319
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
681eb7b
94a9818
40e29d4
7dc4765
97daf1a
4124985
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |
| import inspect | ||
| from pathlib import Path | ||
|
|
||
| import einops | ||
| import pytest | ||
| import torch | ||
| from huggingface_hub import PyTorchModelHubMixin | ||
|
|
@@ -26,14 +27,15 @@ | |
| from lerobot.common.datasets.utils import cycle | ||
| from lerobot.common.envs.factory import make_env | ||
| from lerobot.common.envs.utils import preprocess_observation | ||
| from lerobot.common.policies.act.modeling_act import ACTTemporalEnsembler | ||
| from lerobot.common.policies.factory import ( | ||
| _policy_cfg_from_hydra_cfg, | ||
| get_policy_and_config_classes, | ||
| make_policy, | ||
| ) | ||
| from lerobot.common.policies.normalize import Normalize, Unnormalize | ||
| from lerobot.common.policies.policy_protocol import Policy | ||
| from lerobot.common.utils.utils import init_hydra_config | ||
| from lerobot.common.utils.utils import init_hydra_config, seeded_context | ||
| from lerobot.scripts.train import make_optimizer_and_scheduler | ||
| from tests.scripts.save_policy_to_safetensors import get_policy_stats | ||
| from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel | ||
|
|
@@ -390,3 +392,62 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides, file_nam | |
| assert torch.isclose(param_stats[key], saved_param_stats[key], rtol=50, atol=1e-7).all() | ||
| for key in saved_actions: | ||
| assert torch.isclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7).all() | ||
|
|
||
|
|
||
| def test_act_temporal_ensembler(): | ||
| """Check that the online method in ACTTemporalEnsembler matches a simple offline calculation.""" | ||
| temporal_ensemble_coeff = 0.01 | ||
| chunk_size = 100 | ||
| episode_length = 101 | ||
| ensembler = ACTTemporalEnsembler(temporal_ensemble_coeff, chunk_size) | ||
| # An batch of arbitrary sequences of 1D actions we wish to compute the average over. We'll keep the | ||
| # "action space" in [-1, 1]. Apart from that, there is no real reason for the numbers chosen. | ||
| with seeded_context(0): | ||
| # Dimension is (batch, episode_length, chunk_size, action_dim(=1)) | ||
| # Stepping through the episode_length dim is like running inference at each rollout step and getting | ||
| # a different action chunk. | ||
| batch_seq = torch.stack( | ||
| [ | ||
| torch.rand(episode_length, chunk_size) * 0.05 - 0.6, | ||
| torch.rand(episode_length, chunk_size) * 0.02 - 0.01, | ||
| torch.rand(episode_length, chunk_size) * 0.2 + 0.3, | ||
| ], | ||
| dim=0, | ||
| ).unsqueeze(-1) # unsqueeze for action dim | ||
| batch_size = batch_seq.shape[0] | ||
| # Exponential weighting (normalized). Unsqueeze once to match the position of the `episode_length` | ||
| # dimension of `batch_seq`. | ||
| weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(-1) | ||
|
|
||
| # Simulate stepping through a rollout and computing a batch of actions with model on each step. | ||
| for i in range(episode_length): | ||
| # Mock a batch of actions. | ||
| actions = torch.zeros(size=(batch_size, chunk_size, 1)) + batch_seq[:, i] | ||
| online_avg = ensembler.update(actions) | ||
| # Simple offline calculation: avg = Σ(aᵢ*wᵢ) / Σ(wᵢ). | ||
| # Note: The complicated bit here is the slicing. Think about the (episode_length, chunk_size) grid. | ||
| # What we want to do is take diagonal slices across it starting from the left. | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI: I think this gets a little hairy for a "simple" test, but I really wanted to make sure it's properly checked. I hope the explanation is enough to make the reviewer feel comfortable that this test is doing what it's supposed to. Perhaps it's enough to know that we do the same thing with two approaches and get the same answer. |
||
| # eg: chunk_size=4, episode_length=6 | ||
| # ┌───────┐ | ||
| # │0 1 2 3│ | ||
| # │1 2 3 4│ | ||
| # │2 3 4 5│ | ||
| # │3 4 5 6│ | ||
| # │4 5 6 7│ | ||
| # │5 6 7 8│ | ||
| # └───────┘ | ||
| chunk_indices = torch.arange(min(i, chunk_size - 1), -1, -1) | ||
| episode_step_indices = torch.arange(i + 1)[-len(chunk_indices) :] | ||
| seq_slice = batch_seq[:, episode_step_indices, chunk_indices] | ||
| offline_avg = ( | ||
| einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum") / weights[: i + 1].sum() | ||
| ) | ||
| # Sanity check. The average should be between the extrema. | ||
| assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg) | ||
| assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max")) | ||
| # Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error. | ||
| assert torch.allclose(online_avg, offline_avg, atol=1e-4) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| test_act_temporal_ensembler() | ||
Uh oh!
There was an error while loading. Please reload this page.