Skip to content

Commit

Permalink
Improving PDDL Docs (facebookresearch#1784)
Browse files Browse the repository at this point in the history
* Added better PDDL docs

* Fixed is_true evaluation

* change type of target for _is_obj_state_true

* avoid circular import

---------

Co-authored-by: aclegg3 <[email protected]>
  • Loading branch information
ASzot and aclegg3 authored Feb 6, 2024
1 parent 5d2a6be commit 73c78e4
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ class LogicalQuantifierType(Enum):


class LogicalExpr:
"""
Refers to combinations of PDDL expressions or subexpressions.
"""

def __init__(
self,
expr_type: LogicalExprType,
Expand Down
69 changes: 53 additions & 16 deletions habitat-lab/habitat/tasks/rearrange/multi_task/pddl_sim_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,17 @@
rearrange_logger,
)

# TODO: Deprecate these and instead represent them as articulated object entity type.
CAB_TYPE = "cab_type"
FRIDGE_TYPE = "fridge_type"


class ArtSampler:
"""
Desired simulator state for a articulated object. Expresses a range of
allowable joint values.
"""

def __init__(
self, value: float, cmp: str, override_thresh: Optional[float] = None
):
Expand All @@ -46,7 +52,7 @@ def is_satisfied(self, cur_value: float, thresh: float) -> bool:
elif self.cmp == "close":
return abs(cur_value - self.value) < thresh
else:
raise ValueError(f"Unrecognized cmp {self.cmp}")
raise ValueError(f"Unrecognized comparison {self.cmp}")

def sample(self) -> float:
return self.value
Expand Down Expand Up @@ -224,6 +230,10 @@ def set_state(


class PddlSimState:
"""
The "building block" for predicates. This checks if a particular simulator state is satisfied.
"""

def __init__(
self,
art_states: Dict[PddlEntity, ArtSampler],
Expand Down Expand Up @@ -336,23 +346,27 @@ def is_true(
Throws exception if the arguments are not compatible.
"""

# Check object states.
for entity, target in self._obj_states.items():
return all(
_is_obj_state_true(entity, target, sim_info)
for entity, target in self._obj_states.items()
)
# Check object states are true.
if not all(
_is_obj_state_true(entity, target, sim_info)
for entity, target in self._obj_states.items()
):
return False

for art_entity, set_art in self._art_states.items():
return all(
_is_art_state_true(art_entity, set_art, sim_info)
for art_entity, set_art in self._art_states.items()
)
# Check articulated object states are true.
if not all(
_is_art_state_true(art_entity, set_art, sim_info)
for art_entity, set_art in self._art_states.items()
):
return False

return all(
# Check robot states are true.
if not all(
robot_state.is_true(sim_info, robot_entity)
for robot_entity, robot_state in self._robot_states.items()
)
):
return False
return True

def set_state(self, sim_info: PddlSimInfo) -> None:
"""
Expand Down Expand Up @@ -437,7 +451,9 @@ def _is_object_inside(
return global_bb.contains(entity_pos)


def _is_obj_state_true(entity, target, sim_info) -> bool:
def _is_obj_state_true(
entity: PddlEntity, target: PddlEntity, sim_info: PddlSimInfo
) -> bool:
entity_pos = sim_info.get_entity_pos(entity)

if sim_info.check_type_matches(
Expand Down Expand Up @@ -476,7 +492,14 @@ def _is_obj_state_true(entity, target, sim_info) -> bool:
return True


def _is_art_state_true(art_entity, set_art, sim_info) -> bool:
def _is_art_state_true(
art_entity: PddlEntity, set_art: ArtSampler, sim_info: PddlSimInfo
) -> bool:
"""
Checks if an articulated object entity matches a condition specified by
`set_art`.
"""

if not sim_info.check_type_matches(
art_entity,
SimulatorObjectType.ARTICULATED_RECEPTACLE_ENTITY.value,
Expand All @@ -498,6 +521,10 @@ def _is_art_state_true(art_entity, set_art, sim_info) -> bool:
def _place_obj_on_goal(
target: PddlEntity, sim_info: PddlSimInfo
) -> mn.Matrix4:
"""
Place an object at a goal position.
"""

sim = sim_info.sim
targ_idx = cast(
int,
Expand All @@ -511,6 +538,10 @@ def _place_obj_on_goal(
def _place_obj_on_obj(
entity: PddlEntity, target: PddlEntity, sim_info: PddlSimInfo
) -> mn.Matrix4:
"""
This is intended to implement placing an object on top of another object.
"""

raise NotImplementedError()


Expand All @@ -530,6 +561,12 @@ def _place_obj_on_recep(target: PddlEntity, sim_info) -> mn.Matrix4:
def _set_obj_state(
entity: PddlEntity, target: PddlEntity, sim_info: PddlSimInfo
) -> None:
"""
Sets an object state to match the state specified by `target`. The context
of this will vary on the type of the source and target entity (like if we
are placing an object on a receptacle).
"""

sim = sim_info.sim

# The source object must be movable.
Expand Down
62 changes: 59 additions & 3 deletions habitat-lab/habitat/tasks/rearrange/multi_task/rearrange_pddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import magnum as mn
import numpy as np
Expand All @@ -17,8 +17,16 @@
from habitat.tasks.rearrange.rearrange_sim import RearrangeSim
from habitat.tasks.rearrange.rearrange_task import RearrangeTask

# Trick to avoid circular import
if TYPE_CHECKING:
from habitat.tasks.rearrange.multi_task.pddl_predicate import Predicate


class SimulatorObjectType(Enum):
"""
Predefined entity types for which default predicate behavior is defined.
"""

MOVABLE_ENTITY = "movable_entity_type"
STATIC_RECEPTACLE_ENTITY = "static_receptacle_entity_type"
ARTICULATED_RECEPTACLE_ENTITY = "art_receptacle_entity_type"
Expand Down Expand Up @@ -46,6 +54,11 @@ def parse_func(x: str) -> Tuple[str, List[str]]:


class ExprType:
"""
Hierarchical type in the PDDL system. The user can define custom types and
the types in `SimulatorObjectType` are automatically defined.
"""

def __init__(self, name: str, parent: Optional["ExprType"]):
assert isinstance(name, str)
assert parent is None or isinstance(parent, ExprType)
Expand All @@ -72,6 +85,10 @@ def __repr__(self):

@dataclass(frozen=True)
class PddlEntity:
"""
Abstract PDDL entity. This is linked to simulator via `PddlSimInfo`.
"""

name: str
expr_type: ExprType

Expand Down Expand Up @@ -124,6 +141,27 @@ def ensure_entity_lists_match(

@dataclass
class PddlSimInfo:
"""
Manages the mapping between the abstract PDDL and the underlying simulator
entities. This also provides some helper methods for accessing PDDL entity
simulator properties like object position (which could vary per entity type).
:property obj_ids: Mapping from the habitat instance handle to simulator ID.
:property target_ids: Mapping from target instance handle to simulator ID.
ONLY relevant for geometric goal. In the future we can probably remove this
distinction.
:property art_handles: The simulator articulated object asset handles.
:property obj_thresh: Setting that configures an object threshold in
predicate state evaluation.
:property receptacles: Simulator receptacle regions. See `receptacles` in `RearrangeSim`.
:property filter_colliding_states: Setting used for placing robot predicate state.
:property num_spawn_attempts: Setting used for placing robot predicate state.
:property pred_truth_cache: Used by the task to avoid evaluating the same
predicate multiple times.
"""

obj_ids: Dict[str, int]
target_ids: Dict[str, int]
art_handles: Dict[str, int]
Expand All @@ -139,7 +177,7 @@ class PddlSimInfo:
robot_at_thresh: float
expr_types: Dict[str, ExprType]
predicates: Dict[str, Any]
all_entities: Dict[str, Any]
all_entities: Dict[str, PddlEntity]
receptacles: Dict[str, mn.Range3D]

num_spawn_attempts: int
Expand All @@ -149,15 +187,28 @@ class PddlSimInfo:
pred_truth_cache: Optional[Dict[str, bool]] = None

def reset_pred_truth_cache(self):
"""
Task that uses the `pred_truth_cache` is responsible for calling
this.
"""

self.pred_truth_cache = {}

def get_predicate(self, pred_name: str):
def get_predicate(self, pred_name: str) -> "Predicate":
"""
Look up predicate by name.
"""

return self.predicates[pred_name]

def check_type_matches(self, entity: PddlEntity, match_name: str) -> bool:
return entity.expr_type.is_subtype_of(self.expr_types[match_name])

def get_entity_pos(self, entity: PddlEntity) -> np.ndarray:
"""
Gets a simulator 3D point for an entity.
"""

ename = entity.name
if self.check_type_matches(
entity, SimulatorObjectType.ROBOT_ENTITY.value
Expand Down Expand Up @@ -196,6 +247,11 @@ def get_entity_pos(self, entity: PddlEntity) -> np.ndarray:
def search_for_entity(
self, entity: PddlEntity
) -> Union[int, str, MarkerInfo, mn.Range3D]:
"""
Returns underlying simulator information associated with a PDDL entity.
Helper to match the PDDL entity to something from the simulator.
"""

ename = entity.name

if self.check_type_matches(
Expand Down

0 comments on commit 73c78e4

Please sign in to comment.