Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions lerobot/common/policies/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import logging

import torch
from torch import nn

from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
Expand Down Expand Up @@ -76,7 +75,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:

def make_policy(
cfg: PreTrainedConfig,
device: str | torch.device,
ds_meta: LeRobotDatasetMetadata | None = None,
env_cfg: EnvConfig | None = None,
) -> PreTrainedPolicy:
Expand All @@ -88,15 +86,14 @@ def make_policy(
Args:
cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
be loaded with the weights from that path.
device (str): the device to load the policy onto.
ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
provided if ds_meta is not. Defaults to None.

Raises:
ValueError: Either ds_meta or env and env_cfg must be provided.
NotImplementedError: if the policy.type is 'vqbet' and the device 'mps' (due to an incompatibility)
NotImplementedError: if the policy.type is 'vqbet' and the policy device 'mps' (due to an incompatibility)

Returns:
PreTrainedPolicy: _description_
Expand All @@ -111,7 +108,7 @@ def make_policy(
# https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment
# variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be
# slower than running natively on MPS.
if cfg.type == "vqbet" and str(device) == "mps":
if cfg.type == "vqbet" and cfg.device == "mps":
raise NotImplementedError(
"Current implementation of VQBeT does not support `mps` backend. "
"Please use `cpu` or `cuda` backend."
Expand Down Expand Up @@ -145,7 +142,7 @@ def make_policy(
# Make a fresh policy.
policy = policy_cls(**kwargs)

policy.to(device)
policy.to(cfg.device)
assert isinstance(policy, nn.Module)

# policy = torch.compile(policy, mode="reduce-overhead")
Expand Down
1 change: 1 addition & 0 deletions lerobot/common/policies/pi0/configuration_pi0.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class PI0Config(PreTrainedConfig):
def __post_init__(self):
super().__post_init__()

# TODO(Steven): Validate device and amp? in all policy configs?
"""Input validation (not exhaustive)."""
if self.n_action_steps > self.chunk_size:
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def main():

cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
cfg.pretrained_path = ckpt_torch_dir
policy = make_policy(cfg, device, ds_meta=dataset.meta)
policy = make_policy(cfg, ds_meta=dataset.meta)

# policy = torch.compile(policy, mode="reduce-overhead")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def main():

cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
cfg.pretrained_path = ckpt_torch_dir
policy = make_policy(cfg, device, dataset_meta)
policy = make_policy(cfg, dataset_meta)

# loss_dict = policy.forward(batch, noise=noise, time=time_beta)
# loss_dict["loss"].backward()
Expand Down
7 changes: 3 additions & 4 deletions lerobot/common/policies/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def from_pretrained(
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
map_location: str = "cpu",
strict: bool = False,
**kwargs,
) -> T:
Expand All @@ -98,7 +97,7 @@ def from_pretrained(
if os.path.isdir(model_id):
print("Loading weights from local directory")
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
else:
try:
model_file = hf_hub_download(
Expand All @@ -112,13 +111,13 @@ def from_pretrained(
token=token,
local_files_only=local_files_only,
)
policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
) from e

policy.to(map_location)
policy.to(config.device)
policy.eval()
return policy

Expand Down
29 changes: 0 additions & 29 deletions lerobot/common/robot_devices/control_configs.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import logging
from dataclasses import dataclass
from pathlib import Path

import draccus

from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig


@dataclass
Expand Down Expand Up @@ -43,11 +40,6 @@ class RecordControlConfig(ControlConfig):
# Root directory where the dataset will be stored (e.g. 'dataset/path').
root: str | Path | None = None
policy: PreTrainedConfig | None = None
# TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
device: str | None = None # cuda | cpu | mps
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: bool | None = None
# Limit the frames per second. By default, uses the policy fps.
fps: int | None = None
# Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.
Expand Down Expand Up @@ -90,27 +82,6 @@ def __post_init__(self):
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path

# When no device or use_amp are given, use the one from training config.
if self.device is None or self.use_amp is None:
train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
if self.device is None:
self.device = train_cfg.device
if self.use_amp is None:
self.use_amp = train_cfg.use_amp

# Automatically switch to available device if necessary
if not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
self.device = auto_device

# Automatically deactivate AMP if necessary
if self.use_amp and not is_amp_available(self.device):
logging.warning(
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
)
self.use_amp = False


@ControlConfig.register_subclass("replay")
@dataclass
Expand Down
3 changes: 2 additions & 1 deletion lerobot/common/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def auto_select_torch_device() -> torch.device:
return torch.device("cpu")


# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
"""Given a string, return a torch.device with checks on whether the device is available."""
match try_device:
Expand Down Expand Up @@ -92,7 +93,7 @@ def is_torch_device_available(try_device: str) -> bool:
elif try_device == "cpu":
return True
else:
raise ValueError(f"Unknown device '{try_device}.")
raise ValueError(f"Unknown device '{try_device}. Supported devices are: cuda, mps or cpu.")


def is_amp_available(device: str):
Expand Down
33 changes: 0 additions & 33 deletions lerobot/configs/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
from pathlib import Path

from lerobot.common import envs, policies # noqa: F401
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
from lerobot.configs import parser
from lerobot.configs.default import EvalConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig


@dataclass
Expand All @@ -21,11 +19,6 @@ class EvalPipelineConfig:
policy: PreTrainedConfig | None = None
output_dir: Path | None = None
job_name: str | None = None
# TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
device: str | None = None # cuda | cpu | mps
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: bool = False
seed: int | None = 1000

def __post_init__(self):
Expand All @@ -36,27 +29,6 @@ def __post_init__(self):
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path

# When no device or use_amp are given, use the one from training config.
if self.device is None or self.use_amp is None:
train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
if self.device is None:
self.device = train_cfg.device
if self.use_amp is None:
self.use_amp = train_cfg.use_amp

# Automatically switch to available device if necessary
if not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
self.device = auto_device

# Automatically deactivate AMP if necessary
if self.use_amp and not is_amp_available(self.device):
logging.warning(
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
)
self.use_amp = False

else:
logging.warning(
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
Expand All @@ -73,11 +45,6 @@ def __post_init__(self):
eval_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
self.output_dir = Path("outputs/eval") / eval_dir

if self.device is None:
raise ValueError("Set one of the following device: cuda, cpu or mps")
elif self.device == "cuda" and self.use_amp is None:
raise ValueError("Set 'use_amp' to True or False.")

@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
Expand Down
19 changes: 19 additions & 0 deletions lerobot/configs/policies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
Expand All @@ -12,6 +13,7 @@
from lerobot.common.optim.optimizers import OptimizerConfig
from lerobot.common.optim.schedulers import LRSchedulerConfig
from lerobot.common.utils.hub import HubMixin
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature

# Generic variable that is either PreTrainedConfig or a subclass thereof
Expand Down Expand Up @@ -40,8 +42,25 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict)

# TODO(Steven): Should we implement a getter for these?
device: str | None = None # cuda | cpu | mp
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: bool = False

def __post_init__(self):
self.pretrained_path = None
if not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
self.device = auto_device

# Automatically deactivate AMP if necessary
if self.use_amp and not is_amp_available(self.device):
logging.warning(
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
)
self.use_amp = False

@property
def type(self) -> str:
Expand Down
18 changes: 0 additions & 18 deletions lerobot/configs/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import datetime as dt
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
Expand All @@ -13,7 +12,6 @@
from lerobot.common.optim import OptimizerConfig
from lerobot.common.optim.schedulers import LRSchedulerConfig
from lerobot.common.utils.hub import HubMixin
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available
from lerobot.configs import parser
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
from lerobot.configs.policies import PreTrainedConfig
Expand All @@ -35,10 +33,6 @@ class TrainPipelineConfig(HubMixin):
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
# regardless of what's provided with the training command at the time of resumption.
resume: bool = False
device: str | None = None # cuda | cpu | mp
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: bool = False
# `seed` is used for training (eg: model initialization, dataset shuffling)
# AND for the evaluation environments.
seed: int | None = 1000
Expand All @@ -61,18 +55,6 @@ def __post_init__(self):
self.checkpoint_path = None

def validate(self):
if not self.device:
logging.warning("No device specified, trying to infer device automatically")
device = auto_select_torch_device()
self.device = device.type

# Automatically deactivate AMP if necessary
if self.use_amp and not is_amp_available(self.device):
logging.warning(
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
)
self.use_amp = False

# HACK: We parse again the cli args here to get the pretrained paths if there was some.
policy_path = parser.get_path_arg("policy")
if policy_path:
Expand Down
6 changes: 3 additions & 3 deletions lerobot/scripts/control_robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def record(
)

# Load pretrained policy
policy = None if cfg.policy is None else make_policy(cfg.policy, cfg.device, ds_meta=dataset.meta)
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)

if not robot.is_connected:
robot.connect()
Expand Down Expand Up @@ -285,8 +285,8 @@ def record(
episode_time_s=cfg.episode_time_s,
display_cameras=cfg.display_cameras,
policy=policy,
device=cfg.device,
use_amp=cfg.use_amp,
device=policy.device,
use_amp=policy.use_amp,
fps=cfg.fps,
single_task=cfg.single_task,
)
Expand Down
6 changes: 3 additions & 3 deletions lerobot/scripts/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def eval_main(cfg: EvalPipelineConfig):
logging.info(pformat(asdict(cfg)))

# Check device is available
device = get_safe_torch_device(cfg.device, log=True)
device = get_safe_torch_device(cfg.policy.device, log=True)

torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
Expand All @@ -470,14 +470,14 @@ def eval_main(cfg: EvalPipelineConfig):
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)

logging.info("Making policy.")

policy = make_policy(
cfg=cfg.policy,
device=device,
env_cfg=cfg.env,
)
policy.eval()

with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy(
env,
policy,
Expand Down
Loading
Loading