diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index 1f735ace00b..d4f962b0b47 100644 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -36,6 +36,7 @@ The following modules are available in the ``isaaclab`` extension: lab/isaaclab.sim.converters lab/isaaclab.sim.schemas lab/isaaclab.sim.spawners + lab/isaaclab.sim.views lab/isaaclab.sim.utils diff --git a/docs/source/api/lab/isaaclab.sim.views.rst b/docs/source/api/lab/isaaclab.sim.views.rst new file mode 100644 index 00000000000..3a5f9bdecfe --- /dev/null +++ b/docs/source/api/lab/isaaclab.sim.views.rst @@ -0,0 +1,17 @@ +isaaclab.sim.views +================== + +.. automodule:: isaaclab.sim.views + + .. rubric:: Classes + + .. autosummary:: + + XformPrimView + +XForm Prim View +--------------- + +.. autoclass:: XformPrimView + :members: + :show-inheritance: diff --git a/scripts/benchmarks/benchmark_view_comparison.py b/scripts/benchmarks/benchmark_view_comparison.py new file mode 100644 index 00000000000..6775a40d070 --- /dev/null +++ b/scripts/benchmarks/benchmark_view_comparison.py @@ -0,0 +1,491 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Benchmark script comparing XformPrimView vs PhysX RigidBodyView for transform operations. + +This script tests the performance of batched transform operations using: + +- Isaac Lab's XformPrimView (USD-based) +- PhysX RigidBodyView (PhysX tensors-based, as used in RigidObject) + +Note: + XformPrimView operates on USD attributes directly (useful for non-physics prims), + while RigidBodyView requires rigid body physics components and operates on PhysX tensors. + This benchmark helps understand the performance trade-offs between the two approaches. + +Usage: + # Basic benchmark + ./isaaclab.sh -p scripts/benchmarks/benchmark_view_comparison.py --num_envs 1024 --device cuda:0 --headless + + # With profiling enabled (for snakeviz visualization) + ./isaaclab.sh -p scripts/benchmarks/benchmark_view_comparison.py --num_envs 1024 --profile --headless + + # Then visualize with snakeviz: + snakeviz profile_results/xform_view_benchmark.prof + snakeviz profile_results/physx_view_benchmark.prof +""" + +from __future__ import annotations + +"""Launch Isaac Sim Simulator first.""" + +import argparse + +from isaaclab.app import AppLauncher + +# parse the arguments +args_cli = argparse.Namespace() + +parser = argparse.ArgumentParser(description="Benchmark XformPrimView vs PhysX RigidBodyView performance.") + +parser.add_argument("--num_envs", type=int, default=100, help="Number of environments to simulate.") +parser.add_argument("--num_iterations", type=int, default=50, help="Number of iterations for each test.") +parser.add_argument( + "--profile", + action="store_true", + help="Enable profiling with cProfile. Results saved as .prof files for snakeviz visualization.", +) +parser.add_argument( + "--profile-dir", + type=str, + default="./profile_results", + help="Directory to save profile results. Default: ./profile_results", +) + +AppLauncher.add_app_launcher_args(parser) +args_cli = parser.parse_args() + +# launch omniverse app +app_launcher = AppLauncher(args_cli) +simulation_app = app_launcher.app + +"""Rest everything follows.""" + +import cProfile +import time +import torch + +from isaacsim.core.simulation_manager import SimulationManager + +import isaaclab.sim as sim_utils +import isaaclab.utils.math as math_utils +from isaaclab.sim.views import XformPrimView + + +@torch.no_grad() +def benchmark_view(view_type: str, num_iterations: int) -> tuple[dict[str, float], dict[str, torch.Tensor]]: + """Benchmark the specified view class. + + Args: + view_type: Type of view to benchmark ("xform" or "physx"). + num_iterations: Number of iterations to run. + + Returns: + A tuple of (timing_results, computed_results) where: + - timing_results: Dictionary containing timing results for various operations + - computed_results: Dictionary containing the computed values for validation + """ + timing_results = {} + computed_results = {} + + # Setup scene + print(" Setting up scene") + # Clear stage + sim_utils.create_new_stage() + # Create simulation context + start_time = time.perf_counter() + sim = sim_utils.SimulationContext(sim_utils.SimulationCfg(dt=0.01, device=args_cli.device)) + stage = sim_utils.get_current_stage() + + print(f" Time taken to create simulation context: {time.perf_counter() - start_time:.4f} seconds") + + # create a rigid object + object_cfg = sim_utils.ConeCfg( + radius=0.15, + height=0.5, + rigid_props=sim_utils.RigidBodyPropertiesCfg(), + mass_props=sim_utils.MassPropertiesCfg(mass=1.0), + collision_props=sim_utils.CollisionPropertiesCfg(), + visual_material=sim_utils.PreviewSurfaceCfg(diffuse_color=(0.0, 1.0, 0.0)), + ) + # Create prims + for i in range(args_cli.num_envs): + sim_utils.create_prim(f"/World/Env_{i}", "Xform", stage=stage, translation=(i * 2.0, 0.0, 0.0)) + object_cfg.func(f"/World/Env_{i}/Object", object_cfg, translation=(0.0, 0.0, 1.0)) + + # Play simulation + sim.reset() + + # Pattern to match all prims + pattern = "/World/Env_.*/Object" if view_type == "xform" else "/World/Env_*/Object" + print(f" Pattern: {pattern}") + + # Create view based on type + start_time = time.perf_counter() + if view_type == "xform": + view = XformPrimView(pattern, device=args_cli.device, validate_xform_ops=False) + num_prims = view.count + view_name = "XformPrimView" + else: # physx + physics_sim_view = SimulationManager.get_physics_sim_view() + view = physics_sim_view.create_rigid_body_view(pattern) + num_prims = view.count + view_name = "PhysX RigidBodyView" + timing_results["init"] = time.perf_counter() - start_time + # prepare indices for benchmarking + all_indices = torch.arange(num_prims, device=args_cli.device) + + print(f" {view_name} managing {num_prims} prims") + + # Benchmark get_world_poses + start_time = time.perf_counter() + for _ in range(num_iterations): + if view_type == "xform": + positions, orientations = view.get_world_poses() + else: # physx + transforms = view.get_transforms() + positions = transforms[:, :3] + orientations = transforms[:, 3:7] + # Convert quaternion from xyzw to wxyz + orientations = math_utils.convert_quat(orientations, to="wxyz") + timing_results["get_world_poses"] = (time.perf_counter() - start_time) / num_iterations + + # Store initial world poses + computed_results["initial_world_positions"] = positions.clone() + computed_results["initial_world_orientations"] = orientations.clone() + + # Benchmark set_world_poses + new_positions = positions.clone() + new_positions[:, 2] += 0.5 + start_time = time.perf_counter() + for _ in range(num_iterations): + if view_type == "xform": + view.set_world_poses(new_positions, orientations) + else: # physx + # Convert quaternion from wxyz to xyzw for PhysX + orientations_xyzw = math_utils.convert_quat(orientations, to="xyzw") + new_transforms = torch.cat([new_positions, orientations_xyzw], dim=-1) + view.set_transforms(new_transforms, indices=all_indices) + timing_results["set_world_poses"] = (time.perf_counter() - start_time) / num_iterations + + # Get world poses after setting to verify + if view_type == "xform": + positions_after_set, orientations_after_set = view.get_world_poses() + else: # physx + transforms_after = view.get_transforms() + positions_after_set = transforms_after[:, :3] + orientations_after_set = math_utils.convert_quat(transforms_after[:, 3:7], to="wxyz") + computed_results["world_positions_after_set"] = positions_after_set.clone() + computed_results["world_orientations_after_set"] = orientations_after_set.clone() + + # close simulation + sim.clear() + sim.clear_all_callbacks() + sim.clear_instance() + + return timing_results, computed_results + + +def compare_results( + results_dict: dict[str, dict[str, torch.Tensor]], tolerance: float = 1e-4 +) -> dict[str, dict[str, dict[str, float]]]: + """Compare computed results across implementations. + + Args: + results_dict: Dictionary mapping implementation names to their computed values. + tolerance: Tolerance for numerical comparison. + + Returns: + Nested dictionary: {comparison_pair: {metric: {stats}}} + """ + comparison_stats = {} + impl_names = list(results_dict.keys()) + + # Compare each pair of implementations + for i, impl1 in enumerate(impl_names): + for impl2 in impl_names[i + 1 :]: + pair_key = f"{impl1}_vs_{impl2}" + comparison_stats[pair_key] = {} + + computed1 = results_dict[impl1] + computed2 = results_dict[impl2] + + for key in computed1.keys(): + if key not in computed2: + continue + + val1 = computed1[key] + val2 = computed2[key] + + # Skip zero tensors (not applicable tests) + if torch.all(val1 == 0) or torch.all(val2 == 0): + continue + + # Compute differences + diff = torch.abs(val1 - val2) + max_diff = torch.max(diff).item() + mean_diff = torch.mean(diff).item() + + # Check if within tolerance + all_close = torch.allclose(val1, val2, atol=tolerance, rtol=0) + + comparison_stats[pair_key][key] = { + "max_diff": max_diff, + "mean_diff": mean_diff, + "all_close": all_close, + } + + return comparison_stats + + +def print_comparison_results(comparison_stats: dict[str, dict[str, dict[str, float]]], tolerance: float): + """Print comparison results. + + Args: + comparison_stats: Nested dictionary containing comparison statistics. + tolerance: Tolerance used for comparison. + """ + for pair_key, pair_stats in comparison_stats.items(): + if not pair_stats: # Skip if no comparable results + continue + + # Format the pair key for display + impl1, impl2 = pair_key.split("_vs_") + display_impl1 = impl1.replace("_", " ").title() + display_impl2 = impl2.replace("_", " ").title() + comparison_title = f"{display_impl1} vs {display_impl2}" + + # Check if all results match + all_match = all(stats["all_close"] for stats in pair_stats.values()) + + if all_match: + # Compact output when everything matches + print("\n" + "=" * 100) + print(f"RESULT COMPARISON: {comparison_title}") + print("=" * 100) + print(f"✓ All computed values match within tolerance ({tolerance})") + print("=" * 100) + else: + # Detailed output when there are mismatches + print("\n" + "=" * 100) + print(f"RESULT COMPARISON: {comparison_title}") + print("=" * 100) + print(f"{'Computed Value':<40} {'Max Diff':<15} {'Mean Diff':<15} {'Match':<10}") + print("-" * 100) + + for key, stats in pair_stats.items(): + # Format the key for display + display_key = key.replace("_", " ").title() + match_str = "✓ Yes" if stats["all_close"] else "✗ No" + + print(f"{display_key:<40} {stats['max_diff']:<15.6e} {stats['mean_diff']:<15.6e} {match_str:<10}") + + print("=" * 100) + print(f"\n✗ Some results differ beyond tolerance ({tolerance})") + print(f" This may indicate implementation differences between {display_impl1} and {display_impl2}") + + print() + + +def print_results(results_dict: dict[str, dict[str, float]], num_prims: int, num_iterations: int): + """Print benchmark results in a formatted table. + + Args: + results_dict: Dictionary mapping implementation names to their timing results. + num_prims: Number of prims tested. + num_iterations: Number of iterations run. + """ + print("\n" + "=" * 100) + print(f"BENCHMARK RESULTS: {num_prims} prims, {num_iterations} iterations") + print("=" * 100) + + impl_names = list(results_dict.keys()) + # Format names for display + display_names = [name.replace("_", " ").title() for name in impl_names] + + # Calculate column width + col_width = 20 + + # Print header + header = f"{'Operation':<30}" + for display_name in display_names: + header += f" {display_name + ' (ms)':<{col_width}}" + print(header) + print("-" * 100) + + # Print each operation + operations = [ + ("Initialization", "init"), + ("Get World Poses", "get_world_poses"), + ("Set World Poses", "set_world_poses"), + ] + + for op_name, op_key in operations: + row = f"{op_name:<30}" + for impl_name in impl_names: + impl_time = results_dict[impl_name].get(op_key, 0) * 1000 # Convert to ms + row += f" {impl_time:>{col_width - 1}.4f}" + print(row) + + print("=" * 100) + + # Calculate and print total time (excluding N/A operations) + total_row = f"{'Total Time':<30}" + for impl_name in impl_names: + if impl_name == "physx_view": + # Exclude local pose operations for PhysX + total_time = ( + results_dict[impl_name].get("init", 0) * 1000 + + results_dict[impl_name].get("get_world_poses", 0) * 1000 + + results_dict[impl_name].get("set_world_poses", 0) * 1000 + ) + else: + total_time = sum(results_dict[impl_name].values()) * 1000 + total_row += f" {total_time:>{col_width - 1}.4f}" + print(f"\n{total_row}") + + # Calculate speedups relative to XformPrimView + if "xform_view" in impl_names: + print("\n" + "=" * 100) + print("SPEEDUP vs XformPrimView") + print("=" * 100) + print(f"{'Operation':<30}", end="") + for display_name in display_names: + if "xform" not in display_name.lower(): + print(f" {display_name + ' Speedup':<{col_width}}", end="") + print() + print("-" * 100) + + xform_results = results_dict["xform_view"] + for op_name, op_key in operations: + print(f"{op_name:<30}", end="") + xform_time = xform_results.get(op_key, 0) + for impl_name, display_name in zip(impl_names, display_names): + if impl_name != "xform_view": + impl_time = results_dict[impl_name].get(op_key, 0) + if xform_time > 0 and impl_time > 0: + speedup = impl_time / xform_time + print(f" {speedup:>{col_width - 1}.2f}x", end="") + else: + print(f" {'N/A':>{col_width}}", end="") + print() + + # Overall speedup (only world pose operations) + print("=" * 100) + print(f"{'Overall Speedup (World Ops)':<30}", end="") + total_xform = ( + xform_results.get("init", 0) + + xform_results.get("get_world_poses", 0) + + xform_results.get("set_world_poses", 0) + ) + for impl_name, display_name in zip(impl_names, display_names): + if impl_name != "xform_view": + total_impl = ( + results_dict[impl_name].get("init", 0) + + results_dict[impl_name].get("get_world_poses", 0) + + results_dict[impl_name].get("set_world_poses", 0) + ) + if total_xform > 0 and total_impl > 0: + overall_speedup = total_impl / total_xform + print(f" {overall_speedup:>{col_width - 1}.2f}x", end="") + else: + print(f" {'N/A':>{col_width}}", end="") + print() + + print("\n" + "=" * 100) + print("\nNotes:") + print(" - Times are averaged over all iterations") + print(" - Speedup = (PhysX View time) / (XformPrimView time)") + print(" - Speedup > 1.0 means XformPrimView is faster") + print(" - Speedup < 1.0 means PhysX View is faster") + print(" - PhysX View requires rigid body physics components") + print(" - XformPrimView works with any Xform prim (physics or non-physics)") + print(" - PhysX View does not support local pose operations directly") + print() + + +def main(): + """Main benchmark function.""" + print("=" * 100) + print("View Comparison Benchmark - XformPrimView vs PhysX RigidBodyView") + print("=" * 100) + print("Configuration:") + print(f" Number of environments: {args_cli.num_envs}") + print(f" Iterations per test: {args_cli.num_iterations}") + print(f" Device: {args_cli.device}") + print(f" Profiling: {'Enabled' if args_cli.profile else 'Disabled'}") + if args_cli.profile: + print(f" Profile directory: {args_cli.profile_dir}") + print() + + # Create profile directory if profiling is enabled + if args_cli.profile: + import os + + os.makedirs(args_cli.profile_dir, exist_ok=True) + + # Dictionary to store all results + all_timing_results = {} + all_computed_results = {} + profile_files = {} + + # Implementations to benchmark + implementations = [ + ("xform_view", "XformPrimView", "xform"), + ("physx_view", "PhysX RigidBodyView", "physx"), + ] + + # Benchmark each implementation + for impl_key, impl_name, view_type in implementations: + print(f"Benchmarking {impl_name}...") + + if args_cli.profile: + profiler = cProfile.Profile() + profiler.enable() + + timing, computed = benchmark_view(view_type=view_type, num_iterations=args_cli.num_iterations) + + if args_cli.profile: + profiler.disable() + profile_file = f"{args_cli.profile_dir}/{impl_key}_benchmark.prof" + profiler.dump_stats(profile_file) + profile_files[impl_key] = profile_file + print(f" Profile saved to: {profile_file}") + + all_timing_results[impl_key] = timing + all_computed_results[impl_key] = computed + + print(" Done!") + print() + + # Print timing results + print_results(all_timing_results, args_cli.num_envs, args_cli.num_iterations) + + # Compare computed results + print("\nComparing computed results across implementations...") + comparison_stats = compare_results(all_computed_results, tolerance=1e-4) + print_comparison_results(comparison_stats, tolerance=1e-4) + + # Print profiling instructions if enabled + if args_cli.profile: + print("\n" + "=" * 100) + print("PROFILING RESULTS") + print("=" * 100) + print("Profile files have been saved. To visualize with snakeviz, run:") + for impl_key, profile_file in profile_files.items(): + impl_display = impl_key.replace("_", " ").title() + print(f" # {impl_display}") + print(f" snakeviz {profile_file}") + print("\nAlternatively, use pstats to analyze in terminal:") + print(" python -m pstats ") + print("=" * 100) + print() + + # Clean up + sim_utils.SimulationContext.clear_instance() + + +if __name__ == "__main__": + main() diff --git a/scripts/benchmarks/benchmark_xform_prim_view.py b/scripts/benchmarks/benchmark_xform_prim_view.py new file mode 100644 index 00000000000..e4fd4c95d5b --- /dev/null +++ b/scripts/benchmarks/benchmark_xform_prim_view.py @@ -0,0 +1,509 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Benchmark script comparing XformPrimView implementations across different APIs. + +This script tests the performance of batched transform operations using: +- Isaac Lab's XformPrimView implementation +- Isaac Sim's XformPrimView implementation (legacy) +- Isaac Sim Experimental's XformPrim implementation (latest) + +Usage: + # Basic benchmark (all APIs) + ./isaaclab.sh -p scripts/benchmarks/benchmark_xform_prim_view.py --num_envs 1024 --device cuda:0 --headless + + # With profiling enabled (for snakeviz visualization) + ./isaaclab.sh -p scripts/benchmarks/benchmark_xform_prim_view.py --num_envs 1024 --profile --headless + + # Then visualize with snakeviz: + snakeviz profile_results/isaaclab_XformPrimView.prof + snakeviz profile_results/isaacsim_XformPrimView.prof + snakeviz profile_results/isaacsim_experimental_XformPrim.prof +""" + +from __future__ import annotations + +"""Launch Isaac Sim Simulator first.""" + +import argparse + +from isaaclab.app import AppLauncher + +# parse the arguments +args_cli = argparse.Namespace() + +parser = argparse.ArgumentParser(description="This script can help you benchmark the performance of XformPrimView.") + +parser.add_argument("--num_envs", type=int, default=100, help="Number of environments to simulate.") +parser.add_argument("--num_iterations", type=int, default=50, help="Number of iterations for each test.") +parser.add_argument( + "--profile", + action="store_true", + help="Enable profiling with cProfile. Results saved as .prof files for snakeviz visualization.", +) +parser.add_argument( + "--profile-dir", + type=str, + default="./profile_results", + help="Directory to save profile results. Default: ./profile_results", +) + +AppLauncher.add_app_launcher_args(parser) +args_cli = parser.parse_args() + +# launch omniverse app +app_launcher = AppLauncher(args_cli) +simulation_app = app_launcher.app + +"""Rest everything follows.""" + +import cProfile +import time +import torch +from typing import Literal + +from isaacsim.core.prims import XFormPrim as IsaacSimXformPrimView +from isaacsim.core.utils.extensions import enable_extension + +# compare against latest Isaac Sim implementation +enable_extension("isaacsim.core.experimental.prims") +from isaacsim.core.experimental.prims import XformPrim as IsaacSimExperimentalXformPrimView + +import isaaclab.sim as sim_utils +from isaaclab.sim.views import XformPrimView as IsaacLabXformPrimView + + +@torch.no_grad() +def benchmark_xform_prim_view( + api: Literal["isaaclab", "isaacsim", "isaacsim-exp"], num_iterations: int +) -> tuple[dict[str, float], dict[str, torch.Tensor]]: + """Benchmark the Xform view class from Isaac Lab, Isaac Sim, or Isaac Sim Experimental. + + Args: + api: Which API to benchmark ("isaaclab", "isaacsim", or "isaacsim-exp"). + num_iterations: Number of iterations to run. + + Returns: + A tuple of (timing_results, computed_results) where: + - timing_results: Dictionary containing timing results for various operations + - computed_results: Dictionary containing the computed values for validation + """ + timing_results = {} + computed_results = {} + + # Setup scene + print(" Setting up scene") + # Clear stage + sim_utils.create_new_stage() + # Create simulation context + start_time = time.perf_counter() + sim = sim_utils.SimulationContext(sim_utils.SimulationCfg(dt=0.01, device=args_cli.device)) + stage = sim_utils.get_current_stage() + + print(f" Time taken to create simulation context: {time.perf_counter() - start_time} seconds") + + # Create prims + prim_paths = [] + for i in range(args_cli.num_envs): + sim_utils.create_prim(f"/World/Env_{i}", "Xform", stage=stage, translation=(i * 2.0, 0.0, 1.0)) + sim_utils.create_prim(f"/World/Env_{i}/Object", "Xform", stage=stage, translation=(0.0, 0.0, 0.0)) + prim_paths.append(f"/World/Env_{i}/Object") + # Play simulation + sim.reset() + + # Pattern to match all prims + pattern = "/World/Env_.*/Object" + print(f" Pattern: {pattern}") + + # Create view + start_time = time.perf_counter() + if api == "isaaclab": + xform_view = IsaacLabXformPrimView(pattern, device=args_cli.device, validate_xform_ops=False) + elif api == "isaacsim": + xform_view = IsaacSimXformPrimView(pattern, reset_xform_properties=False) + elif api == "isaacsim-exp": + xform_view = IsaacSimExperimentalXformPrimView(pattern) + else: + raise ValueError(f"Invalid API: {api}") + timing_results["init"] = time.perf_counter() - start_time + + if api in ("isaaclab", "isaacsim"): + num_prims = xform_view.count + elif api == "isaacsim-exp": + num_prims = len(xform_view.prims) + print(f" XformView managing {num_prims} prims") + + # Benchmark get_world_poses + start_time = time.perf_counter() + for _ in range(num_iterations): + positions, orientations = xform_view.get_world_poses() + # Ensure tensors are torch tensors + if not isinstance(positions, torch.Tensor): + positions = torch.tensor(positions, dtype=torch.float32) + if not isinstance(orientations, torch.Tensor): + orientations = torch.tensor(orientations, dtype=torch.float32) + + timing_results["get_world_poses"] = (time.perf_counter() - start_time) / num_iterations + + # Store initial world poses + computed_results["initial_world_positions"] = positions.clone() + computed_results["initial_world_orientations"] = orientations.clone() + + # Benchmark set_world_poses + new_positions = positions.clone() + new_positions[:, 2] += 0.1 + start_time = time.perf_counter() + for _ in range(num_iterations): + if api in ("isaaclab", "isaacsim"): + xform_view.set_world_poses(new_positions, orientations) + elif api == "isaacsim-exp": + xform_view.set_world_poses(new_positions.cpu().numpy(), orientations.cpu().numpy()) + timing_results["set_world_poses"] = (time.perf_counter() - start_time) / num_iterations + + # Get world poses after setting to verify + positions_after_set, orientations_after_set = xform_view.get_world_poses() + if not isinstance(positions_after_set, torch.Tensor): + positions_after_set = torch.tensor(positions_after_set, dtype=torch.float32) + if not isinstance(orientations_after_set, torch.Tensor): + orientations_after_set = torch.tensor(orientations_after_set, dtype=torch.float32) + computed_results["world_positions_after_set"] = positions_after_set.clone() + computed_results["world_orientations_after_set"] = orientations_after_set.clone() + + # Benchmark get_local_poses + start_time = time.perf_counter() + for _ in range(num_iterations): + translations, orientations_local = xform_view.get_local_poses() + # Ensure tensors are torch tensors + if not isinstance(translations, torch.Tensor): + translations = torch.tensor(translations, dtype=torch.float32, device=args_cli.device) + if not isinstance(orientations_local, torch.Tensor): + orientations_local = torch.tensor(orientations_local, dtype=torch.float32, device=args_cli.device) + + timing_results["get_local_poses"] = (time.perf_counter() - start_time) / num_iterations + + # Store initial local poses + computed_results["initial_local_translations"] = translations.clone() + computed_results["initial_local_orientations"] = orientations_local.clone() + + # Benchmark set_local_poses + new_translations = translations.clone() + new_translations[:, 2] += 0.1 + start_time = time.perf_counter() + for _ in range(num_iterations): + if api in ("isaaclab", "isaacsim"): + xform_view.set_local_poses(new_translations, orientations_local) + elif api == "isaacsim-exp": + xform_view.set_local_poses(new_translations.cpu().numpy(), orientations_local.cpu().numpy()) + timing_results["set_local_poses"] = (time.perf_counter() - start_time) / num_iterations + + # Get local poses after setting to verify + translations_after_set, orientations_local_after_set = xform_view.get_local_poses() + if not isinstance(translations_after_set, torch.Tensor): + translations_after_set = torch.tensor(translations_after_set, dtype=torch.float32) + if not isinstance(orientations_local_after_set, torch.Tensor): + orientations_local_after_set = torch.tensor(orientations_local_after_set, dtype=torch.float32) + computed_results["local_translations_after_set"] = translations_after_set.clone() + computed_results["local_orientations_after_set"] = orientations_local_after_set.clone() + + # Benchmark combined get operation + start_time = time.perf_counter() + for _ in range(num_iterations): + positions, orientations = xform_view.get_world_poses() + translations, local_orientations = xform_view.get_local_poses() + timing_results["get_both"] = (time.perf_counter() - start_time) / num_iterations + + # close simulation + sim.clear() + sim.clear_all_callbacks() + sim.clear_instance() + + return timing_results, computed_results + + +def compare_results( + results_dict: dict[str, dict[str, torch.Tensor]], tolerance: float = 1e-4 +) -> dict[str, dict[str, dict[str, float]]]: + """Compare computed results across multiple implementations. + + Args: + results_dict: Dictionary mapping API names to their computed values. + tolerance: Tolerance for numerical comparison. + + Returns: + Nested dictionary: {comparison_pair: {metric: {stats}}}, e.g., + {"isaaclab_vs_isaacsim": {"initial_world_positions": {"max_diff": 0.001, ...}}} + """ + comparison_stats = {} + api_names = list(results_dict.keys()) + + # Compare each pair of APIs + for i, api1 in enumerate(api_names): + for api2 in api_names[i + 1 :]: + pair_key = f"{api1}_vs_{api2}" + comparison_stats[pair_key] = {} + + computed1 = results_dict[api1] + computed2 = results_dict[api2] + + for key in computed1.keys(): + if key not in computed2: + print(f" Warning: Key '{key}' not found in {api2} results") + continue + + val1 = computed1[key] + val2 = computed2[key] + + # Compute differences + diff = torch.abs(val1 - val2) + max_diff = torch.max(diff).item() + mean_diff = torch.mean(diff).item() + + # Check if within tolerance + all_close = torch.allclose(val1, val2, atol=tolerance, rtol=0) + + comparison_stats[pair_key][key] = { + "max_diff": max_diff, + "mean_diff": mean_diff, + "all_close": all_close, + } + + return comparison_stats + + +def print_comparison_results(comparison_stats: dict[str, dict[str, dict[str, float]]], tolerance: float): + """Print comparison results across implementations. + + Args: + comparison_stats: Nested dictionary containing comparison statistics for each API pair. + tolerance: Tolerance used for comparison. + """ + for pair_key, pair_stats in comparison_stats.items(): + # Format the pair key for display (e.g., "isaaclab_vs_isaacsim" -> "Isaac Lab vs Isaac Sim") + api1, api2 = pair_key.split("_vs_") + display_api1 = api1.replace("-", " ").title() + display_api2 = api2.replace("-", " ").title() + comparison_title = f"{display_api1} vs {display_api2}" + + # Check if all results match + all_match = all(stats["all_close"] for stats in pair_stats.values()) + + if all_match: + # Compact output when everything matches + print("\n" + "=" * 100) + print(f"RESULT COMPARISON: {comparison_title}") + print("=" * 100) + print(f"✓ All computed values match within tolerance ({tolerance})") + print("=" * 100) + else: + # Detailed output when there are mismatches + print("\n" + "=" * 100) + print(f"RESULT COMPARISON: {comparison_title}") + print("=" * 100) + print(f"{'Computed Value':<40} {'Max Diff':<15} {'Mean Diff':<15} {'Match':<10}") + print("-" * 100) + + for key, stats in pair_stats.items(): + # Format the key for display + display_key = key.replace("_", " ").title() + match_str = "✓ Yes" if stats["all_close"] else "✗ No" + + print(f"{display_key:<40} {stats['max_diff']:<15.6e} {stats['mean_diff']:<15.6e} {match_str:<10}") + + print("=" * 100) + print(f"\n✗ Some results differ beyond tolerance ({tolerance})") + print(f" This may indicate implementation differences between {display_api1} and {display_api2}") + + print() + + +def print_results(results_dict: dict[str, dict[str, float]], num_prims: int, num_iterations: int): + """Print benchmark results in a formatted table. + + Args: + results_dict: Dictionary mapping API names to their timing results. + num_prims: Number of prims tested. + num_iterations: Number of iterations run. + """ + print("\n" + "=" * 100) + print(f"BENCHMARK RESULTS: {num_prims} prims, {num_iterations} iterations") + print("=" * 100) + + api_names = list(results_dict.keys()) + # Format API names for display + display_names = [name.replace("-", " ").replace("_", " ").title() for name in api_names] + + # Calculate column width based on number of APIs + col_width = 20 + + # Print header + header = f"{'Operation':<25}" + for display_name in display_names: + header += f" {display_name + ' (ms)':<{col_width}}" + print(header) + print("-" * 100) + + # Print each operation + operations = [ + ("Initialization", "init"), + ("Get World Poses", "get_world_poses"), + ("Set World Poses", "set_world_poses"), + ("Get Local Poses", "get_local_poses"), + ("Set Local Poses", "set_local_poses"), + ("Get Both (World+Local)", "get_both"), + ] + + for op_name, op_key in operations: + row = f"{op_name:<25}" + for api_name in api_names: + api_time = results_dict[api_name].get(op_key, 0) * 1000 # Convert to ms + row += f" {api_time:>{col_width - 1}.4f}" + print(row) + + print("=" * 100) + + # Calculate and print total time + total_row = f"{'Total Time':<25}" + for api_name in api_names: + total_time = sum(results_dict[api_name].values()) * 1000 + total_row += f" {total_time:>{col_width - 1}.4f}" + print(f"\n{total_row}") + + # Calculate speedups relative to Isaac Lab + if "isaaclab" in api_names: + print("\n" + "=" * 100) + print("SPEEDUP vs Isaac Lab") + print("=" * 100) + print(f"{'Operation':<25}", end="") + for display_name in display_names: + if "isaaclab" not in display_name.lower(): + print(f" {display_name + ' Speedup':<{col_width}}", end="") + print() + print("-" * 100) + + isaaclab_results = results_dict["isaaclab"] + for op_name, op_key in operations: + print(f"{op_name:<25}", end="") + isaaclab_time = isaaclab_results.get(op_key, 0) + for api_name, display_name in zip(api_names, display_names): + if api_name != "isaaclab": + api_time = results_dict[api_name].get(op_key, 0) + if isaaclab_time > 0 and api_time > 0: + speedup = api_time / isaaclab_time + print(f" {speedup:>{col_width - 1}.2f}x", end="") + else: + print(f" {'N/A':>{col_width}}", end="") + print() + + # Overall speedup + print("=" * 100) + print(f"{'Overall Speedup':<25}", end="") + total_isaaclab = sum(isaaclab_results.values()) + for api_name, display_name in zip(api_names, display_names): + if api_name != "isaaclab": + total_api = sum(results_dict[api_name].values()) + if total_isaaclab > 0 and total_api > 0: + overall_speedup = total_api / total_isaaclab + print(f" {overall_speedup:>{col_width - 1}.2f}x", end="") + else: + print(f" {'N/A':>{col_width}}", end="") + print() + + print("\n" + "=" * 100) + print("\nNotes:") + print(" - Times are averaged over all iterations") + print(" - Speedup = (Other API time) / (Isaac Lab time)") + print(" - Speedup > 1.0 means Isaac Lab is faster") + print(" - Speedup < 1.0 means the other API is faster") + print() + + +def main(): + """Main benchmark function.""" + print("=" * 100) + print("XformPrimView Benchmark - Comparing Multiple APIs") + print("=" * 100) + print("Configuration:") + print(f" Number of environments: {args_cli.num_envs}") + print(f" Iterations per test: {args_cli.num_iterations}") + print(f" Device: {args_cli.device}") + print(f" Profiling: {'Enabled' if args_cli.profile else 'Disabled'}") + if args_cli.profile: + print(f" Profile directory: {args_cli.profile_dir}") + print() + + # Create profile directory if profiling is enabled + if args_cli.profile: + import os + + os.makedirs(args_cli.profile_dir, exist_ok=True) + + # Dictionary to store all results + all_timing_results = {} + all_computed_results = {} + profile_files = {} + + # APIs to benchmark + apis_to_test = [ + ("isaaclab", "Isaac Lab XformPrimView"), + ("isaacsim", "Isaac Sim XformPrimView (Legacy)"), + ("isaacsim-exp", "Isaac Sim Experimental XformPrim"), + ] + + # Benchmark each API + for api_key, api_name in apis_to_test: + print(f"Benchmarking {api_name}...") + + if args_cli.profile: + profiler = cProfile.Profile() + profiler.enable() + + # Cast api_key to Literal type for type checker + timing, computed = benchmark_xform_prim_view( + api=api_key, # type: ignore[arg-type] + num_iterations=args_cli.num_iterations, + ) + + if args_cli.profile: + profiler.disable() + profile_file = f"{args_cli.profile_dir}/{api_key.replace('-', '_')}_benchmark.prof" + profiler.dump_stats(profile_file) + profile_files[api_key] = profile_file + print(f" Profile saved to: {profile_file}") + + all_timing_results[api_key] = timing + all_computed_results[api_key] = computed + + print(" Done!") + print() + + # Print timing results + print_results(all_timing_results, args_cli.num_envs, args_cli.num_iterations) + + # Compare computed results + print("\nComparing computed results across APIs...") + comparison_stats = compare_results(all_computed_results, tolerance=1e-6) + print_comparison_results(comparison_stats, tolerance=1e-4) + + # Print profiling instructions if enabled + if args_cli.profile: + print("\n" + "=" * 100) + print("PROFILING RESULTS") + print("=" * 100) + print("Profile files have been saved. To visualize with snakeviz, run:") + for api_key, profile_file in profile_files.items(): + api_display = api_key.replace("-", " ").title() + print(f" # {api_display}") + print(f" snakeviz {profile_file}") + print("\nAlternatively, use pstats to analyze in terminal:") + print(" python -m pstats ") + print("=" * 100) + print() + + # Clean up + sim_utils.SimulationContext.clear_instance() + + +if __name__ == "__main__": + main() diff --git a/source/isaaclab/isaaclab/scene/interactive_scene.py b/source/isaaclab/isaaclab/scene/interactive_scene.py index 775df1640f4..33abcaed7f9 100644 --- a/source/isaaclab/isaaclab/scene/interactive_scene.py +++ b/source/isaaclab/isaaclab/scene/interactive_scene.py @@ -10,7 +10,6 @@ import carb from isaacsim.core.cloner import GridCloner -from isaacsim.core.prims import XFormPrim from pxr import PhysxSchema import isaaclab.sim as sim_utils @@ -30,6 +29,7 @@ from isaaclab.sensors import ContactSensorCfg, FrameTransformerCfg, SensorBase, SensorBaseCfg from isaaclab.sim import SimulationContext from isaaclab.sim.utils.stage import get_current_stage, get_current_stage_id +from isaaclab.sim.views import XformPrimView from isaaclab.terrains import TerrainImporter, TerrainImporterCfg from isaaclab.utils.version import get_isaac_sim_version @@ -406,11 +406,11 @@ def surface_grippers(self) -> dict[str, SurfaceGripper]: return self._surface_grippers @property - def extras(self) -> dict[str, XFormPrim]: + def extras(self) -> dict[str, XformPrimView]: """A dictionary of miscellaneous simulation objects that neither inherit from assets nor sensors. - The keys are the names of the miscellaneous objects, and the values are the `XFormPrim`_ - of the corresponding prims. + The keys are the names of the miscellaneous objects, and the values are the + :class:`~isaaclab.sim.views.XformPrimView` instances of the corresponding prims. As an example, lights or other props in the scene that do not have any attributes or properties that you want to alter at runtime can be added to this dictionary. @@ -419,8 +419,6 @@ def extras(self) -> dict[str, XFormPrim]: These are not reset or updated by the scene. They are mainly other prims that are not necessarily handled by the interactive scene, but are useful to be accessed by the user. - .. _XFormPrim: https://docs.isaacsim.omniverse.nvidia.com/latest/py/source/extensions/isaacsim.core.prims/docs/index.html#isaacsim.core.prims.XFormPrim - """ return self._extras @@ -779,7 +777,7 @@ def _add_entities_from_cfg(self): ) # store xform prim view corresponding to this asset # all prims in the scene are Xform prims (i.e. have a transform component) - self._extras[asset_name] = XFormPrim(asset_cfg.prim_path, reset_xform_properties=False) + self._extras[asset_name] = XformPrimView(asset_cfg.prim_path, device=self.device, stage=self.stage) else: raise ValueError(f"Unknown asset config type for {asset_name}: {asset_cfg}") # store global collision paths diff --git a/source/isaaclab/isaaclab/sensors/camera/camera.py b/source/isaaclab/isaaclab/sensors/camera/camera.py index 90f8bdef955..d5773bf24f9 100644 --- a/source/isaaclab/isaaclab/sensors/camera/camera.py +++ b/source/isaaclab/isaaclab/sensors/camera/camera.py @@ -17,11 +17,11 @@ import carb import omni.kit.commands import omni.usd -from isaacsim.core.prims import XFormPrim from pxr import Sdf, UsdGeom import isaaclab.sim as sim_utils import isaaclab.utils.sensors as sensor_utils +from isaaclab.sim.views import XformPrimView from isaaclab.utils import to_camel_case from isaaclab.utils.array import convert_to_torch from isaaclab.utils.math import ( @@ -405,8 +405,7 @@ def _initialize_impl(self): # Initialize parent class super()._initialize_impl() # Create a view for the sensor - self._view = XFormPrim(self.cfg.prim_path, reset_xform_properties=False) - self._view.initialize() + self._view = XformPrimView(self.cfg.prim_path, device=self._device, stage=self.stage) # Check that sizes are correct if self._view.count != self._num_envs: raise RuntimeError( @@ -424,9 +423,9 @@ def _initialize_impl(self): self._rep_registry: dict[str, list[rep.annotators.Annotator]] = {name: list() for name in self.cfg.data_types} # Convert all encapsulated prims to Camera - for cam_prim_path in self._view.prim_paths: - # Get camera prim - cam_prim = self.stage.GetPrimAtPath(cam_prim_path) + for cam_prim in self._view.prims: + # Obtain the prim path + cam_prim_path = cam_prim.GetPath().pathString # Check if prim is a camera if not cam_prim.IsA(UsdGeom.Camera): raise RuntimeError(f"Prim at path '{cam_prim_path}' is not a Camera.") diff --git a/source/isaaclab/isaaclab/sensors/camera/tiled_camera.py b/source/isaaclab/isaaclab/sensors/camera/tiled_camera.py index 2f9f996b4f4..3fb1f343ff7 100644 --- a/source/isaaclab/isaaclab/sensors/camera/tiled_camera.py +++ b/source/isaaclab/isaaclab/sensors/camera/tiled_camera.py @@ -14,9 +14,9 @@ from typing import TYPE_CHECKING, Any import carb -from isaacsim.core.prims import XFormPrim from pxr import UsdGeom +from isaaclab.sim.views import XformPrimView from isaaclab.utils.warp.kernels import reshape_tiled_image from ..sensor_base import SensorBase @@ -150,8 +150,7 @@ def _initialize_impl(self): # Initialize parent class SensorBase._initialize_impl(self) # Create a view for the sensor - self._view = XFormPrim(self.cfg.prim_path, reset_xform_properties=False) - self._view.initialize() + self._view = XformPrimView(self.cfg.prim_path, device=self._device, stage=self.stage) # Check that sizes are correct if self._view.count != self._num_envs: raise RuntimeError( @@ -165,20 +164,19 @@ def _initialize_impl(self): self._frame = torch.zeros(self._view.count, device=self._device, dtype=torch.long) # Convert all encapsulated prims to Camera - for cam_prim_path in self._view.prim_paths: + cam_prim_paths = [] + for cam_prim in self._view.prims: # Get camera prim - cam_prim = self.stage.GetPrimAtPath(cam_prim_path) + cam_prim_path = cam_prim.GetPath().pathString # Check if prim is a camera if not cam_prim.IsA(UsdGeom.Camera): raise RuntimeError(f"Prim at path '{cam_prim_path}' is not a Camera.") # Add to list - sensor_prim = UsdGeom.Camera(cam_prim) - self._sensor_prims.append(sensor_prim) + self._sensor_prims.append(UsdGeom.Camera(cam_prim)) + cam_prim_paths.append(cam_prim_path) # Create replicator tiled render product - rp = rep.create.render_product_tiled( - cameras=self._view.prim_paths, tile_resolution=(self.cfg.width, self.cfg.height) - ) + rp = rep.create.render_product_tiled(cameras=cam_prim_paths, tile_resolution=(self.cfg.width, self.cfg.height)) self._render_product_paths = [rp.path] # Define the annotators based on requested data types diff --git a/source/isaaclab/isaaclab/sensors/contact_sensor/__init__.py b/source/isaaclab/isaaclab/sensors/contact_sensor/__init__.py index 511a660b3d9..94b402d41a3 100644 --- a/source/isaaclab/isaaclab/sensors/contact_sensor/__init__.py +++ b/source/isaaclab/isaaclab/sensors/contact_sensor/__init__.py @@ -3,7 +3,7 @@ # # SPDX-License-Identifier: BSD-3-Clause -"""Sub-module for rigid contact sensor based on :class:`isaacsim.core.prims.RigidContactView`.""" +"""Sub-module for rigid contact sensor.""" from .contact_sensor import ContactSensor from .contact_sensor_cfg import ContactSensorCfg diff --git a/source/isaaclab/isaaclab/sensors/ray_caster/multi_mesh_ray_caster.py b/source/isaaclab/isaaclab/sensors/ray_caster/multi_mesh_ray_caster.py index 817d09674e2..cf09ce27a9a 100644 --- a/source/isaaclab/isaaclab/sensors/ray_caster/multi_mesh_ray_caster.py +++ b/source/isaaclab/isaaclab/sensors/ray_caster/multi_mesh_ray_caster.py @@ -15,9 +15,9 @@ from typing import TYPE_CHECKING, ClassVar import omni.physics.tensors.impl.api as physx -from isaacsim.core.prims import XFormPrim import isaaclab.sim as sim_utils +from isaaclab.sim.views import XformPrimView from isaaclab.utils.math import matrix_from_quat, quat_mul from isaaclab.utils.mesh import PRIMITIVE_MESH_TYPES, create_trimesh_from_geom_mesh, create_trimesh_from_geom_shape from isaaclab.utils.warp import convert_to_warp_mesh, raycast_dynamic_meshes @@ -78,7 +78,7 @@ class MultiMeshRayCaster(RayCaster): mesh_offsets: dict[str, tuple[torch.Tensor, torch.Tensor]] = {} - mesh_views: ClassVar[dict[str, XFormPrim | physx.ArticulationView | physx.RigidBodyView]] = {} + mesh_views: ClassVar[dict[str, XformPrimView | physx.ArticulationView | physx.RigidBodyView]] = {} """A dictionary to store mesh views for raycasting, shared across all instances. The keys correspond to the prim path for the mesh views, and values are the corresponding view objects. diff --git a/source/isaaclab/isaaclab/sensors/ray_caster/ray_cast_utils.py b/source/isaaclab/isaaclab/sensors/ray_caster/ray_cast_utils.py index a5a62fea183..543276e8ea2 100644 --- a/source/isaaclab/isaaclab/sensors/ray_caster/ray_cast_utils.py +++ b/source/isaaclab/isaaclab/sensors/ray_caster/ray_cast_utils.py @@ -10,13 +10,13 @@ import torch import omni.physics.tensors.impl.api as physx -from isaacsim.core.prims import XFormPrim +from isaaclab.sim.views import XformPrimView from isaaclab.utils.math import convert_quat def obtain_world_pose_from_view( - physx_view: XFormPrim | physx.ArticulationView | physx.RigidBodyView, + physx_view: XformPrimView | physx.ArticulationView | physx.RigidBodyView, env_ids: torch.Tensor, clone: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -34,7 +34,7 @@ def obtain_world_pose_from_view( Raises: NotImplementedError: If the prim view is not of the supported type. """ - if isinstance(physx_view, XFormPrim): + if isinstance(physx_view, XformPrimView): pos_w, quat_w = physx_view.get_world_poses(env_ids) elif isinstance(physx_view, physx.ArticulationView): pos_w, quat_w = physx_view.get_root_transforms()[env_ids].split([3, 4], dim=-1) diff --git a/source/isaaclab/isaaclab/sensors/ray_caster/ray_caster.py b/source/isaaclab/isaaclab/sensors/ray_caster/ray_caster.py index 30db49f7f11..d7aa07419d4 100644 --- a/source/isaaclab/isaaclab/sensors/ray_caster/ray_caster.py +++ b/source/isaaclab/isaaclab/sensors/ray_caster/ray_caster.py @@ -14,13 +14,13 @@ from typing import TYPE_CHECKING, ClassVar import omni -from isaacsim.core.prims import XFormPrim from isaacsim.core.simulation_manager import SimulationManager from pxr import UsdGeom, UsdPhysics import isaaclab.sim as sim_utils import isaaclab.utils.math as math_utils from isaaclab.markers import VisualizationMarkers +from isaaclab.sim.views import XformPrimView from isaaclab.terrains.trimesh.utils import make_plane from isaaclab.utils.math import quat_apply, quat_apply_yaw from isaaclab.utils.warp import convert_to_warp_mesh, raycast_mesh @@ -333,7 +333,7 @@ def _debug_vis_callback(self, event): def _obtain_trackable_prim_view( self, target_prim_path: str - ) -> tuple[XFormPrim | any, tuple[torch.Tensor, torch.Tensor]]: + ) -> tuple[XformPrimView | any, tuple[torch.Tensor, torch.Tensor]]: """Obtain a prim view that can be used to track the pose of the parget prim. The target prim path is a regex expression that matches one or more mesh prims. While we can track its @@ -376,7 +376,7 @@ def _obtain_trackable_prim_view( new_root_prim = current_prim.GetParent() current_path_expr = current_path_expr.rsplit("/", 1)[0] if not new_root_prim.IsValid(): - prim_view = XFormPrim(target_prim_path, reset_xform_properties=False) + prim_view = XformPrimView(target_prim_path, device=self._device, stage=self.stage) current_path_expr = target_prim_path logger.warning( f"The prim at path {target_prim_path} which is used for raycasting is not a physics prim." diff --git a/source/isaaclab/isaaclab/sim/__init__.py b/source/isaaclab/isaaclab/sim/__init__.py index 438ebc121ac..1dc920f4e10 100644 --- a/source/isaaclab/isaaclab/sim/__init__.py +++ b/source/isaaclab/isaaclab/sim/__init__.py @@ -32,3 +32,4 @@ from .simulation_context import SimulationContext, build_simulation_context # noqa: F401, F403 from .spawners import * # noqa: F401, F403 from .utils import * # noqa: F401, F403 +from .views import * # noqa: F401, F403 diff --git a/source/isaaclab/isaaclab/sim/views/__init__.py b/source/isaaclab/isaaclab/sim/views/__init__.py new file mode 100644 index 00000000000..eb5bea7690c --- /dev/null +++ b/source/isaaclab/isaaclab/sim/views/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Views for manipulating USD prims.""" + +from .xform_prim_view import XformPrimView diff --git a/source/isaaclab/isaaclab/sim/views/xform_prim_view.py b/source/isaaclab/isaaclab/sim/views/xform_prim_view.py new file mode 100644 index 00000000000..6049948da81 --- /dev/null +++ b/source/isaaclab/isaaclab/sim/views/xform_prim_view.py @@ -0,0 +1,600 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import numpy as np +import torch +from collections.abc import Sequence + +from pxr import Gf, Sdf, Usd, UsdGeom, Vt + +import isaaclab.sim as sim_utils +import isaaclab.utils.math as math_utils + + +class XformPrimView: + """Optimized batched interface for reading and writing transforms of multiple USD prims. + + This class provides efficient batch operations for getting and setting poses (position and orientation) + of multiple prims at once using torch tensors. It is designed for scenarios where you need to manipulate + many prims simultaneously, such as in multi-agent simulations or large-scale procedural generation. + + The class supports both world-space and local-space pose operations: + + - **World poses**: Positions and orientations in the global world frame + - **Local poses**: Positions and orientations relative to each prim's parent + + .. warning:: + **Fabric and Physics Simulation:** + + This view operates directly on USD attributes. When **Fabric** (NVIDIA's USD runtime optimization) + is enabled, physics simulation updates are written to Fabric's internal representation and + **not propagated back to USD attributes**. This causes the following issues: + + - Reading poses via :func:`get_world_poses()` or :func:`get_local_poses()` will return + **stale USD data** which does not reflect the actual physics state + - Writing poses via :func:`set_world_poses()` or :func:`set_local_poses()` will update USD, + but **physics simulation will not see these changes**. + + **Solution:** + For prims with physics components (rigid bodies, articulations), use :mod:`isaaclab.assets` + classes (e.g., :class:`~isaaclab.assets.RigidObject`, :class:`~isaaclab.assets.Articulation`) + which use PhysX tensor APIs that work correctly with Fabric. + + **When to use XformPrimView:** + + - Non-physics prims (markers, visual elements, cameras without physics) + - Setting initial poses before simulation starts + - Non-Fabric workflows + + For more information on Fabric, please refer to the `Fabric documentation`_. + + .. _Fabric documentation: https://docs.omniverse.nvidia.com/kit/docs/usdrt/latest/docs/usd_fabric_usdrt.html + + .. note:: + **Performance Considerations:** + + * Tensor operations are performed on the specified device (CPU/CUDA) + * USD write operations use ``Sdf.ChangeBlock`` for batched updates + * Getting poses involves USD API calls and cannot be fully accelerated on GPU + * For maximum performance, minimize get/set operations within tight loops + + .. note:: + **Transform Requirements:** + + All prims in the view must be Xformable and have standardized transform operations: + ``[translate, orient, scale]``. Non-standard prims will raise a ValueError during + initialization if :attr:`validate_xform_ops` is True. Please use the function + :func:`isaaclab.sim.utils.standardize_xform_ops` to prepare prims before using this view. + + .. warning:: + This class operates at the USD default time code. Any animation or time-sampled data + will not be affected by write operations. For animated transforms, you need to handle + time-sampled keyframes separately. + """ + + def __init__( + self, prim_path: str, device: str = "cpu", validate_xform_ops: bool = True, stage: Usd.Stage | None = None + ): + """Initialize the view with matching prims. + + This method searches the USD stage for all prims matching the provided path pattern, + validates that they are Xformable with standard transform operations, and stores + references for efficient batch operations. + + We generally recommend to validate the xform operations, as it ensures that the prims are in a consistent state + and have the standard transform operations (translate, orient, scale in that order). + However, if you are sure that the prims are in a consistent state, you can set this to False to improve + performance. This can save around 45-50% of the time taken to initialize the view. + + Args: + prim_path: USD prim path pattern to match prims. Supports wildcards (``*``) and + regex patterns (e.g., ``"/World/Env_.*/Robot"``). See + :func:`isaaclab.sim.utils.find_matching_prims` for pattern syntax. + device: Device to place the tensors on. Can be ``"cpu"`` or CUDA devices like + ``"cuda:0"``. Defaults to ``"cpu"``. + validate_xform_ops: Whether to validate that the prims have standard xform operations. + Defaults to True. + stage: USD stage to search for prims. Defaults to None, in which case the current active stage + from the simulation context is used. + + Raises: + ValueError: If any matched prim is not Xformable or doesn't have standardized + transform operations (translate, orient, scale in that order). + """ + stage = sim_utils.get_current_stage() if stage is None else stage + + # Store configuration + self._prim_path = prim_path + self._device = device + + # Find and validate matching prims + self._prims: list[Usd.Prim] = sim_utils.find_matching_prims(prim_path, stage=stage) + + # Create indices buffer + # Since we iterate over the indices, we need to use range instead of torch tensor + self._ALL_INDICES = list(range(len(self._prims))) + + # Validate all prims have standard xform operations + if validate_xform_ops: + for prim in self._prims: + if not sim_utils.validate_standard_xform_ops(prim): + raise ValueError( + f"Prim at path '{prim.GetPath().pathString}' is not a xformable prim with standard transform" + f" operations [translate, orient, scale]. Received type: '{prim.GetTypeName()}'." + " Use sim_utils.standardize_xform_ops() to prepare the prim." + ) + + """ + Properties. + """ + + @property + def count(self) -> int: + """Number of prims in this view. + + Returns: + The number of prims being managed by this view. + """ + return len(self._prims) + + @property + def device(self) -> str: + """Device where tensors are allocated (cpu or cuda).""" + return self._device + + @property + def prims(self) -> list[Usd.Prim]: + """List of USD prims being managed by this view.""" + return self._prims + + @property + def prim_paths(self) -> list[str]: + """List of prim paths (as strings) for all prims being managed by this view. + + This property converts each prim to its path string representation. The conversion is + performed lazily on first access and cached for subsequent accesses. + + Note: + For most use cases, prefer using :attr:`prims` directly as it provides direct access + to the USD prim objects without the conversion overhead. This property is mainly useful + for logging, debugging, or when string paths are explicitly required. + + Returns: + List of prim paths (as strings) in the same order as :attr:`prims`. + """ + # we cache it the first time it is accessed. + # we don't compute it in constructor because it is expensive and we don't need it most of the time. + # users should usually deal with prims directly as they typically need to access the prims directly. + if not hasattr(self, "_prim_paths"): + self._prim_paths = [prim.GetPath().pathString for prim in self._prims] + return self._prim_paths + + """ + Operations - Setters. + """ + + def set_world_poses( + self, + positions: torch.Tensor | None = None, + orientations: torch.Tensor | None = None, + indices: Sequence[int] | None = None, + ): + """Set world-space poses for prims in the view. + + This method sets the position and/or orientation of each prim in world space. The world pose + is computed by considering the prim's parent transforms. If a prim has a parent, this method + will convert the world pose to the appropriate local pose before setting it. + + Note: + This operation writes to USD at the default time code. Any animation data will not be affected. + + Args: + positions: World-space positions as a tensor of shape (M, 3) where M is the number of prims + to set (either all prims if indices is None, or the number of indices provided). + Defaults to None, in which case positions are not modified. + orientations: World-space orientations as quaternions (w, x, y, z) with shape (M, 4). + Defaults to None, in which case orientations are not modified. + indices: Indices of prims to set poses for. Defaults to None, in which case poses are set + for all prims in the view. + + Raises: + ValueError: If positions shape is not (M, 3) or orientations shape is not (M, 4). + ValueError: If the number of poses doesn't match the number of indices provided. + """ + # Resolve indices + if indices is None or indices == slice(None): + indices_list = self._ALL_INDICES + else: + # Convert to list if it is a tensor array + indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) + + # Validate inputs + if positions is not None: + if positions.shape != (len(indices_list), 3): + raise ValueError( + f"Expected positions shape ({len(indices_list)}, 3), got {positions.shape}. " + "Number of positions must match the number of prims in the view." + ) + positions_array = Vt.Vec3dArray.FromNumpy(positions.cpu().numpy()) + else: + positions_array = None + if orientations is not None: + if orientations.shape != (len(indices_list), 4): + raise ValueError( + f"Expected orientations shape ({len(indices_list)}, 4), got {orientations.shape}. " + "Number of orientations must match the number of prims in the view." + ) + # Vt expects quaternions in xyzw order + orientations_array = Vt.QuatdArray.FromNumpy(math_utils.convert_quat(orientations, to="xyzw").cpu().numpy()) + else: + orientations_array = None + + # Create xform cache instance + xform_cache = UsdGeom.XformCache(Usd.TimeCode.Default()) + + # Set poses for each prim + # We use Sdf.ChangeBlock to minimize notification overhead. + with Sdf.ChangeBlock(): + for idx, prim_idx in enumerate(indices_list): + # Get prim + prim = self._prims[prim_idx] + # Get parent prim for local space conversion + parent_prim = prim.GetParent() + + # Determine what to set + world_pos = positions_array[idx] if positions_array is not None else None + world_quat = orientations_array[idx] if orientations_array is not None else None + + # Convert world pose to local if we have a valid parent + # Note: We don't use :func:`isaaclab.sim.utils.transforms.convert_world_pose_to_local` + # here since it isn't optimized for batch operations. + if parent_prim.IsValid() and parent_prim.GetPath() != Sdf.Path.absoluteRootPath: + # Get current world pose if we're only setting one component + if positions_array is None or orientations_array is None: + # get prim xform + prim_tf = xform_cache.GetLocalToWorldTransform(prim) + # sanitize quaternion + # this is needed, otherwise the quaternion might be non-normalized + prim_tf.Orthonormalize() + # populate desired world transform + if world_pos is not None: + prim_tf.SetTranslateOnly(world_pos) + if world_quat is not None: + prim_tf.SetRotateOnly(world_quat) + else: + # Both position and orientation are provided, create new transform + prim_tf = Gf.Matrix4d() + prim_tf.SetTranslateOnly(world_pos) + prim_tf.SetRotateOnly(world_quat) + + # Convert to local space + parent_world_tf = xform_cache.GetLocalToWorldTransform(parent_prim) + local_tf = prim_tf * parent_world_tf.GetInverse() + local_pos = local_tf.ExtractTranslation() + local_quat = local_tf.ExtractRotationQuat() + else: + # No parent or parent is root, world == local + local_pos = world_pos + local_quat = world_quat + + # Get or create the standard transform operations + if local_pos is not None: + prim.GetAttribute("xformOp:translate").Set(local_pos) + if local_quat is not None: + prim.GetAttribute("xformOp:orient").Set(local_quat) + + def set_local_poses( + self, + translations: torch.Tensor | None = None, + orientations: torch.Tensor | None = None, + indices: Sequence[int] | None = None, + ): + """Set local-space poses for prims in the view. + + This method sets the position and/or orientation of each prim in local space (relative to + their parent prims). This is useful when you want to directly manipulate the prim's transform + attributes without considering the parent hierarchy. + + Note: + This operation writes to USD at the default time code. Any animation data will not be affected. + + Args: + translations: Local-space translations as a tensor of shape (M, 3) where M is the number of prims + to set (either all prims if indices is None, or the number of indices provided). + Defaults to None, in which case translations are not modified. + orientations: Local-space orientations as quaternions (w, x, y, z) with shape (M, 4). + Defaults to None, in which case orientations are not modified. + indices: Indices of prims to set poses for. Defaults to None, in which case poses are set + for all prims in the view. + + Raises: + ValueError: If translations shape is not (M, 3) or orientations shape is not (M, 4). + ValueError: If the number of poses doesn't match the number of indices provided. + """ + # Resolve indices + if indices is None or indices == slice(None): + indices_list = self._ALL_INDICES + else: + # Convert to list if it is a tensor array + indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) + + # Validate inputs + if translations is not None: + if translations.shape != (len(indices_list), 3): + raise ValueError( + f"Expected translations shape ({len(indices_list)}, 3), got {translations.shape}. " + "Number of translations must match the number of prims in the view." + ) + translations_array = Vt.Vec3dArray.FromNumpy(translations.cpu().numpy()) + else: + translations_array = None + if orientations is not None: + if orientations.shape != (len(indices_list), 4): + raise ValueError( + f"Expected orientations shape ({len(indices_list)}, 4), got {orientations.shape}. " + "Number of orientations must match the number of prims in the view." + ) + # Vt expects quaternions in xyzw order + orientations_array = Vt.QuatdArray.FromNumpy(math_utils.convert_quat(orientations, to="xyzw").cpu().numpy()) + else: + orientations_array = None + # Set local poses for each prim + # We use Sdf.ChangeBlock to minimize notification overhead. + with Sdf.ChangeBlock(): + for idx, prim_idx in enumerate(indices_list): + # Get prim + prim = self._prims[prim_idx] + # Set attributes if provided + if translations_array is not None: + prim.GetAttribute("xformOp:translate").Set(translations_array[idx]) + if orientations_array is not None: + prim.GetAttribute("xformOp:orient").Set(orientations_array[idx]) + + def set_scales(self, scales: torch.Tensor, indices: Sequence[int] | None = None): + """Set scales for prims in the view. + + This method sets the scale of each prim in the view. + + Args: + scales: Scales as a tensor of shape (M, 3) where M is the number of prims + to set (either all prims if indices is None, or the number of indices provided). + indices: Indices of prims to set scales for. Defaults to None, in which case scales are set + for all prims in the view. + + Raises: + ValueError: If scales shape is not (M, 3). + """ + # Resolve indices + if indices is None or indices == slice(None): + indices_list = self._ALL_INDICES + else: + # Convert to list if it is a tensor array + indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) + + # Validate inputs + if scales.shape != (len(indices_list), 3): + raise ValueError(f"Expected scales shape ({len(indices_list)}, 3), got {scales.shape}.") + + scales_array = Vt.Vec3dArray.FromNumpy(scales.cpu().numpy()) + # Set scales for each prim + # We use Sdf.ChangeBlock to minimize notification overhead. + with Sdf.ChangeBlock(): + for idx, prim_idx in enumerate(indices_list): + # Get prim + prim = self._prims[prim_idx] + # Set scale attribute + prim.GetAttribute("xformOp:scale").Set(scales_array[idx]) + + def set_visibility(self, visibility: torch.Tensor, indices: Sequence[int] | None = None): + """Set visibility for prims in the view. + + This method sets the visibility of each prim in the view. + + Args: + visibility: Visibility as a boolean tensor of shape (M,) where M is the + number of prims to set (either all prims if indices is None, or the number of indices provided). + indices: Indices of prims to set visibility for. Defaults to None, in which case visibility is set + for all prims in the view. + + Raises: + ValueError: If visibility shape is not (M,). + """ + # Resolve indices + if indices is None or indices == slice(None): + indices_list = self._ALL_INDICES + else: + indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) + + # Validate inputs + if visibility.shape != (len(indices_list),): + raise ValueError(f"Expected visibility shape ({len(indices_list)},), got {visibility.shape}.") + + # Set visibility for each prim + with Sdf.ChangeBlock(): + for idx, prim_idx in enumerate(indices_list): + # Convert prim to imageable + imageable = UsdGeom.Imageable(self._prims[prim_idx]) + # Set visibility + if visibility[idx]: + imageable.MakeVisible() + else: + imageable.MakeInvisible() + + """ + Operations - Getters. + """ + + def get_world_poses(self, indices: Sequence[int] | None = None) -> tuple[torch.Tensor, torch.Tensor]: + """Get world-space poses for prims in the view. + + This method retrieves the position and orientation of each prim in world space by computing + the full transform hierarchy from the prim to the world root. + + Note: + Scale and skew are ignored. The returned poses contain only translation and rotation. + + Args: + indices: Indices of prims to get poses for. Defaults to None, in which case poses are retrieved + for all prims in the view. + + Returns: + A tuple of (positions, orientations) where: + + - positions: Torch tensor of shape (M, 3) containing world-space positions (x, y, z), + where M is the number of prims queried. + - orientations: Torch tensor of shape (M, 4) containing world-space quaternions (w, x, y, z) + """ + # Resolve indices + if indices is None or indices == slice(None): + indices_list = self._ALL_INDICES + else: + # Convert to list if it is a tensor array + indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) + + # Create buffers + positions = Vt.Vec3dArray(len(indices_list)) + orientations = Vt.QuatdArray(len(indices_list)) + # Create xform cache instance + xform_cache = UsdGeom.XformCache(Usd.TimeCode.Default()) + + # Note: We don't use :func:`isaaclab.sim.utils.transforms.resolve_prim_pose` + # here since it isn't optimized for batch operations. + for idx, prim_idx in enumerate(indices_list): + # Get prim + prim = self._prims[prim_idx] + # get prim xform + prim_tf = xform_cache.GetLocalToWorldTransform(prim) + # sanitize quaternion + # this is needed, otherwise the quaternion might be non-normalized + prim_tf.Orthonormalize() + # extract position and orientation + positions[idx] = prim_tf.ExtractTranslation() + orientations[idx] = prim_tf.ExtractRotationQuat() + + # move to torch tensors + positions = torch.tensor(np.array(positions), dtype=torch.float32, device=self._device) + orientations = torch.tensor(np.array(orientations), dtype=torch.float32, device=self._device) + # underlying data is in xyzw order, convert to wxyz order + orientations = math_utils.convert_quat(orientations, to="wxyz") + + return positions, orientations # type: ignore + + def get_local_poses(self, indices: Sequence[int] | None = None) -> tuple[torch.Tensor, torch.Tensor]: + """Get local-space poses for prims in the view. + + This method retrieves the position and orientation of each prim in local space (relative to + their parent prims). These are the raw transform values stored on each prim. + + Note: + Scale is ignored. The returned poses contain only translation and rotation. + + Args: + indices: Indices of prims to get poses for. Defaults to None, in which case poses are retrieved + for all prims in the view. + + Returns: + A tuple of (translations, orientations) where: + + - translations: Torch tensor of shape (M, 3) containing local-space translations (x, y, z), + where M is the number of prims queried. + - orientations: Torch tensor of shape (M, 4) containing local-space quaternions (w, x, y, z) + """ + # Resolve indices + if indices is None or indices == slice(None): + indices_list = self._ALL_INDICES + else: + # Convert to list if it is a tensor array + indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) + + # Create buffers + translations = Vt.Vec3dArray(len(indices_list)) + orientations = Vt.QuatdArray(len(indices_list)) + # Create xform cache instance + xform_cache = UsdGeom.XformCache(Usd.TimeCode.Default()) + + # Note: We don't use :func:`isaaclab.sim.utils.transforms.resolve_prim_pose` + # here since it isn't optimized for batch operations. + for idx, prim_idx in enumerate(indices_list): + # Get prim + prim = self._prims[prim_idx] + # get prim xform + prim_tf = xform_cache.GetLocalTransformation(prim)[0] + # sanitize quaternion + # this is needed, otherwise the quaternion might be non-normalized + prim_tf.Orthonormalize() + # extract position and orientation + translations[idx] = prim_tf.ExtractTranslation() + orientations[idx] = prim_tf.ExtractRotationQuat() + + # move to torch tensors + translations = torch.tensor(np.array(translations), dtype=torch.float32, device=self._device) + orientations = torch.tensor(np.array(orientations), dtype=torch.float32, device=self._device) + # underlying data is in xyzw order, convert to wxyz order + orientations = math_utils.convert_quat(orientations, to="wxyz") + + return translations, orientations # type: ignore + + def get_scales(self, indices: Sequence[int] | None = None) -> torch.Tensor: + """Get scales for prims in the view. + + This method retrieves the scale of each prim in the view. + + Args: + indices: Indices of prims to get scales for. Defaults to None, in which case scales are retrieved + for all prims in the view. + + Returns: + A tensor of shape (M, 3) containing the scales of each prim, where M is the number of prims queried. + """ + # Resolve indices + if indices is None or indices == slice(None): + indices_list = self._ALL_INDICES + else: + # Convert to list if it is a tensor array + indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) + + # Create buffers + scales = Vt.Vec3dArray(len(indices_list)) + + for idx, prim_idx in enumerate(indices_list): + # Get prim + prim = self._prims[prim_idx] + scales[idx] = prim.GetAttribute("xformOp:scale").Get() + + # Convert to tensor + return torch.tensor(np.array(scales), dtype=torch.float32, device=self._device) + + def get_visibility(self, indices: Sequence[int] | None = None) -> torch.Tensor: + """Get visibility for prims in the view. + + This method retrieves the visibility of each prim in the view. + + Args: + indices: Indices of prims to get visibility for. Defaults to None, in which case visibility is retrieved + for all prims in the view. + + Returns: + A tensor of shape (M,) containing the visibility of each prim, where M is the number of prims queried. + The tensor is of type bool. + """ + # Resolve indices + if indices is None or indices == slice(None): + indices_list = self._ALL_INDICES + else: + # Convert to list if it is a tensor array + indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) + + # Create buffers + visibility = torch.zeros(len(indices_list), dtype=torch.bool, device=self._device) + + for idx, prim_idx in enumerate(indices_list): + # Get prim + imageable = UsdGeom.Imageable(self._prims[prim_idx]) + # Get visibility + visibility[idx] = imageable.ComputeVisibility() != UsdGeom.Tokens.invisible + + return visibility diff --git a/source/isaaclab/test/sim/test_views_xform_prim.py b/source/isaaclab/test/sim/test_views_xform_prim.py new file mode 100644 index 00000000000..1e01de61ced --- /dev/null +++ b/source/isaaclab/test/sim/test_views_xform_prim.py @@ -0,0 +1,1373 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Launch Isaac Sim Simulator first.""" + +from isaaclab.app import AppLauncher + +# launch omniverse app +simulation_app = AppLauncher(headless=True).app + +"""Rest everything follows.""" + +import pytest +import torch + +try: + from isaacsim.core.prims import XFormPrim as _IsaacSimXformPrimView +except (ModuleNotFoundError, ImportError): + _IsaacSimXformPrimView = None + +import isaaclab.sim as sim_utils +from isaaclab.sim.views import XformPrimView as XformPrimView +from isaaclab.utils.assets import ISAAC_NUCLEUS_DIR + + +@pytest.fixture(autouse=True) +def test_setup_teardown(): + """Create a blank new stage for each test.""" + # Setup: Create a new stage + sim_utils.create_new_stage() + sim_utils.update_stage() + + # Yield for the test + yield + + # Teardown: Clear stage after each test + sim_utils.clear_stage() + + +""" +Helper functions. +""" + + +def _prepare_indices(index_type, target_indices, num_prims, device): + """Helper function to prepare indices based on type.""" + if index_type == "list": + return target_indices, target_indices + elif index_type == "torch_tensor": + return torch.tensor(target_indices, dtype=torch.int64, device=device), target_indices + elif index_type == "slice_none": + return slice(None), list(range(num_prims)) + else: + raise ValueError(f"Unknown index type: {index_type}") + + +""" +Tests - Initialization. +""" + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_xform_prim_view_initialization_single_prim(device): + """Test XformPrimView initialization with a single prim.""" + # check if CUDA is available + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + # Create a single xform prim + stage = sim_utils.get_current_stage() + sim_utils.create_prim("/World/Object", "Xform", translation=(1.0, 2.0, 3.0), stage=stage) + + # Create view + view = XformPrimView("/World/Object", device=device) + + # Verify properties + assert view.count == 1 + assert view.prim_paths == ["/World/Object"] + assert view.device == device + assert len(view.prims) == 1 + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_xform_prim_view_initialization_multiple_prims(device): + """Test XformPrimView initialization with multiple prims using pattern matching.""" + # check if CUDA is available + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + # Create multiple prims + num_prims = 10 + stage = sim_utils.get_current_stage() + for i in range(num_prims): + sim_utils.create_prim(f"/World/Env_{i}/Object", "Xform", translation=(i * 2.0, 0.0, 1.0), stage=stage) + + # Create view with pattern + view = XformPrimView("/World/Env_.*/Object", device=device) + + # Verify properties + assert view.count == num_prims + assert view.device == device + assert len(view.prims) == num_prims + assert view.prim_paths == [f"/World/Env_{i}/Object" for i in range(num_prims)] + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_xform_prim_view_initialization_multiple_prims_order(device): + """Test XformPrimView initialization with multiple prims using pattern matching with multiple objects per prim. + + This test validates that XformPrimView respects USD stage traversal order, which is based on + creation order (depth-first search), NOT alphabetical/lexical sorting. This is an important + edge case that ensures deterministic prim ordering that matches USD's internal representation. + + The test creates prims in a deliberately non-alphabetical order (1, 0, A, a, 2) and verifies + that they are retrieved in creation order, not sorted order (0, 1, 2, A, a). + """ + # check if CUDA is available + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + # Create multiple prims + num_prims = 10 + stage = sim_utils.get_current_stage() + + # NOTE: Prims are created in a specific order to test that XformPrimView respects + # USD stage traversal order (DFS based on creation order), NOT alphabetical/lexical order. + # This is an important edge case: children under the same parent are returned in the + # order they were created, not sorted by name. + + # First batch: Create Object_1, Object_0, Object_A for each environment + # (intentionally non-alphabetical: 1, 0, A instead of 0, 1, A) + for i in range(num_prims): + sim_utils.create_prim(f"/World/Env_{i}/Object_1", "Xform", translation=(i * 2.0, -2.0, 1.0), stage=stage) + sim_utils.create_prim(f"/World/Env_{i}/Object_0", "Xform", translation=(i * 2.0, 2.0, 1.0), stage=stage) + sim_utils.create_prim(f"/World/Env_{i}/Object_A", "Xform", translation=(i * 2.0, 0.0, -1.0), stage=stage) + + # Second batch: Create Object_a, Object_2 for each environment + # (created after the first batch to verify traversal is depth-first per environment) + for i in range(num_prims): + sim_utils.create_prim(f"/World/Env_{i}/Object_a", "Xform", translation=(i * 2.0, 2.0, -1.0), stage=stage) + sim_utils.create_prim(f"/World/Env_{i}/Object_2", "Xform", translation=(i * 2.0, 2.0, 1.0), stage=stage) + + # Create view with pattern + view = XformPrimView("/World/Env_.*/Object_.*", device=device) + + # Expected ordering: DFS traversal by environment, with children in creation order + # For each Env_i, we expect: Object_1, Object_0, Object_A, Object_a, Object_2 + # (matches creation order, NOT alphabetical: would be 0, 1, 2, A, a if sorted) + expected_prim_paths_ordering = [] + for i in range(num_prims): + expected_prim_paths_ordering.append(f"/World/Env_{i}/Object_1") + expected_prim_paths_ordering.append(f"/World/Env_{i}/Object_0") + expected_prim_paths_ordering.append(f"/World/Env_{i}/Object_A") + expected_prim_paths_ordering.append(f"/World/Env_{i}/Object_a") + expected_prim_paths_ordering.append(f"/World/Env_{i}/Object_2") + + # Verify properties + assert view.count == num_prims * 5 + assert view.device == device + assert len(view.prims) == num_prims * 5 + assert view.prim_paths == expected_prim_paths_ordering + + # Additional validation: Verify ordering is NOT alphabetical + # If it were alphabetical, Object_0 would come before Object_1 + alphabetical_order = [] + for i in range(num_prims): + alphabetical_order.append(f"/World/Env_{i}/Object_0") + alphabetical_order.append(f"/World/Env_{i}/Object_1") + alphabetical_order.append(f"/World/Env_{i}/Object_2") + alphabetical_order.append(f"/World/Env_{i}/Object_A") + alphabetical_order.append(f"/World/Env_{i}/Object_a") + + assert view.prim_paths != alphabetical_order, ( + "Prim paths should follow creation order, not alphabetical order. " + "This test validates that USD stage traversal respects creation order." + ) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_xform_prim_view_initialization_invalid_prim(device): + """Test XformPrimView initialization fails for non-xformable prims.""" + # check if CUDA is available + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + stage = sim_utils.get_current_stage() + + # Create a prim with non-standard xform operations + stage.DefinePrim("/World/InvalidPrim", "Xform") + + # XformPrimView should raise ValueError because prim doesn't have standard operations + with pytest.raises(ValueError, match="not a xformable prim"): + XformPrimView("/World/InvalidPrim", device=device) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_xform_prim_view_initialization_empty_pattern(device): + """Test XformPrimView initialization with pattern that matches no prims.""" + # check if CUDA is available + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + sim_utils.create_new_stage() + + # Create view with pattern that matches nothing + view = XformPrimView("/World/NonExistent_.*", device=device) + + # Should have zero count + assert view.count == 0 + assert len(view.prims) == 0 + + +""" +Tests - Getters. +""" + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_get_world_poses(device): + """Test getting world poses from XformPrimView.""" + if device.startswith("cuda") and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + stage = sim_utils.get_current_stage() + + # Create prims with known world poses + expected_positions = [(1.0, 2.0, 3.0), (4.0, 5.0, 6.0), (7.0, 8.0, 9.0)] + expected_orientations = [(1.0, 0.0, 0.0, 0.0), (0.7071068, 0.0, 0.0, 0.7071068), (0.7071068, 0.7071068, 0.0, 0.0)] + + for i, (pos, quat) in enumerate(zip(expected_positions, expected_orientations)): + sim_utils.create_prim(f"/World/Object_{i}", "Xform", translation=pos, orientation=quat, stage=stage) + + # Create view + view = XformPrimView("/World/Object_.*", device=device) + + # Get world poses + positions, orientations = view.get_world_poses() + + # Verify shapes + assert positions.shape == (3, 3) + assert orientations.shape == (3, 4) + + # Convert expected values to tensors + expected_positions_tensor = torch.tensor(expected_positions, dtype=torch.float32, device=device) + expected_orientations_tensor = torch.tensor(expected_orientations, dtype=torch.float32, device=device) + + # Verify positions + torch.testing.assert_close(positions, expected_positions_tensor, atol=1e-5, rtol=0) + + # Verify orientations (allow for quaternion sign ambiguity) + try: + torch.testing.assert_close(orientations, expected_orientations_tensor, atol=1e-5, rtol=0) + except AssertionError: + torch.testing.assert_close(orientations, -expected_orientations_tensor, atol=1e-5, rtol=0) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_get_local_poses(device): + """Test getting local poses from XformPrimView.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + stage = sim_utils.get_current_stage() + + # Create parent and child prims + sim_utils.create_prim("/World/Parent", "Xform", translation=(10.0, 0.0, 0.0), stage=stage) + + # Children with different local poses + expected_local_positions = [(1.0, 0.0, 0.0), (0.0, 2.0, 0.0), (0.0, 0.0, 3.0)] + expected_local_orientations = [ + (1.0, 0.0, 0.0, 0.0), + (0.7071068, 0.0, 0.0, 0.7071068), + (0.7071068, 0.7071068, 0.0, 0.0), + ] + + for i, (pos, quat) in enumerate(zip(expected_local_positions, expected_local_orientations)): + sim_utils.create_prim(f"/World/Parent/Child_{i}", "Xform", translation=pos, orientation=quat, stage=stage) + + # Create view + view = XformPrimView("/World/Parent/Child_.*", device=device) + + # Get local poses + translations, orientations = view.get_local_poses() + + # Verify shapes + assert translations.shape == (3, 3) + assert orientations.shape == (3, 4) + + # Convert expected values to tensors + expected_translations_tensor = torch.tensor(expected_local_positions, dtype=torch.float32, device=device) + expected_orientations_tensor = torch.tensor(expected_local_orientations, dtype=torch.float32, device=device) + + # Verify translations + torch.testing.assert_close(translations, expected_translations_tensor, atol=1e-5, rtol=0) + + # Verify orientations (allow for quaternion sign ambiguity) + try: + torch.testing.assert_close(orientations, expected_orientations_tensor, atol=1e-5, rtol=0) + except AssertionError: + torch.testing.assert_close(orientations, -expected_orientations_tensor, atol=1e-5, rtol=0) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_get_scales(device): + """Test getting scales from XformPrimView.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + stage = sim_utils.get_current_stage() + + # Create prims with different scales + expected_scales = [(1.0, 1.0, 1.0), (2.0, 2.0, 2.0), (1.0, 2.0, 3.0)] + + for i, scale in enumerate(expected_scales): + sim_utils.create_prim(f"/World/Object_{i}", "Xform", scale=scale, stage=stage) + + # Create view + view = XformPrimView("/World/Object_.*", device=device) + + # Get scales + scales = view.get_scales() + + # Verify shape and values + assert scales.shape == (3, 3) + expected_scales_tensor = torch.tensor(expected_scales, dtype=torch.float32, device=device) + torch.testing.assert_close(scales, expected_scales_tensor, atol=1e-5, rtol=0) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_get_visibility(device): + """Test getting visibility when all prims are visible.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + stage = sim_utils.get_current_stage() + + # Create prims (default is visible) + num_prims = 5 + for i in range(num_prims): + sim_utils.create_prim(f"/World/Object_{i}", "Xform", translation=(float(i), 0.0, 0.0), stage=stage) + + # Create view + view = XformPrimView("/World/Object_.*", device=device) + + # Get visibility + visibility = view.get_visibility() + + # Verify shape and values + assert visibility.shape == (num_prims,) + assert visibility.dtype == torch.bool + assert torch.all(visibility), "All prims should be visible by default" + + +""" +Tests - Setters. +""" + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_set_world_poses(device): + """Test setting world poses in XformPrimView.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + stage = sim_utils.get_current_stage() + + # Create prims + num_prims = 5 + for i in range(num_prims): + sim_utils.create_prim(f"/World/Object_{i}", "Xform", translation=(0.0, 0.0, 0.0), stage=stage) + + # Create view + view = XformPrimView("/World/Object_.*", device=device) + + # Set new world poses + new_positions = torch.tensor( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0], [13.0, 14.0, 15.0]], device=device + ) + new_orientations = torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0], + [0.7071068, 0.0, 0.0, 0.7071068], + [0.7071068, 0.7071068, 0.0, 0.0], + [0.9238795, 0.3826834, 0.0, 0.0], + [0.7071068, 0.0, 0.7071068, 0.0], + ], + device=device, + ) + + view.set_world_poses(new_positions, new_orientations) + + # Get the poses back + retrieved_positions, retrieved_orientations = view.get_world_poses() + + # Verify they match + torch.testing.assert_close(retrieved_positions, new_positions, atol=1e-5, rtol=0) + # Check quaternions (allow sign flip) + try: + torch.testing.assert_close(retrieved_orientations, new_orientations, atol=1e-5, rtol=0) + except AssertionError: + torch.testing.assert_close(retrieved_orientations, -new_orientations, atol=1e-5, rtol=0) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_set_world_poses_only_positions(device): + """Test setting only positions, leaving orientations unchanged.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + stage = sim_utils.get_current_stage() + + # Create prims with specific orientations + initial_quat = (0.7071068, 0.0, 0.0, 0.7071068) # 90 deg around Z + for i in range(3): + sim_utils.create_prim( + f"/World/Object_{i}", "Xform", translation=(0.0, 0.0, 0.0), orientation=initial_quat, stage=stage + ) + + # Create view + view = XformPrimView("/World/Object_.*", device=device) + + # Get initial orientations + _, initial_orientations = view.get_world_poses() + + # Set only positions + new_positions = torch.tensor([[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0]], device=device) + view.set_world_poses(positions=new_positions, orientations=None) + + # Get poses back + retrieved_positions, retrieved_orientations = view.get_world_poses() + + # Positions should be updated + torch.testing.assert_close(retrieved_positions, new_positions, atol=1e-5, rtol=0) + + # Orientations should be unchanged + try: + torch.testing.assert_close(retrieved_orientations, initial_orientations, atol=1e-5, rtol=0) + except AssertionError: + torch.testing.assert_close(retrieved_orientations, -initial_orientations, atol=1e-5, rtol=0) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_set_world_poses_only_orientations(device): + """Test setting only orientations, leaving positions unchanged.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + stage = sim_utils.get_current_stage() + + # Create prims with specific positions + for i in range(3): + sim_utils.create_prim(f"/World/Object_{i}", "Xform", translation=(float(i), 0.0, 0.0), stage=stage) + + # Create view + view = XformPrimView("/World/Object_.*", device=device) + + # Get initial positions + initial_positions, _ = view.get_world_poses() + + # Set only orientations + new_orientations = torch.tensor( + [[0.7071068, 0.0, 0.0, 0.7071068], [0.7071068, 0.7071068, 0.0, 0.0], [0.9238795, 0.3826834, 0.0, 0.0]], + device=device, + ) + view.set_world_poses(positions=None, orientations=new_orientations) + + # Get poses back + retrieved_positions, retrieved_orientations = view.get_world_poses() + + # Positions should be unchanged + torch.testing.assert_close(retrieved_positions, initial_positions, atol=1e-5, rtol=0) + + # Orientations should be updated + try: + torch.testing.assert_close(retrieved_orientations, new_orientations, atol=1e-5, rtol=0) + except AssertionError: + torch.testing.assert_close(retrieved_orientations, -new_orientations, atol=1e-5, rtol=0) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_set_world_poses_with_hierarchy(device): + """Test setting world poses correctly handles parent transformations.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + stage = sim_utils.get_current_stage() + + # Create parent prims + for i in range(3): + parent_pos = (i * 10.0, 0.0, 0.0) + parent_quat = (0.7071068, 0.0, 0.0, 0.7071068) # 90 deg around Z + sim_utils.create_prim( + f"/World/Parent_{i}", "Xform", translation=parent_pos, orientation=parent_quat, stage=stage + ) + # Create child prims + sim_utils.create_prim(f"/World/Parent_{i}/Child", "Xform", translation=(0.0, 0.0, 0.0), stage=stage) + + # Create view for children + view = XformPrimView("/World/Parent_.*/Child", device=device) + + # Set world poses for children + desired_world_positions = torch.tensor([[5.0, 5.0, 0.0], [15.0, 5.0, 0.0], [25.0, 5.0, 0.0]], device=device) + desired_world_orientations = torch.tensor( + [[1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]], device=device + ) + + view.set_world_poses(desired_world_positions, desired_world_orientations) + + # Get world poses back + retrieved_positions, retrieved_orientations = view.get_world_poses() + + # Should match desired world poses + torch.testing.assert_close(retrieved_positions, desired_world_positions, atol=1e-4, rtol=0) + try: + torch.testing.assert_close(retrieved_orientations, desired_world_orientations, atol=1e-4, rtol=0) + except AssertionError: + torch.testing.assert_close(retrieved_orientations, -desired_world_orientations, atol=1e-4, rtol=0) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_set_local_poses(device): + """Test setting local poses in XformPrimView.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + stage = sim_utils.get_current_stage() + + # Create parent + sim_utils.create_prim("/World/Parent", "Xform", translation=(5.0, 5.0, 5.0), stage=stage) + + # Create children + num_prims = 4 + for i in range(num_prims): + sim_utils.create_prim(f"/World/Parent/Child_{i}", "Xform", translation=(0.0, 0.0, 0.0), stage=stage) + + # Create view + view = XformPrimView("/World/Parent/Child_.*", device=device) + + # Set new local poses + new_translations = torch.tensor([[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0], [4.0, 4.0, 4.0]], device=device) + new_orientations = torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0], + [0.7071068, 0.0, 0.0, 0.7071068], + [0.7071068, 0.7071068, 0.0, 0.0], + [0.9238795, 0.3826834, 0.0, 0.0], + ], + device=device, + ) + + view.set_local_poses(new_translations, new_orientations) + + # Get local poses back + retrieved_translations, retrieved_orientations = view.get_local_poses() + + # Verify they match + torch.testing.assert_close(retrieved_translations, new_translations, atol=1e-5, rtol=0) + try: + torch.testing.assert_close(retrieved_orientations, new_orientations, atol=1e-5, rtol=0) + except AssertionError: + torch.testing.assert_close(retrieved_orientations, -new_orientations, atol=1e-5, rtol=0) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_set_local_poses_only_translations(device): + """Test setting only local translations.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + stage = sim_utils.get_current_stage() + + # Create parent and children with specific orientations + sim_utils.create_prim("/World/Parent", "Xform", translation=(0.0, 0.0, 0.0), stage=stage) + initial_quat = (0.7071068, 0.0, 0.0, 0.7071068) + + for i in range(3): + sim_utils.create_prim( + f"/World/Parent/Child_{i}", "Xform", translation=(0.0, 0.0, 0.0), orientation=initial_quat, stage=stage + ) + + # Create view + view = XformPrimView("/World/Parent/Child_.*", device=device) + + # Get initial orientations + _, initial_orientations = view.get_local_poses() + + # Set only translations + new_translations = torch.tensor([[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0]], device=device) + view.set_local_poses(translations=new_translations, orientations=None) + + # Get poses back + retrieved_translations, retrieved_orientations = view.get_local_poses() + + # Translations should be updated + torch.testing.assert_close(retrieved_translations, new_translations, atol=1e-5, rtol=0) + + # Orientations should be unchanged + try: + torch.testing.assert_close(retrieved_orientations, initial_orientations, atol=1e-5, rtol=0) + except AssertionError: + torch.testing.assert_close(retrieved_orientations, -initial_orientations, atol=1e-5, rtol=0) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_set_scales(device): + """Test setting scales in XformPrimView.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + stage = sim_utils.get_current_stage() + + # Create prims + num_prims = 5 + for i in range(num_prims): + sim_utils.create_prim(f"/World/Object_{i}", "Xform", scale=(1.0, 1.0, 1.0), stage=stage) + + # Create view + view = XformPrimView("/World/Object_.*", device=device) + + # Set new scales + new_scales = torch.tensor( + [[2.0, 2.0, 2.0], [1.0, 2.0, 3.0], [0.5, 0.5, 0.5], [3.0, 1.0, 2.0], [1.5, 1.5, 1.5]], device=device + ) + + view.set_scales(new_scales) + + # Get scales back + retrieved_scales = view.get_scales() + + # Verify they match + torch.testing.assert_close(retrieved_scales, new_scales, atol=1e-5, rtol=0) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_set_visibility(device): + """Test toggling visibility multiple times.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + stage = sim_utils.get_current_stage() + + # Create prims + num_prims = 3 + for i in range(num_prims): + sim_utils.create_prim(f"/World/Object_{i}", "Xform", stage=stage) + + # Create view + view = XformPrimView("/World/Object_.*", device=device) + + # Initial state: all visible + visibility = view.get_visibility() + assert torch.all(visibility), "All should be visible initially" + + # Make all invisible + view.set_visibility(torch.zeros(num_prims, dtype=torch.bool, device=device)) + visibility = view.get_visibility() + assert not torch.any(visibility), "All should be invisible" + + # Make all visible again + view.set_visibility(torch.ones(num_prims, dtype=torch.bool, device=device)) + visibility = view.get_visibility() + assert torch.all(visibility), "All should be visible again" + + # Toggle individual prims + view.set_visibility(torch.tensor([False], dtype=torch.bool, device=device), indices=[1]) + visibility = view.get_visibility() + assert visibility[0] and not visibility[1] and visibility[2], "Only middle prim should be invisible" + + +""" +Tests - Index Handling. +""" + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +@pytest.mark.parametrize("index_type", ["list", "torch_tensor", "slice_none"]) +@pytest.mark.parametrize("method", ["world_poses", "local_poses", "scales", "visibility"]) +def test_index_types_get_methods(device, index_type, method): + """Test that getter methods work with different index types.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + stage = sim_utils.get_current_stage() + + # Create prims based on method type + num_prims = 10 + if method == "local_poses": + # Create parent and children for local poses + sim_utils.create_prim("/World/Parent", "Xform", translation=(10.0, 0.0, 0.0), stage=stage) + for i in range(num_prims): + sim_utils.create_prim( + f"/World/Parent/Child_{i}", "Xform", translation=(float(i), float(i) * 0.5, 0.0), stage=stage + ) + view = XformPrimView("/World/Parent/Child_.*", device=device) + elif method == "scales": + # Create prims with different scales + for i in range(num_prims): + scale = (1.0 + i * 0.5, 1.0 + i * 0.3, 1.0 + i * 0.2) + sim_utils.create_prim(f"/World/Object_{i}", "Xform", scale=scale, stage=stage) + view = XformPrimView("/World/Object_.*", device=device) + else: # world_poses + # Create prims with different positions + for i in range(num_prims): + sim_utils.create_prim(f"/World/Object_{i}", "Xform", translation=(float(i), 0.0, 0.0), stage=stage) + view = XformPrimView("/World/Object_.*", device=device) + + # Get all data as reference + if method == "world_poses": + all_data1, all_data2 = view.get_world_poses() + elif method == "local_poses": + all_data1, all_data2 = view.get_local_poses() + elif method == "scales": + all_data1 = view.get_scales() + all_data2 = None + else: # visibility + all_data1 = view.get_visibility() + all_data2 = None + + # Prepare indices + target_indices_base = [2, 5, 7] + indices, target_indices = _prepare_indices(index_type, target_indices_base, num_prims, device) + + # Get subset + if method == "world_poses": + subset_data1, subset_data2 = view.get_world_poses(indices=indices) # type: ignore[arg-type] + elif method == "local_poses": + subset_data1, subset_data2 = view.get_local_poses(indices=indices) # type: ignore[arg-type] + elif method == "scales": + subset_data1 = view.get_scales(indices=indices) # type: ignore[arg-type] + subset_data2 = None + else: # visibility + subset_data1 = view.get_visibility(indices=indices) # type: ignore[arg-type] + subset_data2 = None + + # Verify shapes + expected_count = len(target_indices) + if method == "visibility": + assert subset_data1.shape == (expected_count,) + else: + assert subset_data1.shape == (expected_count, 3) + if subset_data2 is not None: + assert subset_data2.shape == (expected_count, 4) + + # Verify values + target_indices_tensor = torch.tensor(target_indices, dtype=torch.int64, device=device) + torch.testing.assert_close(subset_data1, all_data1[target_indices_tensor], atol=1e-5, rtol=0) + if subset_data2 is not None and all_data2 is not None: + torch.testing.assert_close(subset_data2, all_data2[target_indices_tensor], atol=1e-5, rtol=0) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +@pytest.mark.parametrize("index_type", ["list", "torch_tensor", "slice_none"]) +@pytest.mark.parametrize("method", ["world_poses", "local_poses", "scales", "visibility"]) +def test_index_types_set_methods(device, index_type, method): + """Test that setter methods work with different index types.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + stage = sim_utils.get_current_stage() + + # Create prims based on method type + num_prims = 10 + if method == "local_poses": + # Create parent and children for local poses + sim_utils.create_prim("/World/Parent", "Xform", translation=(5.0, 5.0, 0.0), stage=stage) + for i in range(num_prims): + sim_utils.create_prim(f"/World/Parent/Child_{i}", "Xform", translation=(float(i), 0.0, 0.0), stage=stage) + view = XformPrimView("/World/Parent/Child_.*", device=device) + else: # world_poses or scales + for i in range(num_prims): + sim_utils.create_prim(f"/World/Object_{i}", "Xform", translation=(0.0, 0.0, 0.0), stage=stage) + view = XformPrimView("/World/Object_.*", device=device) + + # Get initial data + if method == "world_poses": + initial_data1, initial_data2 = view.get_world_poses() + elif method == "local_poses": + initial_data1, initial_data2 = view.get_local_poses() + elif method == "scales": + initial_data1 = view.get_scales() + initial_data2 = None + else: # visibility + initial_data1 = view.get_visibility() + initial_data2 = None + + # Prepare indices + target_indices_base = [2, 5, 7] + indices, target_indices = _prepare_indices(index_type, target_indices_base, num_prims, device) + + # Prepare new data + num_to_set = len(target_indices) + if method in ["world_poses", "local_poses"]: + new_data1 = torch.randn(num_to_set, 3, device=device) * 10.0 + new_data2 = torch.tensor([[1.0, 0.0, 0.0, 0.0]] * num_to_set, dtype=torch.float32, device=device) + elif method == "scales": + new_data1 = torch.rand(num_to_set, 3, device=device) * 2.0 + 0.5 + new_data2 = None + else: # visibility + # Set to False to test change (default is True) + new_data1 = torch.zeros(num_to_set, dtype=torch.bool, device=device) + new_data2 = None + + # Set data + if method == "world_poses": + view.set_world_poses(positions=new_data1, orientations=new_data2, indices=indices) # type: ignore[arg-type] + elif method == "local_poses": + view.set_local_poses(translations=new_data1, orientations=new_data2, indices=indices) # type: ignore[arg-type] + elif method == "scales": + view.set_scales(scales=new_data1, indices=indices) # type: ignore[arg-type] + else: # visibility + view.set_visibility(visibility=new_data1, indices=indices) # type: ignore[arg-type] + + # Get all data after update + if method == "world_poses": + updated_data1, updated_data2 = view.get_world_poses() + elif method == "local_poses": + updated_data1, updated_data2 = view.get_local_poses() + elif method == "scales": + updated_data1 = view.get_scales() + updated_data2 = None + else: # visibility + updated_data1 = view.get_visibility() + updated_data2 = None + + # Verify that specified indices were updated + for i, target_idx in enumerate(target_indices): + torch.testing.assert_close(updated_data1[target_idx], new_data1[i], atol=1e-5, rtol=0) + if new_data2 is not None and updated_data2 is not None: + try: + torch.testing.assert_close(updated_data2[target_idx], new_data2[i], atol=1e-5, rtol=0) + except AssertionError: + # Account for quaternion sign ambiguity + torch.testing.assert_close(updated_data2[target_idx], -new_data2[i], atol=1e-5, rtol=0) + + # Verify that other indices were NOT updated (only for non-slice(None) cases) + if index_type != "slice_none": + for i in range(num_prims): + if i not in target_indices: + torch.testing.assert_close(updated_data1[i], initial_data1[i], atol=1e-5, rtol=0) + if initial_data2 is not None and updated_data2 is not None: + try: + torch.testing.assert_close(updated_data2[i], initial_data2[i], atol=1e-5, rtol=0) + except AssertionError: + # Account for quaternion sign ambiguity + torch.testing.assert_close(updated_data2[i], -initial_data2[i], atol=1e-5, rtol=0) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_indices_single_element(device): + """Test with a single index.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + stage = sim_utils.get_current_stage() + + # Create prims + num_prims = 5 + for i in range(num_prims): + sim_utils.create_prim(f"/World/Object_{i}", "Xform", translation=(float(i), 0.0, 0.0), stage=stage) + + # Create view + view = XformPrimView("/World/Object_.*", device=device) + + # Test with single index + indices = [3] + positions, orientations = view.get_world_poses(indices=indices) + + # Verify shapes + assert positions.shape == (1, 3) + assert orientations.shape == (1, 4) + + # Set pose for single index + new_position = torch.tensor([[100.0, 200.0, 300.0]], device=device) + view.set_world_poses(positions=new_position, indices=indices) + + # Verify it was set + retrieved_positions, _ = view.get_world_poses(indices=indices) + torch.testing.assert_close(retrieved_positions, new_position, atol=1e-5, rtol=0) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_indices_out_of_order(device): + """Test with indices provided in non-sequential order.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + stage = sim_utils.get_current_stage() + + # Create prims + num_prims = 10 + for i in range(num_prims): + sim_utils.create_prim(f"/World/Object_{i}", "Xform", translation=(0.0, 0.0, 0.0), stage=stage) + + # Create view + view = XformPrimView("/World/Object_.*", device=device) + + # Use out-of-order indices + indices = [7, 2, 9, 0, 5] + new_positions = torch.tensor( + [[7.0, 0.0, 0.0], [2.0, 0.0, 0.0], [9.0, 0.0, 0.0], [0.0, 0.0, 0.0], [5.0, 0.0, 0.0]], device=device + ) + + # Set poses with out-of-order indices + view.set_world_poses(positions=new_positions, indices=indices) + + # Get all poses + all_positions, _ = view.get_world_poses() + + # Verify each index got the correct value + expected_x_values = [0.0, 0.0, 2.0, 0.0, 0.0, 5.0, 0.0, 7.0, 0.0, 9.0] + for i in range(num_prims): + assert abs(all_positions[i, 0].item() - expected_x_values[i]) < 1e-5 + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_indices_with_only_positions_or_orientations(device): + """Test indices work correctly when setting only positions or only orientations.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + stage = sim_utils.get_current_stage() + + # Create prims + num_prims = 5 + for i in range(num_prims): + sim_utils.create_prim( + f"/World/Object_{i}", "Xform", translation=(0.0, 0.0, 0.0), orientation=(1.0, 0.0, 0.0, 0.0), stage=stage + ) + + # Create view + view = XformPrimView("/World/Object_.*", device=device) + + # Get initial poses + initial_positions, initial_orientations = view.get_world_poses() + + # Set only positions for specific indices + indices = [1, 3] + new_positions = torch.tensor([[10.0, 0.0, 0.0], [30.0, 0.0, 0.0]], device=device) + view.set_world_poses(positions=new_positions, orientations=None, indices=indices) + + # Get updated poses + updated_positions, updated_orientations = view.get_world_poses() + + # Verify positions updated for indices 1 and 3, others unchanged + torch.testing.assert_close(updated_positions[1], new_positions[0], atol=1e-5, rtol=0) + torch.testing.assert_close(updated_positions[3], new_positions[1], atol=1e-5, rtol=0) + torch.testing.assert_close(updated_positions[0], initial_positions[0], atol=1e-5, rtol=0) + + # Verify all orientations unchanged + try: + torch.testing.assert_close(updated_orientations, initial_orientations, atol=1e-5, rtol=0) + except AssertionError: + torch.testing.assert_close(updated_orientations, -initial_orientations, atol=1e-5, rtol=0) + + # Now set only orientations for different indices + indices2 = [0, 4] + new_orientations = torch.tensor([[0.7071068, 0.0, 0.0, 0.7071068], [0.7071068, 0.7071068, 0.0, 0.0]], device=device) + view.set_world_poses(positions=None, orientations=new_orientations, indices=indices2) + + # Get final poses + final_positions, final_orientations = view.get_world_poses() + + # Verify positions unchanged from previous step + torch.testing.assert_close(final_positions, updated_positions, atol=1e-5, rtol=0) + + # Verify orientations updated for indices 0 and 4 + try: + torch.testing.assert_close(final_orientations[0], new_orientations[0], atol=1e-5, rtol=0) + torch.testing.assert_close(final_orientations[4], new_orientations[1], atol=1e-5, rtol=0) + except AssertionError: + # Account for quaternion sign ambiguity + torch.testing.assert_close(final_orientations[0], -new_orientations[0], atol=1e-5, rtol=0) + torch.testing.assert_close(final_orientations[4], -new_orientations[1], atol=1e-5, rtol=0) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_index_type_none_equivalent_to_all(device): + """Test that indices=None is equivalent to getting/setting all prims.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + stage = sim_utils.get_current_stage() + + # Create prims + num_prims = 6 + for i in range(num_prims): + sim_utils.create_prim(f"/World/Object_{i}", "Xform", translation=(float(i), 0.0, 0.0), stage=stage) + + # Create view + view = XformPrimView("/World/Object_.*", device=device) + + # Get poses with indices=None + pos_none, quat_none = view.get_world_poses(indices=None) + + # Get poses with no argument (default) + pos_default, quat_default = view.get_world_poses() + + # Get poses with slice(None) + pos_slice, quat_slice = view.get_world_poses(indices=slice(None)) # type: ignore[arg-type] + + # All should be equivalent + torch.testing.assert_close(pos_none, pos_default, atol=1e-10, rtol=0) + torch.testing.assert_close(quat_none, quat_default, atol=1e-10, rtol=0) + torch.testing.assert_close(pos_none, pos_slice, atol=1e-10, rtol=0) + torch.testing.assert_close(quat_none, quat_slice, atol=1e-10, rtol=0) + + # Test the same for set operations + new_positions = torch.randn(num_prims, 3, device=device) * 10.0 + new_orientations = torch.tensor([[1.0, 0.0, 0.0, 0.0]] * num_prims, dtype=torch.float32, device=device) + + # Set with indices=None + view.set_world_poses(positions=new_positions, orientations=new_orientations, indices=None) + pos_after_none, quat_after_none = view.get_world_poses() + + # Reset + view.set_world_poses(positions=torch.zeros(num_prims, 3, device=device), indices=None) + + # Set with slice(None) + view.set_world_poses(positions=new_positions, orientations=new_orientations, indices=slice(None)) # type: ignore[arg-type] + pos_after_slice, quat_after_slice = view.get_world_poses() + + # Should be equivalent + torch.testing.assert_close(pos_after_none, pos_after_slice, atol=1e-5, rtol=0) + torch.testing.assert_close(quat_after_none, quat_after_slice, atol=1e-5, rtol=0) + + +""" +Tests - Integration. +""" + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_with_franka_robots(device): + """Test XformPrimView with real Franka robot USD assets.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + stage = sim_utils.get_current_stage() + + # Load Franka robot assets + franka_usd_path = f"{ISAAC_NUCLEUS_DIR}/Robots/FrankaRobotics/FrankaPanda/franka.usd" + + # Add two Franka robots to the stage + sim_utils.create_prim("/World/Franka_1", "Xform", usd_path=franka_usd_path, stage=stage) + sim_utils.create_prim("/World/Franka_2", "Xform", usd_path=franka_usd_path, stage=stage) + + # Create view for both Frankas + frankas_view = XformPrimView("/World/Franka_.*", device=device) + + # Verify count + assert frankas_view.count == 2 + + # Get initial world poses (should be at origin) + initial_positions, initial_orientations = frankas_view.get_world_poses() + + # Verify initial positions are at origin + expected_initial_positions = torch.zeros(2, 3, device=device) + torch.testing.assert_close(initial_positions, expected_initial_positions, atol=1e-5, rtol=0) + + # Verify initial orientations are identity + expected_initial_orientations = torch.tensor([[1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]], device=device) + try: + torch.testing.assert_close(initial_orientations, expected_initial_orientations, atol=1e-5, rtol=0) + except AssertionError: + torch.testing.assert_close(initial_orientations, -expected_initial_orientations, atol=1e-5, rtol=0) + + # Set new world poses + new_positions = torch.tensor([[10.0, 10.0, 0.0], [-40.0, -40.0, 0.0]], device=device) + # 90° rotation around Z axis for first, -90° for second + new_orientations = torch.tensor( + [[0.7071068, 0.0, 0.0, 0.7071068], [0.7071068, 0.0, 0.0, -0.7071068]], device=device + ) + + frankas_view.set_world_poses(positions=new_positions, orientations=new_orientations) + + # Get poses back and verify + retrieved_positions, retrieved_orientations = frankas_view.get_world_poses() + + torch.testing.assert_close(retrieved_positions, new_positions, atol=1e-5, rtol=0) + try: + torch.testing.assert_close(retrieved_orientations, new_orientations, atol=1e-5, rtol=0) + except AssertionError: + torch.testing.assert_close(retrieved_orientations, -new_orientations, atol=1e-5, rtol=0) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_with_nested_targets(device): + """Test with nested frame/target structure similar to Isaac Sim tests.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + stage = sim_utils.get_current_stage() + + # Create frames and targets + for i in range(1, 4): + sim_utils.create_prim(f"/World/Frame_{i}", "Xform", stage=stage) + sim_utils.create_prim(f"/World/Frame_{i}/Target", "Xform", stage=stage) + + # Create views + frames_view = XformPrimView("/World/Frame_.*", device=device) + targets_view = XformPrimView("/World/Frame_.*/Target", device=device) + + assert frames_view.count == 3 + assert targets_view.count == 3 + + # Set local poses for frames + frame_translations = torch.tensor([[0.0, 0.0, 0.0], [0.0, 10.0, 5.0], [0.0, 3.0, 5.0]], device=device) + frames_view.set_local_poses(translations=frame_translations) + + # Set local poses for targets + target_translations = torch.tensor([[0.0, 20.0, 10.0], [0.0, 30.0, 20.0], [0.0, 50.0, 10.0]], device=device) + targets_view.set_local_poses(translations=target_translations) + + # Get world poses of targets + world_positions, _ = targets_view.get_world_poses() + + # Expected world positions are frame_translation + target_translation + expected_positions = torch.tensor([[0.0, 20.0, 10.0], [0.0, 40.0, 25.0], [0.0, 53.0, 15.0]], device=device) + + torch.testing.assert_close(world_positions, expected_positions, atol=1e-5, rtol=0) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_visibility_with_hierarchy(device): + """Test visibility with parent-child hierarchy and inheritance.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + stage = sim_utils.get_current_stage() + + # Create parent and children + sim_utils.create_prim("/World/Parent", "Xform", stage=stage) + + num_children = 4 + for i in range(num_children): + sim_utils.create_prim(f"/World/Parent/Child_{i}", "Xform", stage=stage) + + # Create views for both parent and children + parent_view = XformPrimView("/World/Parent", device=device) + children_view = XformPrimView("/World/Parent/Child_.*", device=device) + + # Verify parent and all children are visible initially + parent_visibility = parent_view.get_visibility() + children_visibility = children_view.get_visibility() + assert parent_visibility[0], "Parent should be visible initially" + assert torch.all(children_visibility), "All children should be visible initially" + + # Make some children invisible directly + new_visibility = torch.tensor([True, False, True, False], dtype=torch.bool, device=device) + children_view.set_visibility(new_visibility) + + # Verify the visibility changes + retrieved_visibility = children_view.get_visibility() + torch.testing.assert_close(retrieved_visibility, new_visibility) + + # Make all children visible again + children_view.set_visibility(torch.ones(num_children, dtype=torch.bool, device=device)) + all_visible = children_view.get_visibility() + assert torch.all(all_visible), "All children should be visible again" + + # Now test parent visibility inheritance: + # Make parent invisible + parent_view.set_visibility(torch.tensor([False], dtype=torch.bool, device=device)) + + # Verify parent is invisible + parent_visibility = parent_view.get_visibility() + assert not parent_visibility[0], "Parent should be invisible" + + # Verify children are also invisible (due to parent being invisible) + children_visibility = children_view.get_visibility() + assert not torch.any(children_visibility), "All children should be invisible when parent is invisible" + + # Make parent visible again + parent_view.set_visibility(torch.tensor([True], dtype=torch.bool, device=device)) + + # Verify parent is visible + parent_visibility = parent_view.get_visibility() + assert parent_visibility[0], "Parent should be visible again" + + # Verify children are also visible again + children_visibility = children_view.get_visibility() + assert torch.all(children_visibility), "All children should be visible again when parent is visible" + + +""" +Tests - Comparison with Isaac Sim Implementation. +""" + + +def test_compare_get_world_poses_with_isaacsim(): + """Compare get_world_poses with Isaac Sim's implementation.""" + stage = sim_utils.get_current_stage() + + # Check if Isaac Sim is available + if _IsaacSimXformPrimView is None: + pytest.skip("Isaac Sim is not available") + + # Create prims with various poses + num_prims = 10 + for i in range(num_prims): + pos = (i * 2.0, i * 0.5, i * 1.5) + # Vary orientations + if i % 3 == 0: + quat = (1.0, 0.0, 0.0, 0.0) # Identity + elif i % 3 == 1: + quat = (0.7071068, 0.0, 0.0, 0.7071068) # 90 deg around Z + else: + quat = (0.7071068, 0.7071068, 0.0, 0.0) # 90 deg around X + sim_utils.create_prim(f"/World/Env_{i}/Object", "Xform", translation=pos, orientation=quat, stage=stage) + + pattern = "/World/Env_.*/Object" + + # Create both views + isaaclab_view = XformPrimView(pattern, device="cpu") + isaacsim_view = _IsaacSimXformPrimView(pattern, reset_xform_properties=False) + + # Get world poses from both + isaaclab_pos, isaaclab_quat = isaaclab_view.get_world_poses() + isaacsim_pos, isaacsim_quat = isaacsim_view.get_world_poses() + + # Convert Isaac Sim results to torch tensors if needed + if not isinstance(isaacsim_pos, torch.Tensor): + isaacsim_pos = torch.tensor(isaacsim_pos, dtype=torch.float32) + if not isinstance(isaacsim_quat, torch.Tensor): + isaacsim_quat = torch.tensor(isaacsim_quat, dtype=torch.float32) + + # Compare results + torch.testing.assert_close(isaaclab_pos, isaacsim_pos, atol=1e-5, rtol=0) + + # Compare quaternions (account for sign ambiguity) + try: + torch.testing.assert_close(isaaclab_quat, isaacsim_quat, atol=1e-5, rtol=0) + except AssertionError: + torch.testing.assert_close(isaaclab_quat, -isaacsim_quat, atol=1e-5, rtol=0) + + +def test_compare_set_world_poses_with_isaacsim(): + """Compare set_world_poses with Isaac Sim's implementation.""" + stage = sim_utils.get_current_stage() + + # Check if Isaac Sim is available + if _IsaacSimXformPrimView is None: + pytest.skip("Isaac Sim is not available") + + # Create prims + num_prims = 8 + for i in range(num_prims): + sim_utils.create_prim(f"/World/Env_{i}/Object", "Xform", translation=(0.0, 0.0, 0.0), stage=stage) + + pattern = "/World/Env_.*/Object" + + # Create both views + isaaclab_view = XformPrimView(pattern, device="cpu") + isaacsim_view = _IsaacSimXformPrimView(pattern, reset_xform_properties=False) + + # Generate new poses + new_positions = torch.randn(num_prims, 3) * 10.0 + new_orientations = torch.tensor([[1.0, 0.0, 0.0, 0.0]] * num_prims, dtype=torch.float32) + + # Set poses using both implementations + isaaclab_view.set_world_poses(new_positions.clone(), new_orientations.clone()) + isaacsim_view.set_world_poses(new_positions.clone(), new_orientations.clone()) + + # Get poses back from both + isaaclab_pos, isaaclab_quat = isaaclab_view.get_world_poses() + isaacsim_pos, isaacsim_quat = isaacsim_view.get_world_poses() + + # Convert Isaac Sim results to torch tensors if needed + if not isinstance(isaacsim_pos, torch.Tensor): + isaacsim_pos = torch.tensor(isaacsim_pos, dtype=torch.float32) + if not isinstance(isaacsim_quat, torch.Tensor): + isaacsim_quat = torch.tensor(isaacsim_quat, dtype=torch.float32) + + # Compare results - both implementations should produce the same world poses + torch.testing.assert_close(isaaclab_pos, isaacsim_pos, atol=1e-4, rtol=0) + try: + torch.testing.assert_close(isaaclab_quat, isaacsim_quat, atol=1e-4, rtol=0) + except AssertionError: + torch.testing.assert_close(isaaclab_quat, -isaacsim_quat, atol=1e-4, rtol=0) + + +def test_compare_get_local_poses_with_isaacsim(): + """Compare get_local_poses with Isaac Sim's implementation.""" + stage = sim_utils.get_current_stage() + + # Check if Isaac Sim is available + if _IsaacSimXformPrimView is None: + pytest.skip("Isaac Sim is not available") + + # Create hierarchical prims + num_prims = 5 + for i in range(num_prims): + # Create parent + sim_utils.create_prim(f"/World/Env_{i}", "Xform", translation=(i * 5.0, 0.0, 0.0), stage=stage) + # Create child with local pose + local_pos = (1.0, float(i), 0.0) + local_quat = (1.0, 0.0, 0.0, 0.0) if i % 2 == 0 else (0.7071068, 0.0, 0.0, 0.7071068) + sim_utils.create_prim( + f"/World/Env_{i}/Object", "Xform", translation=local_pos, orientation=local_quat, stage=stage + ) + + pattern = "/World/Env_.*/Object" + + # Create both views + isaaclab_view = XformPrimView(pattern, device="cpu") + isaacsim_view = _IsaacSimXformPrimView(pattern, reset_xform_properties=False) + + # Get local poses from both + isaaclab_trans, isaaclab_quat = isaaclab_view.get_local_poses() + isaacsim_trans, isaacsim_quat = isaacsim_view.get_local_poses() + + # Convert Isaac Sim results to torch tensors if needed + if not isinstance(isaacsim_trans, torch.Tensor): + isaacsim_trans = torch.tensor(isaacsim_trans, dtype=torch.float32) + if not isinstance(isaacsim_quat, torch.Tensor): + isaacsim_quat = torch.tensor(isaacsim_quat, dtype=torch.float32) + + # Compare results + torch.testing.assert_close(isaaclab_trans, isaacsim_trans, atol=1e-5, rtol=0) + try: + torch.testing.assert_close(isaaclab_quat, isaacsim_quat, atol=1e-5, rtol=0) + except AssertionError: + torch.testing.assert_close(isaaclab_quat, -isaacsim_quat, atol=1e-5, rtol=0) + + +def test_compare_set_local_poses_with_isaacsim(): + """Compare set_local_poses with Isaac Sim's implementation.""" + stage = sim_utils.get_current_stage() + + # Check if Isaac Sim is available + if _IsaacSimXformPrimView is None: + pytest.skip("Isaac Sim is not available") + + # Create hierarchical prims + num_prims = 6 + for i in range(num_prims): + sim_utils.create_prim(f"/World/Env_{i}", "Xform", translation=(i * 3.0, 0.0, 0.0), stage=stage) + sim_utils.create_prim(f"/World/Env_{i}/Object", "Xform", translation=(0.0, 0.0, 0.0), stage=stage) + + pattern = "/World/Env_.*/Object" + + # Create both views + isaaclab_view = XformPrimView(pattern, device="cpu") + isaacsim_view = _IsaacSimXformPrimView(pattern, reset_xform_properties=False) + + # Generate new local poses + new_translations = torch.randn(num_prims, 3) * 5.0 + new_orientations = torch.tensor( + [[1.0, 0.0, 0.0, 0.0], [0.7071068, 0.0, 0.0, 0.7071068]] * (num_prims // 2), dtype=torch.float32 + ) + + # Set local poses using both implementations + isaaclab_view.set_local_poses(new_translations.clone(), new_orientations.clone()) + isaacsim_view.set_local_poses(new_translations.clone(), new_orientations.clone()) + + # Get local poses back from both + isaaclab_trans, isaaclab_quat = isaaclab_view.get_local_poses() + isaacsim_trans, isaacsim_quat = isaacsim_view.get_local_poses() + + # Convert Isaac Sim results to torch tensors if needed + if not isinstance(isaacsim_trans, torch.Tensor): + isaacsim_trans = torch.tensor(isaacsim_trans, dtype=torch.float32) + if not isinstance(isaacsim_quat, torch.Tensor): + isaacsim_quat = torch.tensor(isaacsim_quat, dtype=torch.float32) + + # Compare results + torch.testing.assert_close(isaaclab_trans, isaacsim_trans, atol=1e-4, rtol=0) + try: + torch.testing.assert_close(isaaclab_quat, isaacsim_quat, atol=1e-4, rtol=0) + except AssertionError: + torch.testing.assert_close(isaaclab_quat, -isaacsim_quat, atol=1e-4, rtol=0)