Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lerobot/policies/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def make_pre_post_processors(
config=policy_cfg, dataset_stats=kwargs.get("dataset_stats")
)

elif isinstance(policy_cfg.type, ACTConfig):
elif isinstance(policy_cfg, ACTConfig):
from lerobot.policies.act.processor_act import make_act_pre_post_processors

processors = make_act_pre_post_processors(
Expand Down
102 changes: 42 additions & 60 deletions src/lerobot/processor/device_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,32 +35,30 @@ class DeviceProcessor(ProcessorStep):

device: str = "cpu"
float_dtype: str | None = None
_device: torch.device | None = None

DTYPE_MAPPING = {
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
"bfloat16": torch.bfloat16,
"half": torch.float16,
"float": torch.float32,
"double": torch.float64,
}

def __post_init__(self):
self._device = get_safe_torch_device(self.device)
self.device = self._device.type
self._device: torch.device = get_safe_torch_device(self.device)
self.device = self._device.type # cuda might have changed to cuda:1
self.non_blocking = "cuda" in str(self.device)

# Validate and convert float_dtype string to torch dtype
if self.float_dtype is not None:
dtype_mapping = {
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
"bfloat16": torch.bfloat16,
"half": torch.float16,
"float": torch.float32,
"double": torch.float64,
}

if self.float_dtype not in dtype_mapping:
available_dtypes = list(dtype_mapping.keys())
if self.float_dtype not in self.DTYPE_MAPPING:
raise ValueError(
f"Invalid float_dtype '{self.float_dtype}'. Available options: {available_dtypes}"
f"Invalid float_dtype '{self.float_dtype}'. Available options: {list(self.DTYPE_MAPPING.keys())}"
)

self._target_float_dtype = dtype_mapping[self.float_dtype]
self._target_float_dtype = self.DTYPE_MAPPING[self.float_dtype]
else:
self._target_float_dtype = None

Expand Down Expand Up @@ -93,51 +91,35 @@ def _process_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
return tensor

def __call__(self, transition: EnvTransition) -> EnvTransition:
# Create a copy of the transition
new_transition = transition.copy()

# Process observation tensors
observation = transition.get(TransitionKey.OBSERVATION)
if observation is not None:
new_observation = {
k: self._process_tensor(v) if isinstance(v, torch.Tensor) else v
for k, v in observation.items()
}
new_transition[TransitionKey.OBSERVATION] = new_observation

# Process action tensor
action = transition.get(TransitionKey.ACTION)
if action is not None and isinstance(action, torch.Tensor):
new_transition[TransitionKey.ACTION] = self._process_tensor(action)

# Process reward tensor
reward = transition.get(TransitionKey.REWARD)
if reward is not None and isinstance(reward, torch.Tensor):
new_transition[TransitionKey.REWARD] = self._process_tensor(reward)

# Process done tensor
done = transition.get(TransitionKey.DONE)
if done is not None and isinstance(done, torch.Tensor):
new_transition[TransitionKey.DONE] = self._process_tensor(done)

# Process truncated tensor
truncated = transition.get(TransitionKey.TRUNCATED)
if truncated is not None and isinstance(truncated, torch.Tensor):
new_transition[TransitionKey.TRUNCATED] = self._process_tensor(truncated)

# Process complementary data tensors
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is not None:
new_complementary_data = {}

# Process all items in complementary_data
for key, value in complementary_data.items():
if isinstance(value, torch.Tensor):
new_complementary_data[key] = self._process_tensor(value)
else:
new_complementary_data[key] = value

new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
simple_tensor_keys = [
TransitionKey.ACTION,
TransitionKey.REWARD,
TransitionKey.DONE,
TransitionKey.TRUNCATED,
]

dict_tensor_keys = [
TransitionKey.OBSERVATION,
TransitionKey.COMPLEMENTARY_DATA,
]

# Process simple tensors
for key in simple_tensor_keys:
value = transition.get(key)
if isinstance(value, torch.Tensor):
new_transition[key] = self._process_tensor(value)

# Process dictionary-like tensors
for key in dict_tensor_keys:
data_dict = transition.get(key)
if data_dict is not None:
new_data_dict = {
k: self._process_tensor(v) if isinstance(v, torch.Tensor) else v
for k, v in data_dict.items()
}
new_transition[key] = new_data_dict

return new_transition

Expand Down
1 change: 1 addition & 0 deletions src/lerobot/processor/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,7 @@ def __repr__(self) -> str:
def __post_init__(self):
for i, step in enumerate(self.steps):
if not callable(step):
# TODO(steven): This should instead check isinstance(step, ProcessorStep), test need to be updated
raise TypeError(
f"Step {i} ({type(step).__name__}) must define __call__(transition) -> EnvTransition"
)
Expand Down