Skip to content

Commit

Permalink
"Ontop" utility (#1835)
Browse files Browse the repository at this point in the history
* add ontop function and test

* refactor utils to use python case
  • Loading branch information
aclegg3 authored Mar 8, 2024
1 parent f5567b7 commit 8e2d957
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 35 deletions.
140 changes: 105 additions & 35 deletions habitat-lab/habitat/sims/habitat_simulator/sim_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,55 +519,55 @@ def get_global_keypoints_from_bb(


def get_rigid_object_global_keypoints(
objectA: habitat_sim.physics.ManagedRigidObject,
object_a: habitat_sim.physics.ManagedRigidObject,
) -> List[mn.Vector3]:
"""
Get a list of rigid object keypoints in global space.
0th point is the bounding box center, others are bounding box corners.
:param objectA: The ManagedRigidObject from which to extract keypoints.
:param object_a: The ManagedRigidObject from which to extract keypoints.
:return: A set of global 3D keypoints for the object.
"""

bb = objectA.root_scene_node.cumulative_bb
return get_global_keypoints_from_bb(bb, objectA.transformation)
bb = object_a.root_scene_node.cumulative_bb
return get_global_keypoints_from_bb(bb, object_a.transformation)


def get_articulated_object_global_keypoints(
objectA: habitat_sim.physics.ManagedArticulatedObject,
object_a: habitat_sim.physics.ManagedArticulatedObject,
ao_aabbs: Dict[int, mn.Range3D] = None,
) -> List[mn.Vector3]:
"""
Get global bb keypoints for an ArticulatedObject.
:param objectA: The ManagedArticulatedObject from which to extract keypoints.
:param object_a: The ManagedArticulatedObject from which to extract keypoints.
:param ao_aabbs: A pre-computed map from ArticulatedObject object_ids to their local bounding boxes. If not provided, recomputed as necessary. Must contain the subjects of the query.
:return: A set of global 3D keypoints for the object.
"""

ao_bb = None
if ao_aabbs is None:
ao_bb = get_ao_root_bb(objectA)
ao_bb = get_ao_root_bb(object_a)
else:
ao_bb = ao_aabbs[objectA.object_id]
ao_bb = ao_aabbs[object_a.object_id]

return get_global_keypoints_from_bb(ao_bb, objectA.transformation)
return get_global_keypoints_from_bb(ao_bb, object_a.transformation)


def get_articulated_link_global_keypoints(
objectA: habitat_sim.physics.ManagedArticulatedObject, link_index: int
object_a: habitat_sim.physics.ManagedArticulatedObject, link_index: int
) -> List[mn.Vector3]:
"""
Get global bb keypoints for an ArticulatedLink.
:param objectA: The parent ManagedArticulatedObject for the link.
:param object_a: The parent ManagedArticulatedObject for the link.
:param link_index: The local index of the link within the parent ArticulatedObject. Not the object_id of the link.
:return: A set of global 3D keypoints for the link.
"""
link_node = objectA.get_link_scene_node(link_index)
link_node = object_a.get_link_scene_node(link_index)

return get_global_keypoints_from_bb(
link_node.cumulative_bb, link_node.absolute_transformation()
Expand Down Expand Up @@ -608,14 +608,14 @@ def get_global_keypoints_from_object_id(

def object_keypoint_cast(
sim: habitat_sim.Simulator,
objectA: habitat_sim.physics.ManagedRigidObject,
object_a: habitat_sim.physics.ManagedRigidObject,
direction: mn.Vector3 = None,
) -> List[habitat_sim.physics.RaycastResults]:
"""
Computes object global keypoints, casts rays from each in the specified direction and returns the resulting RaycastResults.
:param sim: The Simulator instance.
:param objectA: The ManagedRigidObject from which to extract keypoints and raycast.
:param object_a: The ManagedRigidObject from which to extract keypoints and raycast.
:param direction: Optionally provide a unit length global direction vector for the raycast. If None, default to -Y.
:return: A list of RaycastResults, one from each object keypoint.
Expand All @@ -625,7 +625,7 @@ def object_keypoint_cast(
# default to downward raycast
direction = mn.Vector3(0, -1, 0)

global_keypoints = get_rigid_object_global_keypoints(objectA)
global_keypoints = get_rigid_object_global_keypoints(object_a)
return [
sim.cast_ray(habitat_sim.geo.Ray(keypoint, direction))
for keypoint in global_keypoints
Expand All @@ -639,39 +639,39 @@ def object_keypoint_cast(

def above(
sim: habitat_sim.Simulator,
objectA: Union[
object_a: Union[
habitat_sim.physics.ManagedRigidObject,
habitat_sim.physics.ManagedArticulatedObject,
],
) -> List[int]:
"""
Get a list of all objects that a particular objectA is 'above'.
Get a list of all objects that a particular object_a is 'above'.
Concretely, 'above' is defined as: a downward raycast of any object keypoint hits the object below.
:param sim: The Simulator instance.
:param objectA: The ManagedRigidObject for which to query the 'above' set.
:param object_a: The ManagedRigidObject for which to query the 'above' set.
:return: a list of object ids.
"""

# get object ids of all objects below this one
above_object_ids = [
hit.object_id
for keypoint_raycast_result in object_keypoint_cast(sim, objectA)
for keypoint_raycast_result in object_keypoint_cast(sim, object_a)
for hit in keypoint_raycast_result.hits
]
above_object_ids = list(set(above_object_ids))

# remove self from the list if present
if objectA.object_id in above_object_ids:
above_object_ids.remove(objectA.object_id)
if object_a.object_id in above_object_ids:
above_object_ids.remove(object_a.object_id)

return above_object_ids


def within(
sim: habitat_sim.Simulator,
objectA: Union[
object_a: Union[
habitat_sim.physics.ManagedRigidObject,
habitat_sim.physics.ManagedArticulatedObject,
],
Expand All @@ -680,20 +680,20 @@ def within(
center_ensures_containment: bool = True,
) -> List[int]:
"""
Get a list of all objects that a particular objectA is 'within'.
Get a list of all objects that a particular object_a is 'within'.
Concretely, 'within' is defined as: a threshold number of opposing keypoint raycasts hit the same object.
This function computes raycasts along all global axes from all keypoints and checks opposing rays for collision with the same object.
:param sim: The Simulator instance.
:param objectA: The ManagedRigidObject for which to query the 'within' set.
:param object_a: The ManagedRigidObject for which to query the 'within' set.
:param max_distance: The maximum ray distance to check in each opposing direction (this is half the "wingspan" of the check). Makes the raycast more efficienct and realistically containing objects will have a limited size.
:param keypoint_vote_threshold: The minimum number of keypoints which must indicate containment to qualify objectA as "within" another object.
:param center_ensures_containment: If True, positive test of objectA's center keypoint alone qualifies objectA as "within" another object.
:param keypoint_vote_threshold: The minimum number of keypoints which must indicate containment to qualify object_a as "within" another object.
:param center_ensures_containment: If True, positive test of object_a's center keypoint alone qualifies object_a as "within" another object.
:return: a list of object_id integers.
"""

global_keypoints = get_rigid_object_global_keypoints(objectA)
global_keypoints = get_rigid_object_global_keypoints(object_a)

# build axes vectors
pos_axes = [mn.Vector3.x_axis(), mn.Vector3.y_axis(), mn.Vector3.z_axis()]
Expand Down Expand Up @@ -748,15 +748,85 @@ def within(
containment_ids = list(set(containment_ids))

# remove self from the list if present
if objectA.object_id in containment_ids:
containment_ids.remove(objectA.object_id)
if object_a.object_id in containment_ids:
containment_ids.remove(object_a.object_id)

return containment_ids


def ontop(
sim: habitat_sim.Simulator,
object_a: Union[
habitat_sim.physics.ManagedRigidObject,
habitat_sim.physics.ManagedArticulatedObject,
int,
],
do_collision_detection: bool,
vertical_normal_error_threshold: float = 0.75,
) -> List[int]:
"""
Get a list of all object ids or objects that are "ontop" of a particular object_a.
Concretely, 'ontop' is defined as: contact points between object_a and objectB have vertical normals "upward" relative to object_a.
This function uses collision points to determine which objects are resting on or contacting the surface of object_a.
:param sim: The Simulator instance.
:param object_a: The ManagedRigidObject or object id for which to query the 'ontop' set.
:param do_collision_detection: If True, a fresh discrete collision detection is run before the contact point query. Pass False to skip if a recent sim step or pre-process has run a collision detection pass on the current state.
:param vertical_normal_error_threshold: The allowed error in normal alignment for a contact point to be considered "vertical" for this check. Functionally, if dot(contact normal, Y) <= threshold, the contact is ignored.
:return: a list of integer object_ids for the set of objects "ontop" of object_a.
"""

link_id = None
if isinstance(object_a, int):
subject_object = get_obj_from_id(sim, object_a)
if subject_object is None:
raise AssertionError(
f"The passed object_id {object_a} is invalid."
)
if subject_object.object_id != object_a:
# object_a is a link
link_id = subject_object.link_object_ids[object_a]
object_a = subject_object

if do_collision_detection:
sim.perform_discrete_collision_detection()

yup = mn.Vector3(0.0, 1.0, 0.0)

ontop_object_ids = []
for cp in sim.get_physics_contact_points():
contacting_obj_id = None
obj_is_b = False
if cp.object_id_a == object_a.object_id and (
link_id is None or link_id == cp.link_id_a
):
contacting_obj_id = cp.object_id_b
elif cp.object_id_b == object_a.object_id and (
link_id is None or link_id == cp.link_id_b
):
contacting_obj_id = cp.object_id_a
obj_is_b = True
if contacting_obj_id is not None:
contact_normal = (
cp.contact_normal_on_b_in_ws
if obj_is_b
else -cp.contact_normal_on_b_in_ws
)
if (
mn.math.dot(contact_normal, yup)
> vertical_normal_error_threshold
):
ontop_object_ids.append(contacting_obj_id)

ontop_object_ids = list(set(ontop_object_ids))

return ontop_object_ids


def object_in_region(
sim: habitat_sim.Simulator,
objectA: Union[
object_a: Union[
habitat_sim.physics.ManagedRigidObject,
habitat_sim.physics.ManagedArticulatedObject,
],
Expand All @@ -770,7 +840,7 @@ def object_in_region(
Check if an object is within a region by checking region containment of keypoints.
:param sim: The Simulator instance.
:param objectA: The object instance.
:param object_a: The object instance.
:param region: The SemanticRegion to check.
:param containment_threshold: threshold ratio of keypoints which need to be in a region to count as containment.
:param center_only: If True, only use the BB center keypoint, all or nothing.
Expand All @@ -783,7 +853,7 @@ def object_in_region(

key_points = get_global_keypoints_from_object_id(
sim,
object_id=objectA.object_id,
object_id=object_a.object_id,
ao_link_map=ao_link_map,
ao_aabbs=ao_aabbs,
)
Expand All @@ -799,7 +869,7 @@ def object_in_region(

def get_object_regions(
sim: habitat_sim.Simulator,
objectA: Union[
object_a: Union[
habitat_sim.physics.ManagedRigidObject,
habitat_sim.physics.ManagedArticulatedObject,
],
Expand All @@ -810,7 +880,7 @@ def get_object_regions(
Get a sorted list of regions containing an object using bounding box keypoints.
:param sim: The Simulator instance.
:param objectA: The object instance.
:param object_a: The object instance.
:param ao_link_map: A pre-computed map from link object ids to their parent ArticulatedObject's object id.
:param ao_aabbs: A pre-computed map from ArticulatedObject object_ids to their local bounding boxes. If not provided, recomputed as necessary.
Expand All @@ -819,7 +889,7 @@ def get_object_regions(

key_points = get_global_keypoints_from_object_id(
sim,
object_id=objectA.object_id,
object_id=object_a.object_id,
ao_link_map=ao_link_map,
ao_aabbs=ao_aabbs,
)
Expand Down
Loading

0 comments on commit 8e2d957

Please sign in to comment.