diff --git a/docs/source/api/lab/isaaclab.sim.utils.rst b/docs/source/api/lab/isaaclab.sim.utils.rst index 9d59df77bcc..f27e574efb9 100644 --- a/docs/source/api/lab/isaaclab.sim.utils.rst +++ b/docs/source/api/lab/isaaclab.sim.utils.rst @@ -10,6 +10,7 @@ stage queries prims + transforms semantics legacy @@ -34,6 +35,13 @@ Prims :members: :show-inheritance: +Transforms +---------- + +.. automodule:: isaaclab.sim.utils.transforms + :members: + :show-inheritance: + Semantics --------- diff --git a/source/isaaclab/config/extension.toml b/source/isaaclab/config/extension.toml index 66246612534..81047f7cadb 100644 --- a/source/isaaclab/config/extension.toml +++ b/source/isaaclab/config/extension.toml @@ -1,7 +1,7 @@ [package] # Note: Semantic Versioning is used: https://semver.org/ -version = "0.51.1" +version = "0.52.0" # Description title = "Isaac Lab framework for Robot Learning" diff --git a/source/isaaclab/docs/CHANGELOG.rst b/source/isaaclab/docs/CHANGELOG.rst index ae136334a93..09aa6d708e7 100644 --- a/source/isaaclab/docs/CHANGELOG.rst +++ b/source/isaaclab/docs/CHANGELOG.rst @@ -1,6 +1,23 @@ Changelog --------- +0.52.0 (2026-01-02) +~~~~~~~~~~~~~~~~~~~ + +Added +^^^^^ + +* Added :mod:`~isaaclab.sim.utils.transforms` module to handle USD Xform operations. +* Added passing of ``stage`` to the :func:`~isaaclab.sim.utils.prims.create_prim` function + inside spawning functions to allow for the creation of prims in a specific stage. + +Changed +^^^^^^^ + +* Changed :func:`~isaaclab.sim.utils.prims.create_prim` function to use the :mod:`~isaaclab.sim.utils.transforms` + module for USD Xform operations. It removes the usage of Isaac Sim's XformPrim class. + + 0.51.2 (2025-12-30) ~~~~~~~~~~~~~~~~~~~ diff --git a/source/isaaclab/isaaclab/sim/spawners/from_files/from_files.py b/source/isaaclab/isaaclab/sim/spawners/from_files/from_files.py index d2583d8c9e3..17d69d7da11 100644 --- a/source/isaaclab/isaaclab/sim/spawners/from_files/from_files.py +++ b/source/isaaclab/isaaclab/sim/spawners/from_files/from_files.py @@ -198,7 +198,7 @@ def spawn_ground_plane( # Spawn Ground-plane if not stage.GetPrimAtPath(prim_path).IsValid(): - create_prim(prim_path, usd_path=cfg.usd_path, translation=translation, orientation=orientation) + create_prim(prim_path, usd_path=cfg.usd_path, translation=translation, orientation=orientation, stage=stage) else: raise ValueError(f"A prim already exists at path: '{prim_path}'.") @@ -327,6 +327,7 @@ def _spawn_from_usd_file( translation=translation, orientation=orientation, scale=cfg.scale, + stage=stage, ) else: logger.warning(f"A prim already exists at prim path: '{prim_path}'.") diff --git a/source/isaaclab/isaaclab/sim/spawners/lights/lights.py b/source/isaaclab/isaaclab/sim/spawners/lights/lights.py index 51a57c61ba4..9b0106c6ecd 100644 --- a/source/isaaclab/isaaclab/sim/spawners/lights/lights.py +++ b/source/isaaclab/isaaclab/sim/spawners/lights/lights.py @@ -50,7 +50,9 @@ def spawn_light( if stage.GetPrimAtPath(prim_path).IsValid(): raise ValueError(f"A prim already exists at path: '{prim_path}'.") # create the prim - prim = create_prim(prim_path, prim_type=cfg.prim_type, translation=translation, orientation=orientation) + prim = create_prim( + prim_path, prim_type=cfg.prim_type, translation=translation, orientation=orientation, stage=stage + ) # convert to dict cfg = cfg.to_dict() diff --git a/source/isaaclab/isaaclab/sim/spawners/shapes/shapes.py b/source/isaaclab/isaaclab/sim/spawners/shapes/shapes.py index 6ae610c1bc2..b701b904ea2 100644 --- a/source/isaaclab/isaaclab/sim/spawners/shapes/shapes.py +++ b/source/isaaclab/isaaclab/sim/spawners/shapes/shapes.py @@ -282,7 +282,7 @@ def _spawn_geom_from_prim_type( # spawn geometry if it doesn't exist. if not stage.GetPrimAtPath(prim_path).IsValid(): - create_prim(prim_path, prim_type="Xform", translation=translation, orientation=orientation) + create_prim(prim_path, prim_type="Xform", translation=translation, orientation=orientation, stage=stage) else: raise ValueError(f"A prim already exists at path: '{prim_path}'.") @@ -291,7 +291,7 @@ def _spawn_geom_from_prim_type( mesh_prim_path = geom_prim_path + "/mesh" # create the geometry prim - create_prim(mesh_prim_path, prim_type, scale=scale, attributes=attributes) + create_prim(mesh_prim_path, prim_type, scale=scale, attributes=attributes, stage=stage) # apply collision properties if cfg.collision_props is not None: schemas.define_collision_properties(mesh_prim_path, cfg.collision_props) diff --git a/source/isaaclab/isaaclab/sim/spawners/wrappers/wrappers.py b/source/isaaclab/isaaclab/sim/spawners/wrappers/wrappers.py index 2e5141096f1..64d0c4f4ab9 100644 --- a/source/isaaclab/isaaclab/sim/spawners/wrappers/wrappers.py +++ b/source/isaaclab/isaaclab/sim/spawners/wrappers/wrappers.py @@ -67,7 +67,7 @@ def spawn_multi_asset( # find a free prim path to hold all the template prims template_prim_path = sim_utils.get_next_free_prim_path("/World/Template", stage=stage) - sim_utils.create_prim(template_prim_path, "Scope") + sim_utils.create_prim(template_prim_path, "Scope", stage=stage) # spawn everything first in a "Dataset" prim proto_prim_paths = list() @@ -116,7 +116,7 @@ def spawn_multi_asset( Sdf.CopySpec(env_spec.layer, Sdf.Path(proto_path), env_spec.layer, Sdf.Path(prim_path)) # delete the dataset prim after spawning - sim_utils.delete_prim(template_prim_path) + sim_utils.delete_prim(template_prim_path, stage=stage) # set carb setting to indicate Isaac Lab's environments that different prims have been spawned # at varying prim paths. In this case, PhysX parser shouldn't optimize the stage parsing. diff --git a/source/isaaclab/isaaclab/sim/utils/__init__.py b/source/isaaclab/isaaclab/sim/utils/__init__.py index 3402602fcf5..3a85ae44c2f 100644 --- a/source/isaaclab/isaaclab/sim/utils/__init__.py +++ b/source/isaaclab/isaaclab/sim/utils/__init__.py @@ -10,3 +10,4 @@ from .queries import * # noqa: F401, F403 from .semantics import * # noqa: F401, F403 from .stage import * # noqa: F401, F403 +from .transforms import * # noqa: F401, F403 diff --git a/source/isaaclab/isaaclab/sim/utils/prims.py b/source/isaaclab/isaaclab/sim/utils/prims.py index 4f0e640b81f..020aa5d4d6f 100644 --- a/source/isaaclab/isaaclab/sim/utils/prims.py +++ b/source/isaaclab/isaaclab/sim/utils/prims.py @@ -11,6 +11,7 @@ import inspect import logging import re +import torch from collections.abc import Callable, Sequence from typing import TYPE_CHECKING, Any @@ -20,7 +21,7 @@ import usdrt # noqa: F401 from isaacsim.core.cloner import Cloner from omni.usd.commands import DeletePrimsCommand, MovePrimCommand -from pxr import PhysxSchema, Sdf, Usd, UsdGeom, UsdPhysics, UsdShade +from pxr import PhysxSchema, Sdf, Usd, UsdGeom, UsdPhysics, UsdShade, UsdUtils from isaaclab.utils.string import to_camel_case from isaaclab.utils.version import get_isaac_sim_version @@ -28,6 +29,7 @@ from .queries import find_matching_prim_paths from .semantics import add_labels from .stage import attach_stage_to_usd_context, get_current_stage, get_current_stage_id +from .transforms import convert_world_pose_to_local, standardize_xform_ops if TYPE_CHECKING: from isaaclab.sim.spawners.spawner_cfg import SpawnerCfg @@ -44,10 +46,10 @@ def create_prim( prim_path: str, prim_type: str = "Xform", - position: Sequence[float] | None = None, - translation: Sequence[float] | None = None, - orientation: Sequence[float] | None = None, - scale: Sequence[float] | None = None, + position: Any | None = None, + translation: Any | None = None, + orientation: Any | None = None, + scale: Any | None = None, usd_path: str | None = None, semantic_label: str | None = None, semantic_type: str = "class", @@ -57,40 +59,85 @@ def create_prim( """Creates a prim in the provided USD stage. The method applies the specified transforms, the semantic label and sets the specified attributes. + The transform can be specified either in world space (using ``position``) or local space (using + ``translation``). + + The function determines the coordinate system of the transform based on the provided arguments. + + * If ``position`` is provided, it is assumed the orientation is provided in the world frame as well. + * If ``translation`` is provided, it is assumed the orientation is provided in the local frame as well. + + The scale is always applied in the local frame. + + The function handles various sequence types (list, tuple, numpy array, torch tensor) + and converts them to properly-typed tuples for operations on the prim. + + .. note:: + Transform operations are standardized to the USD convention: translate, orient (quaternion), + and scale, in that order. See :func:`standardize_xform_ops` for more details. Args: - prim_path: The path of the new prim. - prim_type: Prim type name - position: prim position (applied last) - translation: prim translation (applied last) - orientation: prim rotation as quaternion - scale: scaling factor in x, y, z. - usd_path: Path to the USD that this prim will reference. - semantic_label: Semantic label. - semantic_type: set to "class" unless otherwise specified. - attributes: Key-value pairs of prim attributes to set. - stage: The stage to create the prim in. Defaults to None, in which case the current stage is used. + prim_path: + The path of the new prim. + prim_type: + Prim type name. Defaults to "Xform", in which case a simple Xform prim is created. + position: + Prim position in world space as (x, y, z). If the prim has a parent, this is + automatically converted to local space relative to the parent. Cannot be used with + ``translation``. Defaults to None, in which case no position is applied. + translation: + Prim translation in local space as (x, y, z). This is applied directly without + any coordinate transformation. Cannot be used with ``position``. Defaults to None, + in which case no translation is applied. + orientation: + Prim rotation as a quaternion (w, x, y, z). When used with ``position``, the + orientation is also converted from world space to local space. When used with ``translation``, + it is applied directly as local orientation. Defaults to None. + scale: + Scaling factor in x, y, z. Applied in local space. Defaults to None, + in which case a uniform scale of 1.0 is applied. + usd_path: + Path to the USD file that this prim will reference. Defaults to None. + semantic_label: + Semantic label to apply to the prim. Defaults to None, in which case no label is added. + semantic_type: + Semantic type for the label. Defaults to "class". + attributes: + Key-value pairs of prim attributes to set. Defaults to None, in which case no attributes are set. + stage: + The stage to create the prim in. Defaults to None, in which case the current stage is used. Returns: The created USD prim. Raises: - ValueError: If there is already a prim at the provided prim_path. + ValueError: If there is already a prim at the provided prim path. + ValueError: If both position and translation are provided. Example: >>> import isaaclab.sim as sim_utils >>> - >>> # create a cube (/World/Cube) of size 2 centered at (1.0, 0.5, 0.0) + >>> # Create a cube at world position (1.0, 0.5, 0.0) >>> sim_utils.create_prim( - ... prim_path="/World/Cube", + ... prim_path="/World/Parent/Cube", ... prim_type="Cube", ... position=(1.0, 0.5, 0.0), ... attributes={"size": 2.0} ... ) - Usd.Prim() + Usd.Prim() + >>> + >>> # Create a sphere with local translation relative to its parent + >>> sim_utils.create_prim( + ... prim_path="/World/Parent/Sphere", + ... prim_type="Sphere", + ... translation=(0.5, 0.0, 0.0), + ... scale=(2.0, 2.0, 2.0) + ... ) + Usd.Prim() """ - # Note: Imported here to prevent cyclic dependency in the module. - from isaacsim.core.prims import XFormPrim + # Ensure that user doesn't provide both position and translation + if position is not None and translation is not None: + raise ValueError("Cannot provide both position and translation. Please provide only one.") # obtain stage handle stage = get_current_stage() if stage is None else stage @@ -114,26 +161,20 @@ def create_prim( if semantic_label is not None: add_labels(prim, labels=[semantic_label], instance_name=semantic_type) - # apply the transformations - from isaacsim.core.api.simulation_context.simulation_context import SimulationContext - - if SimulationContext.instance() is None: - # FIXME: remove this, we should never even use backend utils especially not numpy ones - import isaacsim.core.utils.numpy as backend_utils + # convert input arguments to tuples + position = _to_tuple(position) if position is not None else None + translation = _to_tuple(translation) if translation is not None else None + orientation = _to_tuple(orientation) if orientation is not None else None + scale = _to_tuple(scale) if scale is not None else None - device = "cpu" - else: - backend_utils = SimulationContext.instance().backend_utils - device = SimulationContext.instance().device + # convert position and orientation to translation and orientation + # world --> local if position is not None: - position = backend_utils.expand_dims(backend_utils.convert(position, device), 0) - if translation is not None: - translation = backend_utils.expand_dims(backend_utils.convert(translation, device), 0) - if orientation is not None: - orientation = backend_utils.expand_dims(backend_utils.convert(orientation, device), 0) - if scale is not None: - scale = backend_utils.expand_dims(backend_utils.convert(scale, device), 0) - XFormPrim(prim_path, positions=position, translations=translation, orientations=orientation, scales=scale) + # this means that user provided pose in the world frame + translation, orientation = convert_world_pose_to_local(position, orientation, ref_prim=prim.GetParent()) + + # standardize the xform ops + standardize_xform_ops(prim, translation, orientation, scale) return prim @@ -156,6 +197,12 @@ def delete_prim(prim_path: str | Sequence[str], stage: Usd.Stage | None = None) prim_path = [prim_path] # get stage handle stage = get_current_stage() if stage is None else stage + # the prim command looks for the stage ID in the stage cache + # so we need to ensure the stage is cached + stage_cache = UsdUtils.StageCache.Get() + stage_id = stage_cache.GetId(stage).ToLongInt() + if stage_id < 0: + stage_id = stage_cache.Insert(stage).ToLongInt() # delete prims DeletePrimsCommand(prim_path, stage=stage).do() @@ -240,97 +287,6 @@ def make_uninstanceable(prim_path: str | Sdf.Path, stage: Usd.Stage | None = Non all_prims += child_prim.GetFilteredChildren(Usd.TraverseInstanceProxies()) -def resolve_prim_pose( - prim: Usd.Prim, ref_prim: Usd.Prim | None = None -) -> tuple[tuple[float, float, float], tuple[float, float, float, float]]: - """Resolve the pose of a prim with respect to another prim. - - Note: - This function ignores scale and skew by orthonormalizing the transformation - matrix at the final step. However, if any ancestor prim in the hierarchy - has non-uniform scale, that scale will still affect the resulting position - and orientation of the prim (because it's baked into the transform before - scale removal). - - In other words: scale **is not removed hierarchically**. If you need - completely scale-free poses, you must walk the transform chain and strip - scale at each level. Please open an issue if you need this functionality. - - Args: - prim: The USD prim to resolve the pose for. - ref_prim: The USD prim to compute the pose with respect to. - Defaults to None, in which case the world frame is used. - - Returns: - A tuple containing the position (as a 3D vector) and the quaternion orientation - in the (w, x, y, z) format. - - Raises: - ValueError: If the prim or ref prim is not valid. - """ - # check if prim is valid - if not prim.IsValid(): - raise ValueError(f"Prim at path '{prim.GetPath().pathString}' is not valid.") - # get prim xform - xform = UsdGeom.Xformable(prim) - prim_tf = xform.ComputeLocalToWorldTransform(Usd.TimeCode.Default()) - # sanitize quaternion - # this is needed, otherwise the quaternion might be non-normalized - prim_tf = prim_tf.GetOrthonormalized() - - if ref_prim is not None: - # check if ref prim is valid - if not ref_prim.IsValid(): - raise ValueError(f"Ref prim at path '{ref_prim.GetPath().pathString}' is not valid.") - # get ref prim xform - ref_xform = UsdGeom.Xformable(ref_prim) - ref_tf = ref_xform.ComputeLocalToWorldTransform(Usd.TimeCode.Default()) - # make sure ref tf is orthonormal - ref_tf = ref_tf.GetOrthonormalized() - # compute relative transform to get prim in ref frame - prim_tf = prim_tf * ref_tf.GetInverse() - - # extract position and orientation - prim_pos = [*prim_tf.ExtractTranslation()] - prim_quat = [prim_tf.ExtractRotationQuat().real, *prim_tf.ExtractRotationQuat().imaginary] - return tuple(prim_pos), tuple(prim_quat) - - -def resolve_prim_scale(prim: Usd.Prim) -> tuple[float, float, float]: - """Resolve the scale of a prim in the world frame. - - At an attribute level, a USD prim's scale is a scaling transformation applied to the prim with - respect to its parent prim. This function resolves the scale of the prim in the world frame, - by computing the local to world transform of the prim. This is equivalent to traversing up - the prim hierarchy and accounting for the rotations and scales of the prims. - - For instance, if a prim has a scale of (1, 2, 3) and it is a child of a prim with a scale of (4, 5, 6), - then the scale of the prim in the world frame is (4, 10, 18). - - Args: - prim: The USD prim to resolve the scale for. - - Returns: - The scale of the prim in the x, y, and z directions in the world frame. - - Raises: - ValueError: If the prim is not valid. - """ - # check if prim is valid - if not prim.IsValid(): - raise ValueError(f"Prim at path '{prim.GetPath().pathString}' is not valid.") - # compute local to world transform - xform = UsdGeom.Xformable(prim) - world_transform = xform.ComputeLocalToWorldTransform(Usd.TimeCode.Default()) - # extract scale - return tuple([*(v.GetLength() for v in world_transform.ExtractRotationMatrix())]) - - -""" -Attribute - Setters. -""" - - def set_prim_visibility(prim: Usd.Prim, visible: bool) -> None: """Sets the visibility of the prim in the opened stage. @@ -921,6 +877,9 @@ def get_usd_references(prim_path: str, stage: Usd.Stage | None = None) -> list[s Returns: A list of USD reference paths. + + Raises: + ValueError: If the prim at the specified path is not valid. """ # get stage handle stage = get_current_stage() if stage is None else stage @@ -931,7 +890,8 @@ def get_usd_references(prim_path: str, stage: Usd.Stage | None = None) -> list[s # get USD references references = [] for prim_spec in prim.GetPrimStack(): - references.extend(prim_spec.referenceList.prependedItems.assetPath) + for ref in prim_spec.referenceList.prependedItems: + references.append(str(ref.assetPath)) return references @@ -1008,3 +968,59 @@ class TableVariants: f"Setting variant selection '{variant_selection}' for variant set '{variant_set_name}' on" f" prim '{prim_path}'." ) + + +""" +Internal Helpers. +""" + + +def _to_tuple(value: Any) -> tuple[float, ...]: + """Convert various sequence types to a Python tuple of floats. + + This function provides robust conversion from different array-like types (list, tuple, numpy array, + torch tensor) to Python tuples. It handles edge cases like malformed sequences, CUDA tensors, + and arrays with singleton dimensions. + + Args: + value: A sequence-like object containing floats. Supported types include: + - Python list or tuple + - NumPy array (any device) + - PyTorch tensor (CPU or CUDA) + - Mixed sequences with numpy/torch scalar items and float values + + Returns: + A one-dimensional tuple of floats. + + Raises: + ValueError: If the input value is not one-dimensional after squeezing singleton dimensions. + + Example: + >>> import torch + >>> import numpy as np + >>> + >>> _to_tuple([1.0, 2.0, 3.0]) + (1.0, 2.0, 3.0) + >>> _to_tuple(torch.tensor([[1.0, 2.0]])) # Squeezes first dimension + (1.0, 2.0) + >>> _to_tuple(np.array([1.0, 2.0, 3.0])) + (1.0, 2.0, 3.0) + >>> _to_tuple((1.0, 2.0, 3.0)) + (1.0, 2.0, 3.0) + + """ + # Normalize to tensor if value is a plain sequence (list with mixed types, etc.) + # This handles cases like [np.float32(1.0), 2.0, torch.tensor(3.0)] + if not hasattr(value, "tolist"): + value = torch.tensor(value, device="cpu", dtype=torch.float) + + # Remove leading singleton dimension if present (e.g., shape (1, 3) -> (3,)) + # This is common when batched operations produce single-item batches + if value.ndim != 1: + value = value.squeeze() + # Validate that the result is one-dimensional + if value.ndim != 1: + raise ValueError(f"Input value is not one dimensional: {value.shape}") + + # Convert to tuple - works for both numpy arrays and torch tensors + return tuple(value.tolist()) diff --git a/source/isaaclab/isaaclab/sim/utils/transforms.py b/source/isaaclab/isaaclab/sim/utils/transforms.py new file mode 100644 index 00000000000..b7da662fd77 --- /dev/null +++ b/source/isaaclab/isaaclab/sim/utils/transforms.py @@ -0,0 +1,450 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Utilities for working with USD transform (xform) operations. + +This module provides utilities for manipulating USD transform operations (xform ops) on prims. +Transform operations in USD define how geometry is positioned, oriented, and scaled in 3D space. + +The utilities in this module help standardize transform stacks, clear operations, and manipulate +transforms in a consistent way across different USD assets. +""" + +from __future__ import annotations + +import logging + +from pxr import Gf, Sdf, Usd, UsdGeom + +# import logger +logger = logging.getLogger(__name__) + +_INVALID_XFORM_OPS = [ + "xformOp:rotateX", + "xformOp:rotateXZY", + "xformOp:rotateY", + "xformOp:rotateYXZ", + "xformOp:rotateYZX", + "xformOp:rotateZ", + "xformOp:rotateZYX", + "xformOp:rotateZXY", + "xformOp:rotateXYZ", + "xformOp:transform", +] +"""List of invalid xform ops that should be removed.""" + + +def standardize_xform_ops( + prim: Usd.Prim, + translation: tuple[float, ...] | None = None, + orientation: tuple[float, ...] | None = None, + scale: tuple[float, ...] | None = None, +) -> bool: + """Standardize the transform operation stack on a USD prim to a canonical form. + + This function converts a prim's transform stack to use the standard USD transform operation + order: [translate, orient, scale]. The function performs the following operations: + + 1. Validates that the prim is Xformable + 2. Captures the current local transform (translation, rotation, scale) + 3. Resolves and bakes unit scale conversions (xformOp:scale:unitsResolve) + 4. Creates or reuses standard transform operations (translate, orient, scale) + 5. Sets the transform operation order to [translate, orient, scale] + 6. Applies the preserved or user-specified transform values + + The entire modification is performed within an ``Sdf.ChangeBlock`` for optimal performance + when processing multiple prims. + + .. note:: + **Standard Transform Order:** The function enforces the USD best practice order: + ``xformOp:translate``, ``xformOp:orient``, ``xformOp:scale``. This order is + compatible with most USD tools and workflows, and uses quaternions for rotation + (avoiding gimbal lock issues). + + .. note:: + **Pose Preservation:** By default, the function preserves the prim's local transform + (relative to its parent). The world-space position of the prim remains unchanged + unless explicit ``translation``, ``orientation``, or ``scale`` values are provided. + + .. warning:: + **Animation Data Loss:** This function only preserves transform values at the default + time code (``Usd.TimeCode.Default()``). Any animation or time-sampled transform data + will be lost. Use this function during asset import or preparation, not on animated prims. + + .. warning:: + **Unit Scale Resolution:** If the prim has a ``xformOp:scale:unitsResolve`` attribute + (common in imported assets with unit mismatches), it will be baked into the scale + and removed. For example, a scale of (1, 1, 1) with unitsResolve of (100, 100, 100) + becomes a final scale of (100, 100, 100). + + Args: + prim: The USD prim to standardize. Must be a valid prim that supports the + UsdGeom.Xformable schema (e.g., Xform, Mesh, Cube, etc.). Material and + Shader prims are not Xformable and will return False. + translation: Optional translation vector (x, y, z) in local space. If provided, + overrides the prim's current translation. If None, preserves the current + local translation. Defaults to None. + orientation: Optional orientation quaternion (w, x, y, z) in local space. If provided, + overrides the prim's current orientation. If None, preserves the current + local orientation. Defaults to None. + scale: Optional scale vector (x, y, z). If provided, overrides the prim's current scale. + If None, preserves the current scale (after unit resolution) or uses (1, 1, 1) + if no scale exists. Defaults to None. + + Returns: + bool: True if the transform operations were successfully standardized. False if the + prim is not Xformable (e.g., Material, Shader prims). The function will log an + error message when returning False. + + Raises: + ValueError: If the prim is not valid (i.e., does not exist or is an invalid prim). + + Example: + >>> import isaaclab.sim as sim_utils + >>> + >>> # Standardize a prim with non-standard transform operations + >>> prim = stage.GetPrimAtPath("/World/ImportedAsset") + >>> result = sim_utils.standardize_xform_ops(prim) + >>> if result: + ... print("Transform stack standardized successfully") + >>> # The prim now uses: [translate, orient, scale] in that order + >>> + >>> # Standardize and set new transform values + >>> sim_utils.standardize_xform_ops( + ... prim, + ... translation=(1.0, 2.0, 3.0), + ... orientation=(1.0, 0.0, 0.0, 0.0), # identity rotation (w, x, y, z) + ... scale=(2.0, 2.0, 2.0) + ... ) + >>> + >>> # Batch processing for performance + >>> prims_to_standardize = [stage.GetPrimAtPath(p) for p in prim_paths] + >>> for prim in prims_to_standardize: + ... sim_utils.standardize_xform_ops(prim) # Each call uses Sdf.ChangeBlock + """ + # Validate prim + if not prim.IsValid(): + raise ValueError(f"Prim at path '{prim.GetPath()}' is not valid.") + + # Check if prim is an Xformable + if not prim.IsA(UsdGeom.Xformable): + logger.error(f"Prim at path '{prim.GetPath()}' is not an Xformable.") + return False + + # Create xformable interface + xformable = UsdGeom.Xformable(prim) + # Get current property names + prop_names = prim.GetPropertyNames() + + # Obtain current local transformations + tf = Gf.Transform(xformable.GetLocalTransformation()) + xform_pos = Gf.Vec3d(tf.GetTranslation()) + xform_quat = Gf.Quatd(tf.GetRotation().GetQuat()) + xform_scale = Gf.Vec3d(tf.GetScale()) + + if translation is not None: + xform_pos = Gf.Vec3d(*translation) + if orientation is not None: + xform_quat = Gf.Quatd(*orientation) + + # Handle scale resolution + if scale is not None: + # User provided scale + xform_scale = Gf.Vec3d(scale) + elif "xformOp:scale" in prop_names: + # Handle unit resolution for scale if present + # This occurs when assets are imported with different unit scales + # Reference: Omniverse Metrics Assembler + if "xformOp:scale:unitsResolve" in prop_names: + units_resolve = prim.GetAttribute("xformOp:scale:unitsResolve").Get() + for i in range(3): + xform_scale[i] = xform_scale[i] * units_resolve[i] + else: + # No scale exists, use default uniform scale + xform_scale = Gf.Vec3d(1.0, 1.0, 1.0) + + # Verify if xform stack is reset + has_reset = xformable.GetResetXformStack() + # Batch the operations + with Sdf.ChangeBlock(): + # Clear the existing transform operation order + for prop_name in prop_names: + if prop_name in _INVALID_XFORM_OPS: + prim.RemoveProperty(prop_name) + + # Remove unitsResolve attribute if present (already handled in scale resolution above) + if "xformOp:scale:unitsResolve" in prop_names: + prim.RemoveProperty("xformOp:scale:unitsResolve") + + # Set up or retrieve scale operation + xform_op_scale = UsdGeom.XformOp(prim.GetAttribute("xformOp:scale")) + if not xform_op_scale: + xform_op_scale = xformable.AddXformOp(UsdGeom.XformOp.TypeScale, UsdGeom.XformOp.PrecisionDouble, "") + + # Set up or retrieve translate operation + xform_op_translate = UsdGeom.XformOp(prim.GetAttribute("xformOp:translate")) + if not xform_op_translate: + xform_op_translate = xformable.AddXformOp( + UsdGeom.XformOp.TypeTranslate, UsdGeom.XformOp.PrecisionDouble, "" + ) + + # Set up or retrieve orient (quaternion rotation) operation + xform_op_orient = UsdGeom.XformOp(prim.GetAttribute("xformOp:orient")) + if not xform_op_orient: + xform_op_orient = xformable.AddXformOp(UsdGeom.XformOp.TypeOrient, UsdGeom.XformOp.PrecisionDouble, "") + + # Handle different floating point precisions + # Existing Xform operations might have floating or double precision. + # We need to cast the data to the correct type to avoid setting the wrong type. + xform_ops = [xform_op_translate, xform_op_orient, xform_op_scale] + xform_values = [xform_pos, xform_quat, xform_scale] + for xform_op, value in zip(xform_ops, xform_values): + # Get current value to determine precision type + current_value = xform_op.Get() + # Cast to existing type to preserve precision (float/double) + xform_op.Set(type(current_value)(value) if current_value is not None else value) + + # Set the transform operation order: translate -> orient -> scale + # This is the standard USD convention and ensures consistent behavior + xformable.SetXformOpOrder([xform_op_translate, xform_op_orient, xform_op_scale], has_reset) + + return True + + +def validate_standard_xform_ops(prim: Usd.Prim) -> bool: + """Validate if the transform operations on a prim are standardized. + + This function checks if the transform operations on a prim are standardized to the canonical form: + [translate, orient, scale]. + + Args: + prim: The USD prim to validate. + """ + # check if prim is valid + if not prim.IsValid(): + logger.error(f"Prim at path '{prim.GetPath().pathString}' is not valid.") + return False + # check if prim is an xformable + if not prim.IsA(UsdGeom.Xformable): + logger.error(f"Prim at path '{prim.GetPath().pathString}' is not an xformable.") + return False + # get the xformable interface + xformable = UsdGeom.Xformable(prim) + # get the xform operation order + xform_op_order = xformable.GetOrderedXformOps() + xform_op_order = [op.GetOpName() for op in xform_op_order] + # check if the xform operation order is the canonical form + if xform_op_order != ["xformOp:translate", "xformOp:orient", "xformOp:scale"]: + msg = f"Xform operation order for prim at path '{prim.GetPath().pathString}' is not the canonical form." + msg += f" Received order: {xform_op_order}" + msg += " Expected order: ['xformOp:translate', 'xformOp:orient', 'xformOp:scale']" + logger.error(msg) + return False + return True + + +def resolve_prim_pose( + prim: Usd.Prim, ref_prim: Usd.Prim | None = None +) -> tuple[tuple[float, float, float], tuple[float, float, float, float]]: + """Resolve the pose of a prim with respect to another prim. + + Note: + This function ignores scale and skew by orthonormalizing the transformation + matrix at the final step. However, if any ancestor prim in the hierarchy + has non-uniform scale, that scale will still affect the resulting position + and orientation of the prim (because it's baked into the transform before + scale removal). + + In other words: scale **is not removed hierarchically**. If you need + completely scale-free poses, you must walk the transform chain and strip + scale at each level. Please open an issue if you need this functionality. + + Args: + prim: The USD prim to resolve the pose for. + ref_prim: The USD prim to compute the pose with respect to. + Defaults to None, in which case the world frame is used. + + Returns: + A tuple containing the position (as a 3D vector) and the quaternion orientation + in the (w, x, y, z) format. + + Raises: + ValueError: If the prim or ref prim is not valid. + + Example: + >>> import isaaclab.sim as sim_utils + >>> from pxr import Usd, UsdGeom + >>> + >>> # Get prim + >>> stage = sim_utils.get_current_stage() + >>> prim = stage.GetPrimAtPath("/World/ImportedAsset") + >>> + >>> # Resolve pose + >>> pos, quat = sim_utils.resolve_prim_pose(prim) + >>> print(f"Position: {pos}") + >>> print(f"Orientation: {quat}") + >>> + >>> # Resolve pose with respect to another prim + >>> ref_prim = stage.GetPrimAtPath("/World/Reference") + >>> pos, quat = sim_utils.resolve_prim_pose(prim, ref_prim) + >>> print(f"Position: {pos}") + >>> print(f"Orientation: {quat}") + """ + # check if prim is valid + if not prim.IsValid(): + raise ValueError(f"Prim at path '{prim.GetPath().pathString}' is not valid.") + # get prim xform + xform = UsdGeom.Xformable(prim) + prim_tf = xform.ComputeLocalToWorldTransform(Usd.TimeCode.Default()) + # sanitize quaternion + # this is needed, otherwise the quaternion might be non-normalized + prim_tf.Orthonormalize() + + if ref_prim is not None: + # if reference prim is the root, we can skip the computation + if ref_prim.GetPath() != Sdf.Path.absoluteRootPath: + # get ref prim xform + ref_xform = UsdGeom.Xformable(ref_prim) + ref_tf = ref_xform.ComputeLocalToWorldTransform(Usd.TimeCode.Default()) + # make sure ref tf is orthonormal + ref_tf.Orthonormalize() + # compute relative transform to get prim in ref frame + prim_tf = prim_tf * ref_tf.GetInverse() + + # extract position and orientation + prim_pos = [*prim_tf.ExtractTranslation()] + prim_quat = [prim_tf.ExtractRotationQuat().real, *prim_tf.ExtractRotationQuat().imaginary] + return tuple(prim_pos), tuple(prim_quat) + + +def resolve_prim_scale(prim: Usd.Prim) -> tuple[float, float, float]: + """Resolve the scale of a prim in the world frame. + + At an attribute level, a USD prim's scale is a scaling transformation applied to the prim with + respect to its parent prim. This function resolves the scale of the prim in the world frame, + by computing the local to world transform of the prim. This is equivalent to traversing up + the prim hierarchy and accounting for the rotations and scales of the prims. + + For instance, if a prim has a scale of (1, 2, 3) and it is a child of a prim with a scale of (4, 5, 6), + then the scale of the prim in the world frame is (4, 10, 18). + + Args: + prim: The USD prim to resolve the scale for. + + Returns: + The scale of the prim in the x, y, and z directions in the world frame. + + Raises: + ValueError: If the prim is not valid. + + Example: + >>> import isaaclab.sim as sim_utils + >>> from pxr import Usd, UsdGeom + >>> + >>> # Get prim + >>> stage = sim_utils.get_current_stage() + >>> prim = stage.GetPrimAtPath("/World/ImportedAsset") + >>> + >>> # Resolve scale + >>> scale = sim_utils.resolve_prim_scale(prim) + >>> print(f"Scale: {scale}") + """ + # check if prim is valid + if not prim.IsValid(): + raise ValueError(f"Prim at path '{prim.GetPath().pathString}' is not valid.") + # compute local to world transform + xform = UsdGeom.Xformable(prim) + world_transform = xform.ComputeLocalToWorldTransform(Usd.TimeCode.Default()) + # extract scale + return tuple([*(v.GetLength() for v in world_transform.ExtractRotationMatrix())]) + + +def convert_world_pose_to_local( + position: tuple[float, ...], + orientation: tuple[float, ...] | None, + ref_prim: Usd.Prim, +) -> tuple[tuple[float, float, float], tuple[float, float, float, float] | None]: + """Convert a world-space pose to local-space pose relative to a reference prim. + + This function takes a position and orientation in world space and converts them to local space + relative to the given reference prim. This is useful when creating or positioning prims where you + know the desired world position but need to set local transform attributes relative to another prim. + + The conversion uses the standard USD transformation math: + ``local_transform = world_transform * inverse(ref_world_transform)`` + + .. note:: + If the reference prim is the root prim ("/"), the position and orientation are returned + unchanged, as they are already effectively in local/world space. + + Args: + position: The world-space position as (x, y, z). + orientation: The world-space orientation as quaternion (w, x, y, z). If None, only position is converted + and None is returned for orientation. + ref_prim: The reference USD prim to compute the local transform relative to. If this is + the root prim ("/"), the world pose is returned unchanged. + + Returns: + A tuple of (local_translation, local_orientation) where: + + - local_translation is a tuple of (x, y, z) in local space relative to ref_prim + - local_orientation is a tuple of (w, x, y, z) in local space relative to ref_prim, or None if no orientation was provided + + Raises: + ValueError: If the reference prim is not a valid USD prim. + + Example: + >>> import isaaclab.sim as sim_utils + >>> from pxr import Usd, UsdGeom + >>> + >>> # Get reference prim + >>> stage = sim_utils.get_current_stage() + >>> ref_prim = stage.GetPrimAtPath("/World/Reference") + >>> + >>> # Convert world pose to local (relative to ref_prim) + >>> world_pos = (10.0, 5.0, 0.0) + >>> world_quat = (1.0, 0.0, 0.0, 0.0) # identity rotation + >>> local_pos, local_quat = sim_utils.convert_world_pose_to_local( + ... world_pos, world_quat, ref_prim + ... ) + >>> print(f"Local position: {local_pos}") + >>> print(f"Local orientation: {local_quat}") + """ + # Check if prim is valid + if not ref_prim.IsValid(): + raise ValueError(f"Reference prim at path '{ref_prim.GetPath().pathString}' is not valid.") + + # If reference prim is the root, return world pose as-is + if ref_prim.GetPath() == Sdf.Path.absoluteRootPath: + return position, orientation # type: ignore + + # Check if reference prim is a valid xformable + ref_xformable = UsdGeom.Xformable(ref_prim) + # Get reference prim's world transform + ref_world_tf = ref_xformable.ComputeLocalToWorldTransform(Usd.TimeCode.Default()) + + # Create world transform for the desired position and orientation + desired_world_tf = Gf.Matrix4d() + desired_world_tf.SetTranslateOnly(Gf.Vec3d(*position)) + + if orientation is not None: + # Set rotation from quaternion (w, x, y, z) + quat = Gf.Quatd(*orientation) + desired_world_tf.SetRotateOnly(quat) + + # Convert world transform to local: local = world * inv(ref_world) + ref_world_tf_inv = ref_world_tf.GetInverse() + local_tf = desired_world_tf * ref_world_tf_inv + + # Extract local translation and orientation + local_transform = Gf.Transform(local_tf) + local_translation = tuple(local_transform.GetTranslation()) + + local_orientation = None + if orientation is not None: + quat_result = local_transform.GetRotation().GetQuat() + local_orientation = (quat_result.GetReal(), *quat_result.GetImaginary()) + + return local_translation, local_orientation diff --git a/source/isaaclab/test/sim/test_utils_prims.py b/source/isaaclab/test/sim/test_utils_prims.py index 0cc169081b2..0dcb71b55fc 100644 --- a/source/isaaclab/test/sim/test_utils_prims.py +++ b/source/isaaclab/test/sim/test_utils_prims.py @@ -21,7 +21,7 @@ from pxr import Gf, Sdf, Usd, UsdGeom import isaaclab.sim as sim_utils -import isaaclab.utils.math as math_utils +from isaaclab.sim.utils.prims import _to_tuple # type: ignore[reportPrivateUsage] from isaaclab.utils.assets import ISAACLAB_NUCLEUS_DIR @@ -118,6 +118,137 @@ def test_create_prim(): assert op_names == ["xformOp:translate", "xformOp:orient", "xformOp:scale"] +@pytest.mark.parametrize( + "input_type", + ["list", "tuple", "numpy", "torch_cpu", "torch_cuda"], + ids=["list", "tuple", "numpy", "torch_cpu", "torch_cuda"], +) +def test_create_prim_with_different_input_types(input_type: str): + """Test create_prim() with different input types (list, tuple, numpy array, torch tensor).""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Define test values + translation_vals = [1.0, 2.0, 3.0] + orientation_vals = [1.0, 0.0, 0.0, 0.0] # w, x, y, z + scale_vals = [2.0, 3.0, 4.0] + + # Convert to the specified input type + if input_type == "list": + translation = translation_vals + orientation = orientation_vals + scale = scale_vals + elif input_type == "tuple": + translation = tuple(translation_vals) + orientation = tuple(orientation_vals) + scale = tuple(scale_vals) + elif input_type == "numpy": + translation = np.array(translation_vals) + orientation = np.array(orientation_vals) + scale = np.array(scale_vals) + elif input_type == "torch_cpu": + translation = torch.tensor(translation_vals) + orientation = torch.tensor(orientation_vals) + scale = torch.tensor(scale_vals) + elif input_type == "torch_cuda": + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + translation = torch.tensor(translation_vals, device="cuda") + orientation = torch.tensor(orientation_vals, device="cuda") + scale = torch.tensor(scale_vals, device="cuda") + + # Create prim with translation (local space) + prim = sim_utils.create_prim( + f"/World/Test/Xform_{input_type}", + "Xform", + stage=stage, + translation=translation, + orientation=orientation, + scale=scale, + ) + + # Verify prim was created correctly + assert prim.IsValid() + assert prim.GetPrimPath() == f"/World/Test/Xform_{input_type}" + + # Verify transform values + assert prim.GetAttribute("xformOp:translate").Get() == Gf.Vec3d(*translation_vals) + assert_quat_close(prim.GetAttribute("xformOp:orient").Get(), Gf.Quatd(*orientation_vals)) + assert prim.GetAttribute("xformOp:scale").Get() == Gf.Vec3d(*scale_vals) + + # Verify xform operation order + op_names = [op.GetOpName() for op in UsdGeom.Xformable(prim).GetOrderedXformOps()] + assert op_names == ["xformOp:translate", "xformOp:orient", "xformOp:scale"] + + +@pytest.mark.parametrize( + "input_type", + ["list", "tuple", "numpy", "torch_cpu", "torch_cuda"], + ids=["list", "tuple", "numpy", "torch_cpu", "torch_cuda"], +) +def test_create_prim_with_world_position_different_types(input_type: str): + """Test create_prim() with world position using different input types.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create a parent prim + _ = sim_utils.create_prim( + "/World/Parent", + "Xform", + stage=stage, + translation=(5.0, 10.0, 15.0), + orientation=(1.0, 0.0, 0.0, 0.0), + ) + + # Define world position and orientation values + world_pos_vals = [10.0, 20.0, 30.0] + world_orient_vals = [0.7071068, 0.0, 0.7071068, 0.0] # 90 deg around Y + + # Convert to the specified input type + if input_type == "list": + world_pos = world_pos_vals + world_orient = world_orient_vals + elif input_type == "tuple": + world_pos = tuple(world_pos_vals) + world_orient = tuple(world_orient_vals) + elif input_type == "numpy": + world_pos = np.array(world_pos_vals) + world_orient = np.array(world_orient_vals) + elif input_type == "torch_cpu": + world_pos = torch.tensor(world_pos_vals) + world_orient = torch.tensor(world_orient_vals) + elif input_type == "torch_cuda": + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + world_pos = torch.tensor(world_pos_vals, device="cuda") + world_orient = torch.tensor(world_orient_vals, device="cuda") + + # Create child prim with world position + child = sim_utils.create_prim( + f"/World/Parent/Child_{input_type}", + "Xform", + stage=stage, + position=world_pos, # Using position (world space) + orientation=world_orient, + ) + + # Verify prim was created + assert child.IsValid() + + # Verify world pose matches what we specified + world_pose = sim_utils.resolve_prim_pose(child) + pos_result, quat_result = world_pose + + # Check position (should be close to world_pos_vals) + for i in range(3): + assert math.isclose(pos_result[i], world_pos_vals[i], abs_tol=1e-4) + + # Check orientation (quaternions may have sign flipped) + quat_match = all(math.isclose(quat_result[i], world_orient_vals[i], abs_tol=1e-4) for i in range(4)) + quat_match_neg = all(math.isclose(quat_result[i], -world_orient_vals[i], abs_tol=1e-4) for i in range(4)) + assert quat_match or quat_match_neg + + def test_delete_prim(): """Test delete_prim() function.""" # obtain stage handle @@ -189,175 +320,32 @@ def test_move_prim(): """ -USD Prim properties and attributes. +USD references and variants. """ -def test_resolve_prim_pose(): - """Test resolve_prim_pose() function.""" - # number of objects - num_objects = 20 - # sample random scales for x, y, z - rand_scales = np.random.uniform(0.5, 1.5, size=(num_objects, 3, 3)) - rand_widths = np.random.uniform(0.1, 10.0, size=(num_objects,)) - # sample random positions - rand_positions = np.random.uniform(-100, 100, size=(num_objects, 3, 3)) - # sample random rotations - rand_quats = np.random.randn(num_objects, 3, 4) - rand_quats /= np.linalg.norm(rand_quats, axis=2, keepdims=True) - - # create objects - for i in range(num_objects): - # simple cubes - cube_prim = sim_utils.create_prim( - f"/World/Cubes/instance_{i:02d}", - "Cube", - translation=rand_positions[i, 0], - orientation=rand_quats[i, 0], - scale=rand_scales[i, 0], - attributes={"size": rand_widths[i]}, - ) - # xform hierarchy - xform_prim = sim_utils.create_prim( - f"/World/Xform/instance_{i:02d}", - "Xform", - translation=rand_positions[i, 1], - orientation=rand_quats[i, 1], - scale=rand_scales[i, 1], - ) - geometry_prim = sim_utils.create_prim( - f"/World/Xform/instance_{i:02d}/geometry", - "Sphere", - translation=rand_positions[i, 2], - orientation=rand_quats[i, 2], - scale=rand_scales[i, 2], - attributes={"radius": rand_widths[i]}, - ) - dummy_prim = sim_utils.create_prim( - f"/World/Xform/instance_{i:02d}/dummy", - "Sphere", - ) - - # cube prim w.r.t. world frame - pos, quat = sim_utils.resolve_prim_pose(cube_prim) - pos, quat = np.array(pos), np.array(quat) - quat = quat if np.sign(rand_quats[i, 0, 0]) == np.sign(quat[0]) else -quat - np.testing.assert_allclose(pos, rand_positions[i, 0], atol=1e-3) - np.testing.assert_allclose(quat, rand_quats[i, 0], atol=1e-3) - # xform prim w.r.t. world frame - pos, quat = sim_utils.resolve_prim_pose(xform_prim) - pos, quat = np.array(pos), np.array(quat) - quat = quat if np.sign(rand_quats[i, 1, 0]) == np.sign(quat[0]) else -quat - np.testing.assert_allclose(pos, rand_positions[i, 1], atol=1e-3) - np.testing.assert_allclose(quat, rand_quats[i, 1], atol=1e-3) - # dummy prim w.r.t. world frame - pos, quat = sim_utils.resolve_prim_pose(dummy_prim) - pos, quat = np.array(pos), np.array(quat) - quat = quat if np.sign(rand_quats[i, 1, 0]) == np.sign(quat[0]) else -quat - np.testing.assert_allclose(pos, rand_positions[i, 1], atol=1e-3) - np.testing.assert_allclose(quat, rand_quats[i, 1], atol=1e-3) - - # geometry prim w.r.t. xform prim - pos, quat = sim_utils.resolve_prim_pose(geometry_prim, ref_prim=xform_prim) - pos, quat = np.array(pos), np.array(quat) - quat = quat if np.sign(rand_quats[i, 2, 0]) == np.sign(quat[0]) else -quat - np.testing.assert_allclose(pos, rand_positions[i, 2] * rand_scales[i, 1], atol=1e-3) - # TODO: Enabling scale causes the test to fail because the current implementation of - # resolve_prim_pose does not correctly handle non-identity scales on Xform prims. This is a known - # limitation. Until this is fixed, the test is disabled here to ensure the test passes. - # np.testing.assert_allclose(quat, rand_quats[i, 2], atol=1e-3) - - # dummy prim w.r.t. xform prim - pos, quat = sim_utils.resolve_prim_pose(dummy_prim, ref_prim=xform_prim) - pos, quat = np.array(pos), np.array(quat) - np.testing.assert_allclose(pos, np.zeros(3), atol=1e-3) - np.testing.assert_allclose(quat, np.array([1, 0, 0, 0]), atol=1e-3) - # xform prim w.r.t. cube prim - pos, quat = sim_utils.resolve_prim_pose(xform_prim, ref_prim=cube_prim) - pos, quat = np.array(pos), np.array(quat) - # -- compute ground truth values - gt_pos, gt_quat = math_utils.subtract_frame_transforms( - torch.from_numpy(rand_positions[i, 0]).unsqueeze(0), - torch.from_numpy(rand_quats[i, 0]).unsqueeze(0), - torch.from_numpy(rand_positions[i, 1]).unsqueeze(0), - torch.from_numpy(rand_quats[i, 1]).unsqueeze(0), - ) - gt_pos, gt_quat = gt_pos.squeeze(0).numpy(), gt_quat.squeeze(0).numpy() - quat = quat if np.sign(gt_quat[0]) == np.sign(quat[0]) else -quat - np.testing.assert_allclose(pos, gt_pos, atol=1e-3) - np.testing.assert_allclose(quat, gt_quat, atol=1e-3) - - -def test_resolve_prim_scale(): - """Test resolve_prim_scale() function. - - To simplify the test, we assume that the effective scale at a prim - is the product of the scales of the prims in the hierarchy: - - scale = scale_of_xform * scale_of_geometry_prim - - This is only true when rotations are identity or the transforms are - orthogonal and uniformly scaled. Otherwise, scale is not composable - like that in local component-wise fashion. - """ - # number of objects - num_objects = 20 - # sample random scales for x, y, z - rand_scales = np.random.uniform(0.5, 1.5, size=(num_objects, 3, 3)) - rand_widths = np.random.uniform(0.1, 10.0, size=(num_objects,)) - # sample random positions - rand_positions = np.random.uniform(-100, 100, size=(num_objects, 3, 3)) - - # create objects - for i in range(num_objects): - # simple cubes - cube_prim = sim_utils.create_prim( - f"/World/Cubes/instance_{i:02d}", - "Cube", - translation=rand_positions[i, 0], - scale=rand_scales[i, 0], - attributes={"size": rand_widths[i]}, - ) - # xform hierarchy - xform_prim = sim_utils.create_prim( - f"/World/Xform/instance_{i:02d}", - "Xform", - translation=rand_positions[i, 1], - scale=rand_scales[i, 1], - ) - geometry_prim = sim_utils.create_prim( - f"/World/Xform/instance_{i:02d}/geometry", - "Sphere", - translation=rand_positions[i, 2], - scale=rand_scales[i, 2], - attributes={"radius": rand_widths[i]}, - ) - dummy_prim = sim_utils.create_prim( - f"/World/Xform/instance_{i:02d}/dummy", - "Sphere", - ) - - # cube prim - scale = sim_utils.resolve_prim_scale(cube_prim) - scale = np.array(scale) - np.testing.assert_allclose(scale, rand_scales[i, 0], atol=1e-5) - # xform prim - scale = sim_utils.resolve_prim_scale(xform_prim) - scale = np.array(scale) - np.testing.assert_allclose(scale, rand_scales[i, 1], atol=1e-5) - # geometry prim - scale = sim_utils.resolve_prim_scale(geometry_prim) - scale = np.array(scale) - np.testing.assert_allclose(scale, rand_scales[i, 1] * rand_scales[i, 2], atol=1e-5) - # dummy prim - scale = sim_utils.resolve_prim_scale(dummy_prim) - scale = np.array(scale) - np.testing.assert_allclose(scale, rand_scales[i, 1], atol=1e-5) +def test_get_usd_references(): + """Test get_usd_references() function.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + # Create a prim without USD reference + sim_utils.create_prim("/World/NoReference", "Xform", stage=stage) + # Check that it has no references + refs = sim_utils.get_usd_references("/World/NoReference", stage=stage) + assert len(refs) == 0 -""" -USD references and variants. -""" + # Create a prim with a USD reference + franka_usd = f"{ISAACLAB_NUCLEUS_DIR}/Robots/FrankaEmika/panda_instanceable.usd" + sim_utils.create_prim("/World/WithReference", usd_path=franka_usd, stage=stage) + # Check that it has the expected reference + refs = sim_utils.get_usd_references("/World/WithReference", stage=stage) + assert len(refs) == 1 + assert refs == [franka_usd] + + # Test with invalid prim path + with pytest.raises(ValueError, match="not valid"): + sim_utils.get_usd_references("/World/NonExistent", stage=stage) def test_select_usd_variants(): @@ -377,3 +365,78 @@ def test_select_usd_variants(): # Check if the variant selection is correct assert variant_set.GetVariantSelection() == "red" + + +""" +Internal Helpers. +""" + + +def test_to_tuple_basic(): + """Test _to_tuple() with basic input types.""" + # Test with list + result = _to_tuple([1.0, 2.0, 3.0]) + assert result == (1.0, 2.0, 3.0) + assert isinstance(result, tuple) + + # Test with tuple + result = _to_tuple((1.0, 2.0, 3.0)) + assert result == (1.0, 2.0, 3.0) + + # Test with numpy array + result = _to_tuple(np.array([1.0, 2.0, 3.0])) + assert result == (1.0, 2.0, 3.0) + + # Test with torch tensor (CPU) + result = _to_tuple(torch.tensor([1.0, 2.0, 3.0])) + assert result == (1.0, 2.0, 3.0) + + # Test squeezing first dimension (batch size 1) + result = _to_tuple(torch.tensor([[1.0, 2.0]])) + assert result == (1.0, 2.0) + + result = _to_tuple(np.array([[1.0, 2.0, 3.0]])) + assert result == (1.0, 2.0, 3.0) + + +def test_to_tuple_raises_error(): + """Test _to_tuple() raises an error for N-dimensional arrays.""" + + with pytest.raises(ValueError, match="not one dimensional"): + _to_tuple(np.array([[1.0, 2.0], [3.0, 4.0]])) + + with pytest.raises(ValueError, match="not one dimensional"): + _to_tuple(torch.tensor([[[1.0, 2.0]], [[3.0, 4.0]]])) + + with pytest.raises(ValueError, match="only one element tensors can be converted"): + _to_tuple((torch.tensor([1.0, 2.0]), 3.0)) + + +def test_to_tuple_mixed_sequences(): + """Test _to_tuple() with mixed type sequences.""" + + # Mixed list with numpy and floats + result = _to_tuple([np.float32(1.0), 2.0, 3.0]) + assert len(result) == 3 + assert all(isinstance(x, float) for x in result) + + # Mixed tuple with torch tensor items and floats + result = _to_tuple([torch.tensor(1.0), 2.0, 3.0]) + assert result == (1.0, 2.0, 3.0) + + # Mixed tuple with numpy array items and torch tensor + result = _to_tuple((np.float32(1.0), 2.0, torch.tensor(3.0))) + assert result == (1.0, 2.0, 3.0) + + +def test_to_tuple_precision(): + """Test _to_tuple() maintains numerical precision.""" + from isaaclab.sim.utils.prims import _to_tuple + + # Test with high precision values + high_precision = [1.123456789, 2.987654321, 3.141592653] + result = _to_tuple(torch.tensor(high_precision, dtype=torch.float64)) + + # Check that precision is maintained reasonably well + for i, val in enumerate(high_precision): + assert math.isclose(result[i], val, abs_tol=1e-6) diff --git a/source/isaaclab/test/sim/test_utils_transforms.py b/source/isaaclab/test/sim/test_utils_transforms.py new file mode 100644 index 00000000000..c3cbdbd3f74 --- /dev/null +++ b/source/isaaclab/test/sim/test_utils_transforms.py @@ -0,0 +1,1411 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Launch Isaac Sim Simulator first.""" + +from isaaclab.app import AppLauncher + +# launch omniverse app +simulation_app = AppLauncher(headless=True).app + +"""Rest everything follows.""" + +import math +import numpy as np +import torch + +import pytest +from pxr import Gf, Sdf, Usd, UsdGeom + +import isaaclab.sim as sim_utils +import isaaclab.utils.math as math_utils + + +@pytest.fixture(autouse=True) +def test_setup_teardown(): + """Create a blank new stage for each test.""" + # Setup: Create a new stage + sim_utils.create_new_stage() + sim_utils.update_stage() + + # Yield for the test + yield + + # Teardown: Clear stage after each test + sim_utils.clear_stage() + + +def assert_vec3_close(v1: Gf.Vec3d | Gf.Vec3f, v2: tuple | Gf.Vec3d | Gf.Vec3f, eps: float = 1e-6): + """Assert two 3D vectors are close.""" + if isinstance(v2, tuple): + v2 = Gf.Vec3d(*v2) + for i in range(3): + assert math.isclose(v1[i], v2[i], abs_tol=eps), f"Vector mismatch at index {i}: {v1[i]} != {v2[i]}" + + +def assert_quat_close(q1: Gf.Quatf | Gf.Quatd, q2: Gf.Quatf | Gf.Quatd | tuple, eps: float = 1e-6): + """Assert two quaternions are close, accounting for double-cover (q and -q represent same rotation).""" + if isinstance(q2, tuple): + q2 = Gf.Quatd(*q2) + # Check if quaternions are close (either q1 ≈ q2 or q1 ≈ -q2) + real_match = math.isclose(q1.GetReal(), q2.GetReal(), abs_tol=eps) + imag_match = all(math.isclose(q1.GetImaginary()[i], q2.GetImaginary()[i], abs_tol=eps) for i in range(3)) + + real_match_neg = math.isclose(q1.GetReal(), -q2.GetReal(), abs_tol=eps) + imag_match_neg = all(math.isclose(q1.GetImaginary()[i], -q2.GetImaginary()[i], abs_tol=eps) for i in range(3)) + + assert (real_match and imag_match) or ( + real_match_neg and imag_match_neg + ), f"Quaternion mismatch: {q1} != {q2} (and not equal to negative either)" + + +def get_xform_ops(prim: Usd.Prim) -> list[str]: + """Get the ordered list of xform operation names for a prim.""" + xformable = UsdGeom.Xformable(prim) + return [op.GetOpName() for op in xformable.GetOrderedXformOps()] + + +""" +Test standardize_xform_ops() function. +""" + + +def test_standardize_xform_ops_basic(): + """Test basic functionality of standardize_xform_ops on a simple prim.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create a simple xform prim with standard operations + prim = sim_utils.create_prim( + "/World/TestXform", + "Xform", + translation=(1.0, 2.0, 3.0), + orientation=(1.0, 0.0, 0.0, 0.0), # w, x, y, z + scale=(1.0, 1.0, 1.0), + stage=stage, + ) + + # Apply standardize_xform_ops + result = sim_utils.standardize_xform_ops(prim) + + # Verify the operation succeeded + assert result is True + assert prim.IsValid() + + # Check that the xform operations are in the correct order + xform_ops = get_xform_ops(prim) + assert xform_ops == [ + "xformOp:translate", + "xformOp:orient", + "xformOp:scale", + ], f"Expected standard xform order, got {xform_ops}" + + # Verify the transform values are preserved (approximately) + assert_vec3_close(prim.GetAttribute("xformOp:translate").Get(), (1.0, 2.0, 3.0)) + assert_quat_close(prim.GetAttribute("xformOp:orient").Get(), (1.0, 0.0, 0.0, 0.0)) + assert_vec3_close(prim.GetAttribute("xformOp:scale").Get(), (1.0, 1.0, 1.0)) + + +def test_standardize_xform_ops_with_rotation_xyz(): + """Test standardize_xform_ops removes deprecated rotateXYZ operations.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create a prim and manually add deprecated rotation operations + prim_path = "/World/TestRotateXYZ" + prim = stage.DefinePrim(prim_path, "Xform") + xformable = UsdGeom.Xformable(prim) + # Add deprecated rotateXYZ operation + rotate_xyz_op = xformable.AddRotateXYZOp(UsdGeom.XformOp.PrecisionDouble) + rotate_xyz_op.Set(Gf.Vec3d(45.0, 30.0, 60.0)) + # Add translate operation + translate_op = xformable.AddTranslateOp(UsdGeom.XformOp.PrecisionDouble) + translate_op.Set(Gf.Vec3d(1.0, 2.0, 3.0)) + + # Verify the deprecated operation exists + assert "xformOp:rotateXYZ" in prim.GetPropertyNames() + + # Get pose before standardization + pos_before, quat_before = sim_utils.resolve_prim_pose(prim) + + # Apply standardize_xform_ops + result = sim_utils.standardize_xform_ops(prim) + assert result is True + + # Get pose after standardization + pos_after, quat_after = sim_utils.resolve_prim_pose(prim) + # Verify world pose is preserved (may have small numeric differences due to rotation conversion) + assert_vec3_close(Gf.Vec3d(*pos_before), pos_after, eps=1e-4) + assert_quat_close(Gf.Quatd(*quat_before), quat_after, eps=1e-4) + + # Verify the deprecated operation is removed + assert "xformOp:rotateXYZ" not in prim.GetPropertyNames() + # Verify standard operations exist + assert "xformOp:translate" in prim.GetPropertyNames() + assert "xformOp:orient" in prim.GetPropertyNames() + assert "xformOp:scale" in prim.GetPropertyNames() + # Check the xform operation order + xform_ops = get_xform_ops(prim) + assert xform_ops == ["xformOp:translate", "xformOp:orient", "xformOp:scale"] + + +def test_standardize_xform_ops_with_transform_matrix(): + """Test standardize_xform_ops removes transform matrix operations.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create a prim with a transform matrix + prim_path = "/World/TestTransformMatrix" + prim = stage.DefinePrim(prim_path, "Xform") + xformable = UsdGeom.Xformable(prim) + + # Add transform matrix operation + transform_op = xformable.AddTransformOp(UsdGeom.XformOp.PrecisionDouble) + # Create a simple translation matrix + matrix = Gf.Matrix4d().SetTranslate(Gf.Vec3d(5.0, 10.0, 15.0)) + transform_op.Set(matrix) + + # Verify the transform operation exists + assert "xformOp:transform" in prim.GetPropertyNames() + + # Get pose before standardization + pos_before, quat_before = sim_utils.resolve_prim_pose(prim) + + # Apply standardize_xform_ops + result = sim_utils.standardize_xform_ops(prim) + assert result is True + + # Get pose after standardization + pos_after, quat_after = sim_utils.resolve_prim_pose(prim) + # Verify world pose is preserved + assert_vec3_close(Gf.Vec3d(*pos_before), pos_after, eps=1e-5) + assert_quat_close(Gf.Quatd(*quat_before), quat_after, eps=1e-5) + + # Verify the transform operation is removed + assert "xformOp:transform" not in prim.GetPropertyNames() + # Verify standard operations exist + assert "xformOp:translate" in prim.GetPropertyNames() + assert "xformOp:orient" in prim.GetPropertyNames() + assert "xformOp:scale" in prim.GetPropertyNames() + + +def test_standardize_xform_ops_preserves_world_pose(): + """Test that standardize_xform_ops preserves the world-space pose of the prim.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create a prim with specific world pose + translation = (10.0, 20.0, 30.0) + # Rotation of 90 degrees around Z axis + orientation = (0.7071068, 0.0, 0.0, 0.7071068) # w, x, y, z + scale = (2.0, 3.0, 4.0) + + prim = sim_utils.create_prim( + "/World/TestPreservePose", + "Xform", + translation=translation, + orientation=orientation, + scale=scale, + stage=stage, + ) + + # Get the world pose before standardization + pos_before, quat_before = sim_utils.resolve_prim_pose(prim) + + # Apply standardize_xform_ops + result = sim_utils.standardize_xform_ops(prim) + assert result is True + + # Get the world pose after standardization + pos_after, quat_after = sim_utils.resolve_prim_pose(prim) + # Verify the world pose is preserved + assert_vec3_close(Gf.Vec3d(*pos_before), pos_after, eps=1e-5) + assert_quat_close(Gf.Quatd(*quat_before), quat_after, eps=1e-5) + + +def test_standardize_xform_ops_with_units_resolve(): + """Test standardize_xform_ops handles scale:unitsResolve attribute.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create a prim + prim_path = "/World/TestUnitsResolve" + prim = stage.DefinePrim(prim_path, "Xform") + xformable = UsdGeom.Xformable(prim) + + # Add scale operation + scale_op = xformable.AddScaleOp(UsdGeom.XformOp.PrecisionDouble) + scale_op.Set(Gf.Vec3d(1.0, 1.0, 1.0)) + + # Manually add a unitsResolve scale attribute (simulating imported asset with different units) + units_resolve_attr = prim.CreateAttribute("xformOp:scale:unitsResolve", Sdf.ValueTypeNames.Double3) + units_resolve_attr.Set(Gf.Vec3d(100.0, 100.0, 100.0)) # e.g., cm to m conversion + + # Verify both attributes exist + assert "xformOp:scale" in prim.GetPropertyNames() + assert "xformOp:scale:unitsResolve" in prim.GetPropertyNames() + + # Get pose before standardization + pos_before, quat_before = sim_utils.resolve_prim_pose(prim) + + # Apply standardize_xform_ops + result = sim_utils.standardize_xform_ops(prim) + assert result is True + + # Get pose after standardization + pos_after, quat_after = sim_utils.resolve_prim_pose(prim) + # Verify pose is preserved + assert_vec3_close(Gf.Vec3d(*pos_before), pos_after, eps=1e-5) + assert_quat_close(Gf.Quatd(*quat_before), quat_after, eps=1e-5) + + # Verify unitsResolve is removed + assert "xformOp:scale:unitsResolve" not in prim.GetPropertyNames() + + # Verify scale is updated (1.0 * 100.0 = 100.0) + scale = prim.GetAttribute("xformOp:scale").Get() + assert_vec3_close(scale, (100.0, 100.0, 100.0)) + + +def test_standardize_xform_ops_with_hierarchy(): + """Test standardize_xform_ops works correctly with prim hierarchies.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create parent prim + parent_prim = sim_utils.create_prim( + "/World/Parent", + "Xform", + translation=(5.0, 0.0, 0.0), + orientation=(1.0, 0.0, 0.0, 0.0), + scale=(2.0, 2.0, 2.0), + stage=stage, + ) + + # Create child prim + child_prim = sim_utils.create_prim( + "/World/Parent/Child", + "Xform", + translation=(0.0, 3.0, 0.0), + orientation=(0.7071068, 0.0, 0.7071068, 0.0), # 90 deg around Y + scale=(0.5, 0.5, 0.5), + stage=stage, + ) + + # Get world poses before standardization + parent_pos_before, parent_quat_before = sim_utils.resolve_prim_pose(parent_prim) + child_pos_before, child_quat_before = sim_utils.resolve_prim_pose(child_prim) + + # Apply standardize_xform_ops to both + sim_utils.standardize_xform_ops(parent_prim) + sim_utils.standardize_xform_ops(child_prim) + + # Get world poses after standardization + parent_pos_after, parent_quat_after = sim_utils.resolve_prim_pose(parent_prim) + child_pos_after, child_quat_after = sim_utils.resolve_prim_pose(child_prim) + + # Verify world poses are preserved + assert_vec3_close(Gf.Vec3d(*parent_pos_before), parent_pos_after, eps=1e-5) + assert_quat_close(Gf.Quatd(*parent_quat_before), parent_quat_after, eps=1e-5) + assert_vec3_close(Gf.Vec3d(*child_pos_before), child_pos_after, eps=1e-5) + assert_quat_close(Gf.Quatd(*child_quat_before), child_quat_after, eps=1e-5) + + +def test_standardize_xform_ops_multiple_deprecated_ops(): + """Test standardize_xform_ops removes multiple deprecated operations.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create a prim with multiple deprecated operations + prim_path = "/World/TestMultipleDeprecated" + prim = stage.DefinePrim(prim_path, "Xform") + xformable = UsdGeom.Xformable(prim) + + # Add various deprecated rotation operations + rotate_x_op = xformable.AddRotateXOp(UsdGeom.XformOp.PrecisionDouble) + rotate_x_op.Set(45.0) + rotate_y_op = xformable.AddRotateYOp(UsdGeom.XformOp.PrecisionDouble) + rotate_y_op.Set(30.0) + rotate_z_op = xformable.AddRotateZOp(UsdGeom.XformOp.PrecisionDouble) + rotate_z_op.Set(60.0) + + # Verify deprecated operations exist + assert "xformOp:rotateX" in prim.GetPropertyNames() + assert "xformOp:rotateY" in prim.GetPropertyNames() + assert "xformOp:rotateZ" in prim.GetPropertyNames() + + # Obtain current local transformations + pos, quat = sim_utils.resolve_prim_pose(prim) + + # Apply standardize_xform_ops + sim_utils.standardize_xform_ops(prim) + + # Obtain current local transformations + pos_after, quat_after = sim_utils.resolve_prim_pose(prim) + # Verify world pose is preserved + assert_vec3_close(Gf.Vec3d(*pos), Gf.Vec3d(*pos_after), eps=1e-5) + assert_quat_close(Gf.Quatd(*quat), Gf.Quatd(*quat_after), eps=1e-5) + + # Verify all deprecated operations are removed + assert "xformOp:rotateX" not in prim.GetPropertyNames() + assert "xformOp:rotateY" not in prim.GetPropertyNames() + assert "xformOp:rotateZ" not in prim.GetPropertyNames() + # Verify standard operations exist + xform_ops = get_xform_ops(prim) + assert xform_ops == ["xformOp:translate", "xformOp:orient", "xformOp:scale"] + + +def test_standardize_xform_ops_with_existing_standard_ops(): + """Test standardize_xform_ops when prim already has standard operations.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create a prim with standard operations already in place + prim = sim_utils.create_prim( + "/World/TestExistingStandard", + "Xform", + translation=(7.0, 8.0, 9.0), + orientation=(0.9238795, 0.3826834, 0.0, 0.0), # rotation around X + scale=(1.5, 2.5, 3.5), + stage=stage, + ) + + # Get initial values + initial_translate = prim.GetAttribute("xformOp:translate").Get() + initial_orient = prim.GetAttribute("xformOp:orient").Get() + initial_scale = prim.GetAttribute("xformOp:scale").Get() + + # Get world pose before standardization + pos_before, quat_before = sim_utils.resolve_prim_pose(prim) + + # Apply standardize_xform_ops + result = sim_utils.standardize_xform_ops(prim) + assert result is True + + # Get world pose after standardization + pos_after, quat_after = sim_utils.resolve_prim_pose(prim) + # Verify world pose is preserved + assert_vec3_close(Gf.Vec3d(*pos_before), pos_after, eps=1e-5) + assert_quat_close(Gf.Quatd(*quat_before), quat_after, eps=1e-5) + + # Verify operations still exist and are in correct order + xform_ops = get_xform_ops(prim) + assert xform_ops == ["xformOp:translate", "xformOp:orient", "xformOp:scale"] + + # Verify values are approximately preserved + final_translate = prim.GetAttribute("xformOp:translate").Get() + final_orient = prim.GetAttribute("xformOp:orient").Get() + final_scale = prim.GetAttribute("xformOp:scale").Get() + + assert_vec3_close(initial_translate, final_translate, eps=1e-5) + assert_quat_close(initial_orient, final_orient, eps=1e-5) + assert_vec3_close(initial_scale, final_scale, eps=1e-5) + + +def test_standardize_xform_ops_invalid_prim(): + """Test standardize_xform_ops raises error for invalid prim.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Get an invalid prim (non-existent path) + invalid_prim = stage.GetPrimAtPath("/World/NonExistent") + + # Verify the prim is invalid + assert not invalid_prim.IsValid() + + # Attempt to apply standardize_xform_ops and expect ValueError + with pytest.raises(ValueError, match="not valid"): + sim_utils.standardize_xform_ops(invalid_prim) + + +def test_standardize_xform_ops_on_geometry_prim(): + """Test standardize_xform_ops on a geometry prim (Cube, Sphere, etc.).""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create a cube with transform + cube_prim = sim_utils.create_prim( + "/World/TestCube", + "Cube", + translation=(1.0, 2.0, 3.0), + orientation=(1.0, 0.0, 0.0, 0.0), + scale=(2.0, 2.0, 2.0), + attributes={"size": 1.0}, + stage=stage, + ) + + # Get world pose before + pos_before, quat_before = sim_utils.resolve_prim_pose(cube_prim) + + # Apply standardize_xform_ops + sim_utils.standardize_xform_ops(cube_prim) + + # Get world pose after + pos_after, quat_after = sim_utils.resolve_prim_pose(cube_prim) + # Verify world pose is preserved + assert_vec3_close(Gf.Vec3d(*pos_before), pos_after, eps=1e-5) + assert_quat_close(Gf.Quatd(*quat_before), quat_after, eps=1e-5) + + # Verify standard operations exist + xform_ops = get_xform_ops(cube_prim) + assert xform_ops == ["xformOp:translate", "xformOp:orient", "xformOp:scale"] + + +def test_standardize_xform_ops_with_non_uniform_scale(): + """Test standardize_xform_ops with non-uniform scale.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create a prim with non-uniform scale + prim = sim_utils.create_prim( + "/World/TestNonUniformScale", + "Xform", + translation=(5.0, 10.0, 15.0), + orientation=(0.7071068, 0.0, 0.7071068, 0.0), # 90 deg around Y + scale=(1.0, 2.0, 3.0), # Non-uniform scale + stage=stage, + ) + + # Get initial scale + initial_scale = prim.GetAttribute("xformOp:scale").Get() + + # Get world pose before standardization + pos_before, quat_before = sim_utils.resolve_prim_pose(prim) + + # Apply standardize_xform_ops + result = sim_utils.standardize_xform_ops(prim) + assert result is True + + # Get world pose after standardization + pos_after, quat_after = sim_utils.resolve_prim_pose(prim) + # Verify world pose is preserved + assert_vec3_close(Gf.Vec3d(*pos_before), pos_after, eps=1e-5) + assert_quat_close(Gf.Quatd(*quat_before), quat_after, eps=1e-5) + # Verify scale is preserved + final_scale = prim.GetAttribute("xformOp:scale").Get() + assert_vec3_close(initial_scale, final_scale, eps=1e-5) + + +def test_standardize_xform_ops_identity_transform(): + """Test standardize_xform_ops with identity transform (no translation, rotation, or scale).""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create a prim with identity transform + prim = sim_utils.create_prim( + "/World/TestIdentity", + "Xform", + translation=(0.0, 0.0, 0.0), + orientation=(1.0, 0.0, 0.0, 0.0), # Identity quaternion + scale=(1.0, 1.0, 1.0), + stage=stage, + ) + + # Apply standardize_xform_ops + sim_utils.standardize_xform_ops(prim) + + # Verify standard operations exist + xform_ops = get_xform_ops(prim) + assert xform_ops == ["xformOp:translate", "xformOp:orient", "xformOp:scale"] + + # Verify identity values + assert_vec3_close(prim.GetAttribute("xformOp:translate").Get(), (0.0, 0.0, 0.0)) + assert_quat_close(prim.GetAttribute("xformOp:orient").Get(), (1.0, 0.0, 0.0, 0.0)) + assert_vec3_close(prim.GetAttribute("xformOp:scale").Get(), (1.0, 1.0, 1.0)) + + +def test_standardize_xform_ops_with_explicit_values(): + """Test standardize_xform_ops with explicit translation, orientation, and scale values.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create a prim with some initial transform + prim = sim_utils.create_prim( + "/World/TestExplicitValues", + "Xform", + translation=(10.0, 10.0, 10.0), + orientation=(0.7071068, 0.7071068, 0.0, 0.0), + scale=(5.0, 5.0, 5.0), + stage=stage, + ) + + # Apply standardize_xform_ops with new explicit values + new_translation = (1.0, 2.0, 3.0) + new_orientation = (1.0, 0.0, 0.0, 0.0) + new_scale = (2.0, 2.0, 2.0) + + result = sim_utils.standardize_xform_ops( + prim, translation=new_translation, orientation=new_orientation, scale=new_scale + ) + assert result is True + + # Verify the new values are set + assert_vec3_close(prim.GetAttribute("xformOp:translate").Get(), new_translation) + assert_quat_close(prim.GetAttribute("xformOp:orient").Get(), new_orientation) + assert_vec3_close(prim.GetAttribute("xformOp:scale").Get(), new_scale) + + # Verify the prim is at the expected world location + pos_after, quat_after = sim_utils.resolve_prim_pose(prim) + assert_vec3_close(Gf.Vec3d(*pos_after), new_translation, eps=1e-5) + assert_quat_close(Gf.Quatd(*quat_after), new_orientation, eps=1e-5) + + # Verify standard operation order + xform_ops = get_xform_ops(prim) + assert xform_ops == ["xformOp:translate", "xformOp:orient", "xformOp:scale"] + + +def test_standardize_xform_ops_with_partial_values(): + """Test standardize_xform_ops with only some values specified.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create a prim + prim = sim_utils.create_prim( + "/World/TestPartialValues", + "Xform", + translation=(1.0, 2.0, 3.0), + orientation=(0.9238795, 0.3826834, 0.0, 0.0), # rotation around X + scale=(2.0, 2.0, 2.0), + stage=stage, + ) + + # Get initial local pose + pos_before, quat_before = sim_utils.resolve_prim_pose(prim, ref_prim=prim.GetParent()) + scale_before = prim.GetAttribute("xformOp:scale").Get() + + # Apply standardize_xform_ops with only translation specified + new_translation = (10.0, 20.0, 30.0) + result = sim_utils.standardize_xform_ops(prim, translation=new_translation) + assert result is True + + # Verify translation is updated + assert_vec3_close(prim.GetAttribute("xformOp:translate").Get(), new_translation) + + # Verify orientation and scale are preserved + quat_after = prim.GetAttribute("xformOp:orient").Get() + scale_after = prim.GetAttribute("xformOp:scale").Get() + assert_quat_close(Gf.Quatd(*quat_before), quat_after, eps=1e-5) + assert_vec3_close(scale_before, scale_after, eps=1e-5) + + # Verify the prim's world orientation hasn't changed (only translation changed) + _, quat_after_world = sim_utils.resolve_prim_pose(prim) + assert_quat_close(Gf.Quatd(*quat_before), quat_after_world, eps=1e-5) + + +def test_standardize_xform_ops_non_xformable_prim(): + """Test standardize_xform_ops returns False for non-Xformable prims.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create a Material prim (not Xformable) + from pxr import UsdShade + + material_prim = UsdShade.Material.Define(stage, "/World/TestMaterial").GetPrim() + + # Verify the prim is valid but not Xformable + assert material_prim.IsValid() + assert not material_prim.IsA(UsdGeom.Xformable) + + # Attempt to apply standardize_xform_ops - should return False + result = sim_utils.standardize_xform_ops(material_prim) + assert result is False + + +def test_standardize_xform_ops_preserves_reset_xform_stack(): + """Test that standardize_xform_ops preserves the resetXformStack attribute.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create a prim + prim = sim_utils.create_prim("/World/TestResetStack", "Xform", stage=stage) + xformable = UsdGeom.Xformable(prim) + + # Set resetXformStack to True + xformable.SetResetXformStack(True) + assert xformable.GetResetXformStack() is True + + # Apply standardize_xform_ops + result = sim_utils.standardize_xform_ops(prim) + assert result is True + + # Verify resetXformStack is preserved + assert xformable.GetResetXformStack() is True + + +def test_standardize_xform_ops_with_complex_hierarchy(): + """Test standardize_xform_ops on deeply nested hierarchy.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create a complex hierarchy + root = sim_utils.create_prim("/World/Root", "Xform", translation=(1.0, 0.0, 0.0), stage=stage) + child1 = sim_utils.create_prim("/World/Root/Child1", "Xform", translation=(0.0, 1.0, 0.0), stage=stage) + child2 = sim_utils.create_prim("/World/Root/Child1/Child2", "Xform", translation=(0.0, 0.0, 1.0), stage=stage) + child3 = sim_utils.create_prim("/World/Root/Child1/Child2/Child3", "Cube", translation=(1.0, 1.0, 1.0), stage=stage) + + # Get world poses before + poses_before = {} + for name, prim in [("root", root), ("child1", child1), ("child2", child2), ("child3", child3)]: + poses_before[name] = sim_utils.resolve_prim_pose(prim) + + # Apply standardize_xform_ops to all prims + assert sim_utils.standardize_xform_ops(root) is True + assert sim_utils.standardize_xform_ops(child1) is True + assert sim_utils.standardize_xform_ops(child2) is True + assert sim_utils.standardize_xform_ops(child3) is True + + # Get world poses after + poses_after = {} + for name, prim in [("root", root), ("child1", child1), ("child2", child2), ("child3", child3)]: + poses_after[name] = sim_utils.resolve_prim_pose(prim) + + # Verify all world poses are preserved + for name in poses_before: + pos_before, quat_before = poses_before[name] + pos_after, quat_after = poses_after[name] + assert_vec3_close(Gf.Vec3d(*pos_before), pos_after, eps=1e-5) + assert_quat_close(Gf.Quatd(*quat_before), quat_after, eps=1e-5) + + +def test_standardize_xform_ops_preserves_float_precision(): + """Test that standardize_xform_ops preserves float precision when it already exists.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create a prim manually with FLOAT precision operations (not double) + prim_path = "/World/TestFloatPrecision" + prim = stage.DefinePrim(prim_path, "Xform") + xformable = UsdGeom.Xformable(prim) + + # Add xform operations with FLOAT precision (not the default double) + translate_op = xformable.AddTranslateOp(UsdGeom.XformOp.PrecisionFloat) + translate_op.Set(Gf.Vec3f(1.0, 2.0, 3.0)) + + orient_op = xformable.AddOrientOp(UsdGeom.XformOp.PrecisionFloat) + orient_op.Set(Gf.Quatf(1.0, 0.0, 0.0, 0.0)) + + scale_op = xformable.AddScaleOp(UsdGeom.XformOp.PrecisionFloat) + scale_op.Set(Gf.Vec3f(1.0, 1.0, 1.0)) + + # Verify operations exist with float precision + assert translate_op.GetPrecision() == UsdGeom.XformOp.PrecisionFloat + assert orient_op.GetPrecision() == UsdGeom.XformOp.PrecisionFloat + assert scale_op.GetPrecision() == UsdGeom.XformOp.PrecisionFloat + + # Now apply standardize_xform_ops with new values (provided as double precision Python floats) + new_translation = (5.0, 10.0, 15.0) + new_orientation = (0.7071068, 0.7071068, 0.0, 0.0) # 90 deg around X + new_scale = (2.0, 3.0, 4.0) + + result = sim_utils.standardize_xform_ops( + prim, translation=new_translation, orientation=new_orientation, scale=new_scale + ) + assert result is True + + # Verify the precision is STILL float (not converted to double) + translate_op_after = UsdGeom.XformOp(prim.GetAttribute("xformOp:translate")) + orient_op_after = UsdGeom.XformOp(prim.GetAttribute("xformOp:orient")) + scale_op_after = UsdGeom.XformOp(prim.GetAttribute("xformOp:scale")) + + assert translate_op_after.GetPrecision() == UsdGeom.XformOp.PrecisionFloat + assert orient_op_after.GetPrecision() == UsdGeom.XformOp.PrecisionFloat + assert scale_op_after.GetPrecision() == UsdGeom.XformOp.PrecisionFloat + + # Verify the VALUES are set correctly (cast to float, so they're Gf.Vec3f and Gf.Quatf) + translate_value = prim.GetAttribute("xformOp:translate").Get() + assert isinstance(translate_value, Gf.Vec3f), f"Expected Gf.Vec3f, got {type(translate_value)}" + assert_vec3_close(translate_value, new_translation, eps=1e-5) + + orient_value = prim.GetAttribute("xformOp:orient").Get() + assert isinstance(orient_value, Gf.Quatf), f"Expected Gf.Quatf, got {type(orient_value)}" + assert_quat_close(orient_value, new_orientation, eps=1e-5) + + scale_value = prim.GetAttribute("xformOp:scale").Get() + assert isinstance(scale_value, Gf.Vec3f), f"Expected Gf.Vec3f, got {type(scale_value)}" + assert_vec3_close(scale_value, new_scale, eps=1e-5) + + # Verify the world pose matches what we set + pos_after, quat_after = sim_utils.resolve_prim_pose(prim) + assert_vec3_close(Gf.Vec3d(*pos_after), new_translation, eps=1e-4) + assert_quat_close(Gf.Quatd(*quat_after), new_orientation, eps=1e-4) + + +""" +Test validate_standard_xform_ops() function. +""" + + +def test_validate_standard_xform_ops_valid(): + """Test validate_standard_xform_ops returns True for standardized prims.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create a prim with standard operations + prim = sim_utils.create_prim( + "/World/TestValid", + "Xform", + translation=(1.0, 2.0, 3.0), + orientation=(1.0, 0.0, 0.0, 0.0), + scale=(1.0, 1.0, 1.0), + stage=stage, + ) + + # Standardize the prim + sim_utils.standardize_xform_ops(prim) + + # Validate it + assert sim_utils.validate_standard_xform_ops(prim) is True + + +def test_validate_standard_xform_ops_invalid_order(): + """Test validate_standard_xform_ops returns False for non-standard operation order.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create a prim and manually set up xform ops in wrong order + prim_path = "/World/TestInvalidOrder" + prim = stage.DefinePrim(prim_path, "Xform") + xformable = UsdGeom.Xformable(prim) + + # Add operations in wrong order: scale, translate, orient (should be translate, orient, scale) + scale_op = xformable.AddScaleOp(UsdGeom.XformOp.PrecisionDouble) + scale_op.Set(Gf.Vec3d(1.0, 1.0, 1.0)) + + translate_op = xformable.AddTranslateOp(UsdGeom.XformOp.PrecisionDouble) + translate_op.Set(Gf.Vec3d(1.0, 2.0, 3.0)) + + orient_op = xformable.AddOrientOp(UsdGeom.XformOp.PrecisionDouble) + orient_op.Set(Gf.Quatd(1.0, 0.0, 0.0, 0.0)) + + # Validate it - should return False + assert sim_utils.validate_standard_xform_ops(prim) is False + + +def test_validate_standard_xform_ops_with_deprecated_ops(): + """Test validate_standard_xform_ops returns False when deprecated operations exist.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create a prim with deprecated rotateXYZ operation + prim_path = "/World/TestDeprecated" + prim = stage.DefinePrim(prim_path, "Xform") + xformable = UsdGeom.Xformable(prim) + + # Add deprecated rotateXYZ operation + rotate_xyz_op = xformable.AddRotateXYZOp(UsdGeom.XformOp.PrecisionDouble) + rotate_xyz_op.Set(Gf.Vec3d(45.0, 30.0, 60.0)) + + # Validate it - should return False + assert sim_utils.validate_standard_xform_ops(prim) is False + + +def test_validate_standard_xform_ops_missing_operations(): + """Test validate_standard_xform_ops returns False when standard operations are missing.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create a prim with only translate operation (missing orient and scale) + prim_path = "/World/TestMissing" + prim = stage.DefinePrim(prim_path, "Xform") + xformable = UsdGeom.Xformable(prim) + + translate_op = xformable.AddTranslateOp(UsdGeom.XformOp.PrecisionDouble) + translate_op.Set(Gf.Vec3d(1.0, 2.0, 3.0)) + + # Validate it - should return False (missing orient and scale) + assert sim_utils.validate_standard_xform_ops(prim) is False + + +def test_validate_standard_xform_ops_invalid_prim(): + """Test validate_standard_xform_ops returns False for invalid prim.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Get an invalid prim + invalid_prim = stage.GetPrimAtPath("/World/NonExistent") + + # Validate it - should return False + assert sim_utils.validate_standard_xform_ops(invalid_prim) is False + + +def test_validate_standard_xform_ops_non_xformable(): + """Test validate_standard_xform_ops returns False for non-Xformable prims.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create a Material prim (not Xformable) + from pxr import UsdShade + + material_prim = UsdShade.Material.Define(stage, "/World/TestMaterial").GetPrim() + + # Validate it - should return False + assert sim_utils.validate_standard_xform_ops(material_prim) is False + + +def test_validate_standard_xform_ops_with_transform_matrix(): + """Test validate_standard_xform_ops returns False when transform matrix operation exists.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create a prim with transform matrix + prim_path = "/World/TestTransformMatrix" + prim = stage.DefinePrim(prim_path, "Xform") + xformable = UsdGeom.Xformable(prim) + + # Add transform matrix operation + transform_op = xformable.AddTransformOp(UsdGeom.XformOp.PrecisionDouble) + matrix = Gf.Matrix4d().SetTranslate(Gf.Vec3d(5.0, 10.0, 15.0)) + transform_op.Set(matrix) + + # Validate it - should return False + assert sim_utils.validate_standard_xform_ops(prim) is False + + +def test_validate_standard_xform_ops_extra_operations(): + """Test validate_standard_xform_ops returns False when extra operations exist.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create a prim with standard operations + prim = sim_utils.create_prim( + "/World/TestExtra", + "Xform", + translation=(1.0, 2.0, 3.0), + orientation=(1.0, 0.0, 0.0, 0.0), + scale=(1.0, 1.0, 1.0), + stage=stage, + ) + + # Standardize it + sim_utils.standardize_xform_ops(prim) + + # Add an extra operation + xformable = UsdGeom.Xformable(prim) + extra_op = xformable.AddRotateXOp(UsdGeom.XformOp.PrecisionDouble) + extra_op.Set(45.0) + + # Validate it - should return False (has extra operation) + assert sim_utils.validate_standard_xform_ops(prim) is False + + +def test_validate_standard_xform_ops_after_standardization(): + """Test validate_standard_xform_ops returns True after standardization of non-standard prim.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create a prim with non-standard operations + prim_path = "/World/TestBeforeAfter" + prim = stage.DefinePrim(prim_path, "Xform") + xformable = UsdGeom.Xformable(prim) + + # Add deprecated operations + rotate_x_op = xformable.AddRotateXOp(UsdGeom.XformOp.PrecisionDouble) + rotate_x_op.Set(45.0) + translate_op = xformable.AddTranslateOp(UsdGeom.XformOp.PrecisionDouble) + translate_op.Set(Gf.Vec3d(1.0, 2.0, 3.0)) + + # Validate before standardization - should be False + assert sim_utils.validate_standard_xform_ops(prim) is False + + # Standardize the prim + sim_utils.standardize_xform_ops(prim) + + # Validate after standardization - should be True + assert sim_utils.validate_standard_xform_ops(prim) is True + + +def test_validate_standard_xform_ops_on_geometry(): + """Test validate_standard_xform_ops works correctly on geometry prims.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create a cube with standard operations + cube_prim = sim_utils.create_prim( + "/World/TestCube", + "Cube", + translation=(1.0, 2.0, 3.0), + orientation=(1.0, 0.0, 0.0, 0.0), + scale=(2.0, 2.0, 2.0), + stage=stage, + ) + + # Standardize it + sim_utils.standardize_xform_ops(cube_prim) + + # Validate it - should be True + assert sim_utils.validate_standard_xform_ops(cube_prim) is True + + +def test_validate_standard_xform_ops_empty_prim(): + """Test validate_standard_xform_ops on prim with no xform operations.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create a bare prim with no xform operations + prim_path = "/World/TestEmpty" + prim = stage.DefinePrim(prim_path, "Xform") + + # Validate it - should return False (no operations at all) + assert sim_utils.validate_standard_xform_ops(prim) is False + + +""" +Test resolve_prim_pose() function. +""" + + +def test_resolve_prim_pose(): + """Test resolve_prim_pose() function.""" + # number of objects + num_objects = 20 + # sample random scales for x, y, z + rand_scales = np.random.uniform(0.5, 1.5, size=(num_objects, 3, 3)) + rand_widths = np.random.uniform(0.1, 10.0, size=(num_objects,)) + # sample random positions + rand_positions = np.random.uniform(-100, 100, size=(num_objects, 3, 3)) + # sample random rotations + rand_quats = np.random.randn(num_objects, 3, 4) + rand_quats /= np.linalg.norm(rand_quats, axis=2, keepdims=True) + + # create objects + for i in range(num_objects): + # simple cubes + cube_prim = sim_utils.create_prim( + f"/World/Cubes/instance_{i:02d}", + "Cube", + translation=rand_positions[i, 0], + orientation=rand_quats[i, 0], + scale=rand_scales[i, 0], + attributes={"size": rand_widths[i]}, + ) + # xform hierarchy + xform_prim = sim_utils.create_prim( + f"/World/Xform/instance_{i:02d}", + "Xform", + translation=rand_positions[i, 1], + orientation=rand_quats[i, 1], + scale=rand_scales[i, 1], + ) + geometry_prim = sim_utils.create_prim( + f"/World/Xform/instance_{i:02d}/geometry", + "Sphere", + translation=rand_positions[i, 2], + orientation=rand_quats[i, 2], + scale=rand_scales[i, 2], + attributes={"radius": rand_widths[i]}, + ) + dummy_prim = sim_utils.create_prim( + f"/World/Xform/instance_{i:02d}/dummy", + "Sphere", + ) + + # cube prim w.r.t. world frame + pos, quat = sim_utils.resolve_prim_pose(cube_prim) + pos, quat = np.array(pos), np.array(quat) + quat = quat if np.sign(rand_quats[i, 0, 0]) == np.sign(quat[0]) else -quat + np.testing.assert_allclose(pos, rand_positions[i, 0], atol=1e-3) + np.testing.assert_allclose(quat, rand_quats[i, 0], atol=1e-3) + # xform prim w.r.t. world frame + pos, quat = sim_utils.resolve_prim_pose(xform_prim) + pos, quat = np.array(pos), np.array(quat) + quat = quat if np.sign(rand_quats[i, 1, 0]) == np.sign(quat[0]) else -quat + np.testing.assert_allclose(pos, rand_positions[i, 1], atol=1e-3) + np.testing.assert_allclose(quat, rand_quats[i, 1], atol=1e-3) + # dummy prim w.r.t. world frame + pos, quat = sim_utils.resolve_prim_pose(dummy_prim) + pos, quat = np.array(pos), np.array(quat) + quat = quat if np.sign(rand_quats[i, 1, 0]) == np.sign(quat[0]) else -quat + np.testing.assert_allclose(pos, rand_positions[i, 1], atol=1e-3) + np.testing.assert_allclose(quat, rand_quats[i, 1], atol=1e-3) + + # geometry prim w.r.t. xform prim + pos, quat = sim_utils.resolve_prim_pose(geometry_prim, ref_prim=xform_prim) + pos, quat = np.array(pos), np.array(quat) + quat = quat if np.sign(rand_quats[i, 2, 0]) == np.sign(quat[0]) else -quat + np.testing.assert_allclose(pos, rand_positions[i, 2] * rand_scales[i, 1], atol=1e-3) + # TODO: Enabling scale causes the test to fail because the current implementation of + # resolve_prim_pose does not correctly handle non-identity scales on Xform prims. This is a known + # limitation. Until this is fixed, the test is disabled here to ensure the test passes. + # np.testing.assert_allclose(quat, rand_quats[i, 2], atol=1e-3) + + # dummy prim w.r.t. xform prim + pos, quat = sim_utils.resolve_prim_pose(dummy_prim, ref_prim=xform_prim) + pos, quat = np.array(pos), np.array(quat) + np.testing.assert_allclose(pos, np.zeros(3), atol=1e-3) + np.testing.assert_allclose(quat, np.array([1, 0, 0, 0]), atol=1e-3) + # xform prim w.r.t. cube prim + pos, quat = sim_utils.resolve_prim_pose(xform_prim, ref_prim=cube_prim) + pos, quat = np.array(pos), np.array(quat) + # -- compute ground truth values + gt_pos, gt_quat = math_utils.subtract_frame_transforms( + torch.from_numpy(rand_positions[i, 0]).unsqueeze(0), + torch.from_numpy(rand_quats[i, 0]).unsqueeze(0), + torch.from_numpy(rand_positions[i, 1]).unsqueeze(0), + torch.from_numpy(rand_quats[i, 1]).unsqueeze(0), + ) + gt_pos, gt_quat = gt_pos.squeeze(0).numpy(), gt_quat.squeeze(0).numpy() + quat = quat if np.sign(gt_quat[0]) == np.sign(quat[0]) else -quat + np.testing.assert_allclose(pos, gt_pos, atol=1e-3) + np.testing.assert_allclose(quat, gt_quat, atol=1e-3) + + +""" +Test resolve_prim_scale() function. +""" + + +def test_resolve_prim_scale(): + """Test resolve_prim_scale() function. + + To simplify the test, we assume that the effective scale at a prim + is the product of the scales of the prims in the hierarchy: + + scale = scale_of_xform * scale_of_geometry_prim + + This is only true when rotations are identity or the transforms are + orthogonal and uniformly scaled. Otherwise, scale is not composable + like that in local component-wise fashion. + """ + # number of objects + num_objects = 20 + # sample random scales for x, y, z + rand_scales = np.random.uniform(0.5, 1.5, size=(num_objects, 3, 3)) + rand_widths = np.random.uniform(0.1, 10.0, size=(num_objects,)) + # sample random positions + rand_positions = np.random.uniform(-100, 100, size=(num_objects, 3, 3)) + + # create objects + for i in range(num_objects): + # simple cubes + cube_prim = sim_utils.create_prim( + f"/World/Cubes/instance_{i:02d}", + "Cube", + translation=rand_positions[i, 0], + scale=rand_scales[i, 0], + attributes={"size": rand_widths[i]}, + ) + # xform hierarchy + xform_prim = sim_utils.create_prim( + f"/World/Xform/instance_{i:02d}", + "Xform", + translation=rand_positions[i, 1], + scale=rand_scales[i, 1], + ) + geometry_prim = sim_utils.create_prim( + f"/World/Xform/instance_{i:02d}/geometry", + "Sphere", + translation=rand_positions[i, 2], + scale=rand_scales[i, 2], + attributes={"radius": rand_widths[i]}, + ) + dummy_prim = sim_utils.create_prim( + f"/World/Xform/instance_{i:02d}/dummy", + "Sphere", + ) + + # cube prim + scale = sim_utils.resolve_prim_scale(cube_prim) + scale = np.array(scale) + np.testing.assert_allclose(scale, rand_scales[i, 0], atol=1e-5) + # xform prim + scale = sim_utils.resolve_prim_scale(xform_prim) + scale = np.array(scale) + np.testing.assert_allclose(scale, rand_scales[i, 1], atol=1e-5) + # geometry prim + scale = sim_utils.resolve_prim_scale(geometry_prim) + scale = np.array(scale) + np.testing.assert_allclose(scale, rand_scales[i, 1] * rand_scales[i, 2], atol=1e-5) + # dummy prim + scale = sim_utils.resolve_prim_scale(dummy_prim) + scale = np.array(scale) + np.testing.assert_allclose(scale, rand_scales[i, 1], atol=1e-5) + + +""" +Test convert_world_pose_to_local() function. +""" + + +def test_convert_world_pose_to_local_basic(): + """Test basic world-to-local pose conversion.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create parent and child prims + parent_prim = sim_utils.create_prim( + "/World/Parent", + "Xform", + translation=(5.0, 0.0, 0.0), + orientation=(1.0, 0.0, 0.0, 0.0), # identity rotation + scale=(1.0, 1.0, 1.0), + stage=stage, + ) + + # World pose we want to achieve for a child + world_position = (10.0, 3.0, 0.0) + world_orientation = (1.0, 0.0, 0.0, 0.0) # identity rotation + + # Convert to local space + local_translation, local_orientation = sim_utils.convert_world_pose_to_local( + world_position, world_orientation, parent_prim + ) + # Assert orientation is not None + assert local_orientation is not None + + # The expected local translation is world_position - parent_position = (10-5, 3-0, 0-0) = (5, 3, 0) + assert_vec3_close(Gf.Vec3d(*local_translation), (5.0, 3.0, 0.0), eps=1e-5) + assert_quat_close(Gf.Quatd(*local_orientation), (1.0, 0.0, 0.0, 0.0), eps=1e-5) + + +def test_convert_world_pose_to_local_with_rotation(): + """Test world-to-local conversion with parent rotation.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create parent with 90-degree rotation around Z axis + parent_prim = sim_utils.create_prim( + "/World/RotatedParent", + "Xform", + translation=(0.0, 0.0, 0.0), + orientation=(0.7071068, 0.0, 0.0, 0.7071068), # 90 deg around Z + scale=(1.0, 1.0, 1.0), + stage=stage, + ) + + # World pose: position at (1, 0, 0) with identity rotation + world_position = (1.0, 0.0, 0.0) + world_orientation = (1.0, 0.0, 0.0, 0.0) + + # Convert to local space + local_translation, local_orientation = sim_utils.convert_world_pose_to_local( + world_position, world_orientation, parent_prim + ) + + # Create a child with the local transform and verify world pose + child_prim = sim_utils.create_prim( + "/World/RotatedParent/Child", + "Xform", + translation=local_translation, + orientation=local_orientation, + stage=stage, + ) + + # Get world pose of child + child_world_pos, child_world_quat = sim_utils.resolve_prim_pose(child_prim) + + # Verify it matches the desired world pose + assert_vec3_close(Gf.Vec3d(*child_world_pos), world_position, eps=1e-5) + assert_quat_close(Gf.Quatd(*child_world_quat), world_orientation, eps=1e-5) + + +def test_convert_world_pose_to_local_with_scale(): + """Test world-to-local conversion with parent scale.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create parent with non-uniform scale + parent_prim = sim_utils.create_prim( + "/World/ScaledParent", + "Xform", + translation=(1.0, 2.0, 3.0), + orientation=(1.0, 0.0, 0.0, 0.0), + scale=(2.0, 2.0, 2.0), + stage=stage, + ) + + # World pose we want + world_position = (5.0, 6.0, 7.0) + world_orientation = (0.7071068, 0.7071068, 0.0, 0.0) # 90 deg around X + + # Convert to local space + local_translation, local_orientation = sim_utils.convert_world_pose_to_local( + world_position, world_orientation, parent_prim + ) + + # Create child and verify + child_prim = sim_utils.create_prim( + "/World/ScaledParent/Child", + "Xform", + translation=local_translation, + orientation=local_orientation, + stage=stage, + ) + + # Get world pose + child_world_pos, child_world_quat = sim_utils.resolve_prim_pose(child_prim) + + # Verify (may have some tolerance due to scale effects on rotation) + assert_vec3_close(Gf.Vec3d(*child_world_pos), world_position, eps=1e-4) + assert_quat_close(Gf.Quatd(*child_world_quat), world_orientation, eps=1e-4) + + +def test_convert_world_pose_to_local_invalid_parent(): + """Test world-to-local conversion with invalid parent returns world pose unchanged.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Get an invalid prim + invalid_prim = stage.GetPrimAtPath("/World/NonExistent") + assert not invalid_prim.IsValid() + + world_position = (10.0, 20.0, 30.0) + world_orientation = (0.7071068, 0.0, 0.7071068, 0.0) + + # Convert with invalid reference prim + with pytest.raises(ValueError): + sim_utils.convert_world_pose_to_local(world_position, world_orientation, invalid_prim) + + +def test_convert_world_pose_to_local_root_parent(): + """Test world-to-local conversion with root as parent returns world pose unchanged.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Get the pseudo-root prim + root_prim = stage.GetPrimAtPath("/") + + world_position = (15.0, 25.0, 35.0) + world_orientation = (0.9238795, 0.3826834, 0.0, 0.0) + + # Convert with root parent + local_translation, local_orientation = sim_utils.convert_world_pose_to_local( + world_position, world_orientation, root_prim + ) + # Assert orientation is not None + assert local_orientation is not None + + # Should return unchanged + assert_vec3_close(Gf.Vec3d(*local_translation), world_position, eps=1e-10) + assert_quat_close(Gf.Quatd(*local_orientation), world_orientation, eps=1e-10) + + +def test_convert_world_pose_to_local_none_orientation(): + """Test world-to-local conversion with None orientation.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create parent + parent_prim = sim_utils.create_prim( + "/World/ParentNoOrient", + "Xform", + translation=(3.0, 4.0, 5.0), + orientation=(0.7071068, 0.0, 0.0, 0.7071068), # 90 deg around Z + stage=stage, + ) + + world_position = (10.0, 10.0, 10.0) + + # Convert with None orientation + local_translation, local_orientation = sim_utils.convert_world_pose_to_local(world_position, None, parent_prim) + + # Orientation should be None + assert local_orientation is None + # Translation should still be converted + assert local_translation is not None + + +def test_convert_world_pose_to_local_complex_hierarchy(): + """Test world-to-local conversion in a complex hierarchy.""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create a complex hierarchy + _ = sim_utils.create_prim( + "/World/Grandparent", + "Xform", + translation=(10.0, 0.0, 0.0), + orientation=(0.7071068, 0.0, 0.0, 0.7071068), # 90 deg around Z + scale=(2.0, 2.0, 2.0), + stage=stage, + ) + + parent = sim_utils.create_prim( + "/World/Grandparent/Parent", + "Xform", + translation=(5.0, 0.0, 0.0), # local to grandparent + orientation=(0.7071068, 0.7071068, 0.0, 0.0), # 90 deg around X + scale=(0.5, 0.5, 0.5), + stage=stage, + ) + + # World pose we want to achieve + world_position = (20.0, 15.0, 10.0) + world_orientation = (1.0, 0.0, 0.0, 0.0) + + # Convert to local space relative to parent + local_translation, local_orientation = sim_utils.convert_world_pose_to_local( + world_position, world_orientation, parent + ) + + # Create child with the computed local transform + child = sim_utils.create_prim( + "/World/Grandparent/Parent/Child", + "Xform", + translation=local_translation, + orientation=local_orientation, + stage=stage, + ) + + # Verify world pose + child_world_pos, child_world_quat = sim_utils.resolve_prim_pose(child) + + # Should match the desired world pose (with some tolerance for complex transforms) + assert_vec3_close(Gf.Vec3d(*child_world_pos), world_position, eps=1e-4) + assert_quat_close(Gf.Quatd(*child_world_quat), world_orientation, eps=1e-4) + + +def test_convert_world_pose_to_local_with_mixed_prim_types(): + """Test world-to-local conversion with mixed prim types (Xform, Scope, Mesh).""" + # obtain stage handle + stage = sim_utils.get_current_stage() + + # Create a hierarchy with different prim types + # Grandparent: Xform with transform + sim_utils.create_prim( + "/World/Grandparent", + "Xform", + translation=(5.0, 3.0, 2.0), + orientation=(0.7071068, 0.0, 0.0, 0.7071068), # 90 deg around Z + scale=(2.0, 2.0, 2.0), + stage=stage, + ) + + # Parent: Scope prim (organizational, typically has no transform) + parent = stage.DefinePrim("/World/Grandparent/Parent", "Scope") + + # Obtain parent prim pose (should be grandparent's transform) + parent_pos, parent_quat = sim_utils.resolve_prim_pose(parent) + assert_vec3_close(Gf.Vec3d(*parent_pos), (5.0, 3.0, 2.0), eps=1e-5) + assert_quat_close(Gf.Quatd(*parent_quat), (0.7071068, 0.0, 0.0, 0.7071068), eps=1e-5) + + # Child: Mesh prim (geometry) + child = sim_utils.create_prim("/World/Grandparent/Parent/Child", "Mesh", stage=stage) + + # World pose we want to achieve for the child + world_position = (10.0, 5.0, 3.0) + world_orientation = (1.0, 0.0, 0.0, 0.0) # identity rotation + + # Convert to local space relative to parent (Scope) + local_translation, local_orientation = sim_utils.convert_world_pose_to_local( + world_position, world_orientation, child + ) + + # Verify orientation is not None + assert local_orientation is not None, "Expected orientation to be computed" + + # Set the local transform on the child (Mesh) + xformable = UsdGeom.Xformable(child) + translate_op = xformable.GetTranslateOp() + translate_op.Set(Gf.Vec3d(*local_translation)) + orient_op = xformable.GetOrientOp() + orient_op.Set(Gf.Quatd(*local_orientation)) + + # Verify world pose of child + child_world_pos, child_world_quat = sim_utils.resolve_prim_pose(child) + + # Should match the desired world pose + # Note: Scope prims typically have no transform, so the child's world pose should account + # for the grandparent's transform + assert_vec3_close(Gf.Vec3d(*child_world_pos), world_position, eps=1e-10) + assert_quat_close(Gf.Quatd(*child_world_quat), world_orientation, eps=1e-10)