Skip to content
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions source/isaaclab/isaaclab/envs/mdp/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,14 +596,16 @@ def randomize(data: torch.Tensor, params: tuple[float, float]) -> torch.Tensor:
actuator_indices = slice(None)
if isinstance(actuator.joint_indices, slice):
global_indices = slice(None)
elif isinstance(actuator.joint_indices, torch.Tensor):
global_indices = actuator.joint_indices.to(self.asset.device)
else:
global_indices = torch.tensor(actuator.joint_indices, device=self.asset.device)
raise TypeError("Actuator joint indices must be a slice or a torch.Tensor.")
elif isinstance(actuator.joint_indices, slice):
# we take the joints defined in the asset config
global_indices = actuator_indices = torch.tensor(self.asset_cfg.joint_ids, device=self.asset.device)
global_indices = torch.tensor(self.asset_cfg.joint_ids, device=self.asset.device)
else:
# we take the intersection of the actuator joints and the asset config joints
actuator_joint_indices = torch.tensor(actuator.joint_indices, device=self.asset.device)
actuator_joint_indices = actuator.joint_indices.to(self.asset.device)
asset_joint_ids = torch.tensor(self.asset_cfg.joint_ids, device=self.asset.device)
# the indices of the joints in the actuator that have to be randomized
actuator_indices = torch.nonzero(torch.isin(actuator_joint_indices, asset_joint_ids)).view(-1)
Expand Down