diff --git a/examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidASE_ViewMotion.py b/examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidASE_ViewMotion.py index ab7e48357..d5d8376a1 100644 --- a/examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidASE_ViewMotion.py +++ b/examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidASE_ViewMotion.py @@ -8,7 +8,8 @@ import isaacgym import argparse -from rofunc.config.utils import omegaconf_to_dict, get_config, load_view_motion_config +import rofunc as rf +from rofunc.config.utils import omegaconf_to_dict, get_config from rofunc.learning.RofuncRL.tasks import Tasks from rofunc.learning.RofuncRL.trainers import Trainers @@ -63,7 +64,7 @@ def inference(custom_args): # Available types of motion file path: # 1. test data provided by rofunc: `examples/data/amp/*.npy` # 2. custom motion file with absolute path - parser.add_argument("--motion_file", type=str, default="/home/ubuntu/Github/HOTU/hotu/data/hotu/010_amp.npy") + parser.add_argument("--motion_file", type=str, default=rf.oslab.get_rofunc_path('../examples/data/amp/amp_humanoid_backflip.npy')) custom_args = parser.parse_args() inference(custom_args) diff --git a/examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidHOTU_RofuncRL.py b/examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidHOTU_RofuncRL.py index 9abecd33e..ff58ce6b3 100644 --- a/examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidHOTU_RofuncRL.py +++ b/examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidHOTU_RofuncRL.py @@ -136,16 +136,16 @@ def inference(custom_args): parser.add_argument("--debug", type=str, default="False") parser.add_argument("--headless", type=str, default="True") parser.add_argument("--inference", action="store_false", help="turn to inference mode while adding this argument") - parser.add_argument("--ckpt_path", type=str, default="/home/ubuntu/Github/Xianova_Robotics/Rofunc-secret/examples/learning_rl/IsaacGym_RofuncRL/saved_runs/RofuncRL_HOTUTrainer_HumanoidHOTUGetup_HOTUBruce_24-05-28_13-51-39-584325_body_amp5/checkpoints/best_ckpt.pth") + parser.add_argument("--ckpt_path", type=str, default="../examples/learning_rl/IsaacGym_RofuncRL/saved_runs/RofuncRL_HOTUTrainer_HumanoidHOTUGetup_HOTUBruce_24-05-28_13-51-39-584325_body_amp5/checkpoints/best_ckpt.pth") # HOTU - # parser.add_argument("--llc_ckpt_path", type=str, default="/home/ubuntu/Github/Xianova_Robotics/Rofunc-secret/examples/learning_rl/IsaacGym_RofuncRL/saved_runs/RofuncRL_HOTUTrainer_HumanoidHOTUGetup_HOTUHumanoidWQbhandNew_24-05-26_21-16-24-361269_body_amp5/checkpoints/best_ckpt.pth") + # parser.add_argument("--llc_ckpt_path", type=str, default="../examples/learning_rl/IsaacGym_RofuncRL/saved_runs/RofuncRL_HOTUTrainer_HumanoidHOTUGetup_HOTUHumanoidWQbhandNew_24-05-26_21-16-24-361269_body_amp5/checkpoints/best_ckpt.pth") # ZJU - # parser.add_argument("--llc_ckpt_path", type=str, default="/home/ubuntu/Github/Xianova_Robotics/Rofunc-secret/examples/learning_rl/IsaacGym_RofuncRL/saved_runs/RofuncRL_HOTUTrainer_HumanoidHOTUGetup_HOTUZJUHumanoidWQbhandNew_24-05-26_18-57-20-244370_body_amp5/checkpoints/best_ckpt.pth") + # parser.add_argument("--llc_ckpt_path", type=str, default="../examples/learning_rl/IsaacGym_RofuncRL/saved_runs/RofuncRL_HOTUTrainer_HumanoidHOTUGetup_HOTUZJUHumanoidWQbhandNew_24-05-26_18-57-20-244370_body_amp5/checkpoints/best_ckpt.pth") # H1 - # parser.add_argument("--llc_ckpt_path", type=str, default="/home/ubuntu/Github/Xianova_Robotics/Rofunc-secret/examples/learning_rl/IsaacGym_RofuncRL/saved_runs/RofuncRL_HOTUTrainer_HumanoidHOTUGetup_HOTUH1WQbhandNew_24-05-27_16-59-15-598225_body_amp5/checkpoints/best_ckpt.pth") + # parser.add_argument("--llc_ckpt_path", type=str, default="../examples/learning_rl/IsaacGym_RofuncRL/saved_runs/RofuncRL_HOTUTrainer_HumanoidHOTUGetup_HOTUH1WQbhandNew_24-05-27_16-59-15-598225_body_amp5/checkpoints/best_ckpt.pth") # Bruce - parser.add_argument("--llc_ckpt_path", type=str, default="/home/ubuntu/Github/Xianova_Robotics/Rofunc-secret/examples/learning_rl/IsaacGym_RofuncRL/saved_runs/RofuncRL_HOTUTrainer_HumanoidHOTUGetup_HOTUBruce_24-05-28_13-51-39-584325_body_amp5/checkpoints/best_ckpt.pth") + parser.add_argument("--llc_ckpt_path", type=str, default="../examples/learning_rl/IsaacGym_RofuncRL/saved_runs/RofuncRL_HOTUTrainer_HumanoidHOTUGetup_HOTUBruce_24-05-28_13-51-39-584325_body_amp5/checkpoints/best_ckpt.pth") custom_args = parser.parse_args() diff --git a/examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidPhysHOI_RofuncRL.py b/examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidPhysHOI_RofuncRL.py index 2eb5873c8..96c8d153f 100644 --- a/examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidPhysHOI_RofuncRL.py +++ b/examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidPhysHOI_RofuncRL.py @@ -103,7 +103,7 @@ def inference(custom_args): parser.add_argument("--rl_device", type=int, default=1) parser.add_argument("--headless", type=str, default="True") parser.add_argument("--inference", action="store_true", help="turn to inference mode while adding this argument") - parser.add_argument("--ckpt_path", type=str, default="/home/ubuntu/Github/Xianova_Robotics/Rofunc-secret/examples/learning_rl/IsaacGym_RofuncRL/runs/RofuncRL_PhysHOITrainer_HumanoidPhysHOI_24-04-23_18-21-03-579079/checkpoints/best_ckpt.pth") + parser.add_argument("--ckpt_path", type=str, default="../examples/learning_rl/IsaacGym_RofuncRL/runs/RofuncRL_PhysHOITrainer_HumanoidPhysHOI_24-04-23_18-21-03-579079/checkpoints/best_ckpt.pth") parser.add_argument("--debug", type=str, default="False") custom_args = parser.parse_args() diff --git a/examples/learning_rl/IsaacLab_RofuncRL/example_isaaclab_env.py b/examples/learning_rl/IsaacLab_RofuncRL/example_isaaclab_env.py new file mode 100644 index 000000000..e93e468a6 --- /dev/null +++ b/examples/learning_rl/IsaacLab_RofuncRL/example_isaaclab_env.py @@ -0,0 +1,105 @@ +""" +Ant (RofuncRL) +=========================== + +Ant RL using RofuncRL +""" + +import sys,os + +sys.path.append("/home/ubuntu/Github/Rofunc") + +import argparse + +from rofunc.config.utils import omegaconf_to_dict, get_config +from rofunc.learning.RofuncRL.tasks import Tasks +from rofunc.learning.RofuncRL.trainers import Trainers +from rofunc.learning.pre_trained_models.download import model_zoo +from rofunc.learning.utils.utils import set_seed +from rofunc.learning.RofuncRL.tasks.utils.env_loaders import load_isaaclab_env +from rofunc.learning.utils.env_wrappers import wrap_env + + +def train(custom_args): + # Config task and trainer parameters for Isaac Gym environments + custom_args.num_envs = 64 if custom_args.agent.upper() in ["SAC", "TD3"] else custom_args.num_envs + + args_overrides = ["task={}".format(custom_args.task), + "train={}{}RofuncRL".format(custom_args.task, custom_args.agent.upper()), + "device_id={}".format(custom_args.sim_device), + "rl_device=cuda:{}".format(custom_args.rl_device), + "headless={}".format(custom_args.headless), + "num_envs={}".format(custom_args.num_envs)] + cfg = get_config('./learning/rl', 'config', args=args_overrides) + cfg_dict = omegaconf_to_dict(cfg.task) + + set_seed(cfg.train.Trainer.seed) + + # Instantiate the Isaac Gym environment + env = load_isaaclab_env(task_name="Isaac-Cartpole-v0", headless=True, num_envs=custom_args.num_envs) + + # Instantiate the RL trainer + trainer = Trainers().trainer_map[custom_args.agent](cfg=cfg, + env=env, + device=cfg.rl_device, + env_name=custom_args.task) + + # Start training + trainer.train() + + +def inference(custom_args): + # Config task and trainer parameters for Isaac Gym environments + args_overrides = ["task={}".format(custom_args.task), + "train={}{}RofuncRL".format(custom_args.task, custom_args.agent.upper()), + "device_id={}".format(custom_args.sim_device), + "rl_device=cuda:{}".format(custom_args.rl_device), + "headless={}".format(False), + "num_envs={}".format(16)] + cfg = get_config('./learning/rl', 'config', args=args_overrides) + cfg_dict = omegaconf_to_dict(cfg.task) + + set_seed(cfg.train.Trainer.seed) + + # Instantiate the Isaac Gym environment + infer_env = Tasks().task_map[custom_args.task](cfg=cfg_dict, + rl_device=cfg.rl_device, + sim_device=f'cuda:{cfg.device_id}', + graphics_device_id=cfg.device_id, + headless=cfg.headless, + virtual_screen_capture=cfg.capture_video, # TODO: check + force_render=cfg.force_render) + + # Instantiate the RL trainer + trainer = Trainers().trainer_map[custom_args.agent](cfg=cfg, + env=infer_env, + device=cfg.rl_device, + env_name=custom_args.task, + inference=True) + # load checkpoint + if custom_args.ckpt_path is None: + custom_args.ckpt_path = model_zoo(name="AntRofuncRLPPO.pth") + trainer.agent.load_ckpt(custom_args.ckpt_path) + + # Start inference + trainer.inference() + + +if __name__ == '__main__': + gpu_id = 0 + + parser = argparse.ArgumentParser() + parser.add_argument("--task", type=str, default="Cartpole") + parser.add_argument("--agent", type=str, default="ppo") # Available agents: ppo, sac, td3, a2c + parser.add_argument("--num_envs", type=int, default=4096) + parser.add_argument("--sim_device", type=int, default=0) + parser.add_argument("--rl_device", type=int, default=gpu_id) + parser.add_argument("--headless", type=str, default="True") + parser.add_argument("--inference", action="store_true", help="turn to inference mode while adding this argument") + parser.add_argument("--ckpt_path", type=str, default=None) + custom_args = parser.parse_args() + + if not custom_args.inference: + train(custom_args) + else: + inference(custom_args) diff --git a/examples/robolab/example_coordinate_transform.py b/examples/robolab/example_coordinate_transform.py index 9d141c2cc..a5ab2091b 100644 --- a/examples/robolab/example_coordinate_transform.py +++ b/examples/robolab/example_coordinate_transform.py @@ -7,6 +7,7 @@ import rofunc as rf +# Quaternion convert quat = [0.234, 0.23, 0.4, 1.3] mat = rf.robolab.convert_ori_format(quat, "quat", "mat") @@ -28,7 +29,7 @@ # [-0.2098, 0.4048, 0.8900, 1.0000], # [ 0.0000, 0.0000, 0.0000, 1.0000]]]) - +# Rotation matrix convert rot = [[0.7825, -0.4763, 0.4011], [0.5862, 0.7806, -0.2168], [-0.2098, 0.4048, 0.8900]] @@ -44,7 +45,7 @@ # [Rofunc:INFO] Quaternion: tensor([[0.1673, 0.1644, 0.2859, 0.9291]]) # [Rofunc:INFO] Euler angles: tensor([[0.4269, 0.2114, 0.6429]]) - +# Euler convert euler = [0.4268, 0.2114, 0.6430] quat = rf.robolab.convert_ori_format(euler, "euler", "quat") mat = rf.robolab.convert_ori_format(euler, "euler", "mat") @@ -57,3 +58,11 @@ # tensor([[[ 0.7825, -0.4763, 0.4011], # [ 0.5863, 0.7806, -0.2168], # [-0.2098, 0.4047, 0.8900]]]) + +# Quaternion multiplication +quat1 = [[-0.436865, 0.49775, 0.054428, 0.747283], [0, 0, 1, 0]] +quat2 = [0.707, 0, 0, 0.707] +quat3 = rf.robolab.quat_multiply(quat1, quat2) +rf.logger.beauty_print(f"Result: {rf.robolab.check_quat_tensor(quat3)}") +# [Rofunc:INFO] Result: tensor([[ 0.2195, 0.3904, -0.3135, 0.8373], +# [ 0.0000, 0.7071, 0.7071, 0.0000]]) diff --git a/examples/robolab/example_forward_dynamics.py b/examples/robolab/example_forward_dynamics.py new file mode 100644 index 000000000..8baef5cc7 --- /dev/null +++ b/examples/robolab/example_forward_dynamics.py @@ -0,0 +1,72 @@ +""" +FD from models +======================== + +Forward dynamics from URDF or MuJoCo XML files. +""" + +import pprint +import math + +import rofunc as rf + +rf.logger.beauty_print("########## Forward kinematics from URDF or MuJoCo XML files with RobotModel class ##########") +rf.logger.beauty_print("---------- Forward kinematics for Franka Panda using URDF file ----------") +model_path = "../../rofunc/simulator/assets/urdf/franka_description/robots/franka_panda.urdf" + +joint_value = [[0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0, 0.0, 0.0], + [0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0, 0.0, 0.0]] +export_link = "panda_hand" + +# # Build the robot model with kinpy +# # Deprecated: kinpy is not supported anymore, just for checking the results!!!! Please use pytorch_kinematics instead. +# robot = rf.robolab.RobotModel(model_path, solve_engine="kinpy", verbose=False) +# # Show the robot chain and joint names, can also be done by verbose=True +# robot.show_chain() +# # Get the forward kinematics of export_link +# pos, rot, ret = robot.get_fk(joint_value, export_link) +# +# # Convert the orientation representation and print the results +# rot = rf.robolab.convert_quat_order(rot, "wxyz", "xyzw") +# rf.logger.beauty_print(f"Position of {export_link}: {pos}") +# rf.logger.beauty_print(f"Rotation of {export_link}: {rot}") +# pprint.pprint(ret, width=1) + +# Try the same thing with pytorch_kinematics +robot = rf.robolab.RobotModel(model_path, solve_engine="pytorch_kinematics", verbose=False) +pos, rot, ret = robot.get_fk(joint_value, export_link) +rf.logger.beauty_print(f"Position of {export_link}: {pos}") +rf.logger.beauty_print(f"Rotation of {export_link}: {rot}") +pprint.pprint(ret, width=1) + +# rf.logger.beauty_print("---------- Forward kinematics for Bruce Humanoid Robot using MJCF file ----------") +model_path = "../../rofunc/simulator/assets/mjcf/bruce/bruce.xml" +joint_value = [0.0 for _ in range(16)] + +export_link = "elbow_pitch_link_r" + +# # Build the robot model with pytorch_kinematics, kinpy is not supported for MJCF files +robot = rf.robolab.RobotModel(model_path, solve_engine="pytorch_kinematics", verbose=True) +# Get the forward kinematics of export_link +pos, rot, ret = robot.get_fk(joint_value, export_link) +# +# # Print the results +# rf.logger.beauty_print(f"Position of {export_link}: {pos}") +# rf.logger.beauty_print(f"Rotation of {export_link}: {rot}") +# pprint.pprint(ret, width=1) + + +model_path = "../../rofunc/simulator/assets/mjcf/hotu/hotu_humanoid.xml" +joint_value = [0.1 for _ in range(34)] + +export_link = "left_hand_link_2" + +# # Build the robot model with pytorch_kinematics, kinpy is not supported for MJCF files +robot = rf.robolab.RobotModel(model_path, solve_engine="pytorch_kinematics", verbose=True) +# Get the forward kinematics of export_link +pos, rot, ret = robot.get_fk(joint_value, export_link) + +# # Print the results +rf.logger.beauty_print(f"Position of {export_link}: {pos}") +rf.logger.beauty_print(f"Rotation of {export_link}: {rot}") +pprint.pprint(ret, width=1) diff --git a/examples/robolab/example_forward_kinematics.py b/examples/robolab/example_forward_kinematics.py index 7cbadefd7..a788ebf30 100644 --- a/examples/robolab/example_forward_kinematics.py +++ b/examples/robolab/example_forward_kinematics.py @@ -11,36 +11,36 @@ import rofunc as rf -# rf.logger.beauty_print("########## Forward kinematics from URDF or MuJoCo XML files with RobotModel class ##########") -# rf.logger.beauty_print("---------- Forward kinematics for Franka Panda using URDF file ----------") -# model_path = "../../rofunc/simulator/assets/urdf/franka_description/robots/franka_panda.urdf" -# -# joint_value = [[0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0, 0.0, 0.0], -# [0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0, 0.0, 0.0]] -# export_link = "panda_hand" -# -# # # Build the robot model with kinpy -# # # Deprecated: kinpy is not supported anymore, just for checking the results!!!! Please use pytorch_kinematics instead. -# # robot = rf.robolab.RobotModel(model_path, solve_engine="kinpy", verbose=False) -# # # Show the robot chain and joint names, can also be done by verbose=True -# # robot.show_chain() -# # # Get the forward kinematics of export_link -# # pos, rot, ret = robot.get_fk(joint_value, export_link) -# # -# # # Convert the orientation representation and print the results -# # rot = rf.robolab.convert_quat_order(rot, "wxyz", "xyzw") -# # rf.logger.beauty_print(f"Position of {export_link}: {pos}") -# # rf.logger.beauty_print(f"Rotation of {export_link}: {rot}") -# # pprint.pprint(ret, width=1) -# -# # Try the same thing with pytorch_kinematics -# robot = rf.robolab.RobotModel(model_path, solve_engine="pytorch_kinematics", verbose=False) +rf.logger.beauty_print("########## Forward kinematics from URDF or MuJoCo XML files with RobotModel class ##########") +rf.logger.beauty_print("---------- Forward kinematics for Franka Panda using URDF file ----------") +model_path = "../../rofunc/simulator/assets/urdf/franka_description/robots/franka_panda.urdf" + +joint_value = [[0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0, 0.0, 0.0], + [0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0, 0.0, 0.0]] +export_link = "panda_hand" + +# # Build the robot model with kinpy +# # Deprecated: kinpy is not supported anymore, just for checking the results!!!! Please use pytorch_kinematics instead. +# robot = rf.robolab.RobotModel(model_path, solve_engine="kinpy", verbose=False) +# # Show the robot chain and joint names, can also be done by verbose=True +# robot.show_chain() +# # Get the forward kinematics of export_link # pos, rot, ret = robot.get_fk(joint_value, export_link) +# +# # Convert the orientation representation and print the results +# rot = rf.robolab.convert_quat_order(rot, "wxyz", "xyzw") # rf.logger.beauty_print(f"Position of {export_link}: {pos}") # rf.logger.beauty_print(f"Rotation of {export_link}: {rot}") # pprint.pprint(ret, width=1) -# -# rf.logger.beauty_print("---------- Forward kinematics for Bruce Humanoid Robot using MJCF file ----------") + +# Try the same thing with pytorch_kinematics +robot = rf.robolab.RobotModel(model_path, solve_engine="pytorch_kinematics", verbose=False) +pos, rot, ret = robot.get_fk(joint_value, export_link) +rf.logger.beauty_print(f"Position of {export_link}: {pos}") +rf.logger.beauty_print(f"Rotation of {export_link}: {rot}") +pprint.pprint(ret, width=1) + +rf.logger.beauty_print("---------- Forward kinematics for Bruce Humanoid Robot using MJCF file ----------") model_path = "../../rofunc/simulator/assets/mjcf/bruce/bruce.xml" joint_value = [0.0 for _ in range(16)] @@ -56,13 +56,11 @@ # rf.logger.beauty_print(f"Rotation of {export_link}: {rot}") # pprint.pprint(ret, width=1) - - - +rf.logger.beauty_print("---------- Forward kinematics for United Digital Human (UDH) using MJCF file ----------") model_path = "../../rofunc/simulator/assets/mjcf/hotu/hotu_humanoid.xml" joint_value = [0.1 for _ in range(34)] -export_link = "left_hand_link_2" +export_link = "left_hand_2" # # Build the robot model with pytorch_kinematics, kinpy is not supported for MJCF files robot = rf.robolab.RobotModel(model_path, solve_engine="pytorch_kinematics", verbose=True) diff --git a/examples/robolab/example_inverse_kinematics.py b/examples/robolab/example_inverse_kinematics.py index dccc895d7..690d80004 100644 --- a/examples/robolab/example_inverse_kinematics.py +++ b/examples/robolab/example_inverse_kinematics.py @@ -5,18 +5,21 @@ Inverse kinematics from URDF or MuJoCo XML files. """ -import pprint - -import math -import torch import rofunc as rf rf.logger.beauty_print("########## Inverse kinematics from URDF or MuJoCo XML files with RobotModel class ##########") rf.logger.beauty_print("---------- Inverse kinematics for Franka Panda using URDF file ----------") -model_path = "/home/ubuntu/Github/Xianova_Robotics/Rofunc-secret/rofunc/simulator/assets/urdf/franka_description/robots/franka_panda.urdf" +model_path = "../../rofunc/simulator/assets/urdf/franka_description/robots/franka_panda.urdf" ee_pose = [0, 0, 0, 0, 0, 0, 1] -export_link = "panda_hand_frame" -robot = rf.robolab.RobotModel(model_path, solve_engine="kinpy", verbose=True) +export_link = "panda_hand" +robot = rf.robolab.RobotModel(model_path, solve_engine="pytorch_kinematics", verbose=True) + +# Get ik solution ret = robot.get_ik(ee_pose, export_link) -print(ret) +print(ret.solutions) + +# Get ik solution near the current configuration +cur_configs = [[-1.7613, 2.7469, -3.5611, -3.8847, 2.7940, 1.9055, 1.9879]] +ret = robot.get_ik(ee_pose, export_link, cur_configs=cur_configs) +print(ret.solutions) diff --git a/examples/robolab/example_jacobain.py b/examples/robolab/example_jacobain.py new file mode 100644 index 000000000..8ad595205 --- /dev/null +++ b/examples/robolab/example_jacobain.py @@ -0,0 +1,26 @@ +""" +Jacobian from models +======================== + +Jacobian from URDF or MuJoCo XML files. +""" + +import rofunc as rf + +rf.logger.beauty_print("########## Jacobian from URDF or MuJoCo XML files with RobotModel class ##########") +model_path = "../../rofunc/simulator/assets/mjcf/bruce/bruce.xml" +joint_value = [0.1 for _ in range(16)] + +export_link = "elbow_pitch_link_r" + +# # Build the robot model with pytorch_kinematics, kinpy is not supported for MJCF files +robot = rf.robolab.RobotModel(model_path, solve_engine="pytorch_kinematics", verbose=True) + +# Get the jacobian of export_link +J = robot.get_jacobian(joint_value, export_link) +print(J) + +# Get the jacobian at a point offset from the export_link +point = [0.1, 0.1, 0.1] +J = robot.get_jacobian(joint_value, export_link, locations=point) +print(J) diff --git a/examples/robolab/example_rdf.py b/examples/robolab/example_rdf.py deleted file mode 100644 index 726c1bdf8..000000000 --- a/examples/robolab/example_rdf.py +++ /dev/null @@ -1,94 +0,0 @@ -""" -Robot distance field (RDF) -======================== - -This example demonstrates how to use the RDF class to train a Bernstein Polynomial model for the robot distance field -from URDF/MJCF files and visualize the reconstructed whole body. -""" - -import argparse -import os -import time - -import numpy as np -import torch - -import rofunc as rf - - -def rdf_from_robot_model(args): - rdf_bp = rf.robolab.rdf.RDF(args) - - # train Bernstein Polynomial model - if args.train: - rdf_bp.train() - - # load trained model - rdf_model_path = os.path.join(args.robot_asset_root, 'rdf/BP', f'BP_{args.n_func}.pt') - rdf_model = torch.load(rdf_model_path) - - # visualize the Bernstein Polynomial model for each robot link - rdf_bp.create_surface_mesh(rdf_model, nbData=128, vis=False, save_mesh_name=f'BP_{args.n_func}') - - joint_max = rdf_bp.robot.joint_limit_max - joint_min = rdf_bp.robot.joint_limit_min - num_joint = rdf_bp.robot.num_joint - # joint_value = torch.rand(num_joint).to(args.device) * (joint_max - joint_min) + joint_min - joint_value = torch.zeros(num_joint).to(args.device) - - trans_dict = rdf_bp.robot.get_trans_dict(joint_value) - # visualize the Bernstein Polynomial model for the whole body - rdf_bp.visualize_reconstructed_whole_body(rdf_model, trans_dict, tag=f'BP_{args.n_func}') - - # run RDF - x = torch.rand(10, 3).to(args.device) * 2.0 - 1.0 - joint_value = torch.rand(100, rdf_bp.robot.num_joint).to(args.device).float() - base_trans = torch.from_numpy(np.identity(4)).to(args.device).reshape(-1, 4, 4).expand(len(joint_value), 4, - 4).float().to(args.device) - - start_time = time.time() - sdf, gradient = rdf_bp.get_whole_body_sdf_batch(x, joint_value, rdf_model, base_trans=base_trans, - use_derivative=True) - print('Time cost:', time.time() - start_time) - print('sdf:', sdf.shape, 'gradient:', gradient.shape) - - start_time = time.time() - sdf, joint_grad = rdf_bp.get_whole_body_sdf_with_joints_grad_batch(x, joint_value, rdf_model, base_trans=base_trans) - print('Time cost:', time.time() - start_time) - print('sdf:', sdf.shape, 'joint gradient:', joint_grad.shape) - - # visualize the 2D & 3D SDF with gradient - # joint_value = torch.zeros(num_joint).to(args.device).reshape((-1, num_joint)) - - joint_value = (torch.rand(num_joint).to(args.device).reshape((-1, num_joint))*0.5 * (joint_max - joint_min) + joint_min) - rf.robolab.rdf.plot_2D_panda_sdf(joint_value, rdf_bp, nbData=80, model=rdf_model, device=args.device) - rf.robolab.rdf.plot_3D_panda_with_gradient(joint_value, rdf_bp, model=rdf_model, device=args.device) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--device', default='cuda', type=str) - parser.add_argument('--domain_max', default=1.0, type=float) - parser.add_argument('--domain_min', default=-1.0, type=float) - parser.add_argument('--n_func', default=8, type=int) - parser.add_argument('--train_epoch', default=200, type=int) - parser.add_argument('--train', action='store_true') - parser.add_argument('--save_mesh_dict', action='store_false') - parser.add_argument('--sampled_points', action='store_false') - parser.add_argument('--parallel', action='store_false') - # parser.add_argument('--robot_asset_root', default="../../rofunc/simulator/assets/urdf/alicia", type=str) - # parser.add_argument('--robot_asset_name', default="Alicia_0624.xml", type=str) - # parser.add_argument('--robot_asset_root', default="../../rofunc/simulator/assets/urdf/franka_description", type=str) - # parser.add_argument('--robot_asset_name', default="robots/franka_panda.urdf", type=str) - parser.add_argument('--robot_asset_root', default="../../rofunc/simulator/assets/mjcf/bruce", type=str) - parser.add_argument('--robot_asset_name', default="bruce.xml", type=str) - # parser.add_argument('--robot_asset_root', default="../../rofunc/simulator/assets/mjcf/hotu", type=str) - # parser.add_argument('--robot_asset_name', default="hotu_humanoid.xml", type=str) - parser.add_argument('--rdf_model_path', default=None) - parser.add_argument('--joint_conf_path', default=None) - parser.add_argument('--sampled_points_dir', default=None, type=str) - args = parser.parse_args() - args.rdf_model_path = f"{args.robot_asset_root}/rdf/BP/BP_{args.n_func}.pt" - args.joint_conf_path = f"{args.robot_asset_root}/rdf/BP/joint_conf.pt" - - rdf_from_robot_model(args) diff --git a/examples/robolab/example_rdf_bbo_planning.py b/examples/robolab/example_rdf_bbo_planning.py deleted file mode 100644 index c90e456b4..000000000 --- a/examples/robolab/example_rdf_bbo_planning.py +++ /dev/null @@ -1,110 +0,0 @@ -""" -Bimanual box carrying using Robot distance field (RDF) -======================== - -This example plans the contacts of a bimanual box carrying task via optimization based on RDF. -""" -import argparse - -import numpy as np -import torch -import trimesh - -import rofunc as rf - - -def box_carrying_contact_rdf(args): - box_size = np.array([0.18, 0.1, 0.16]) - box_pos = np.array([0.7934301890820722, 0.0, 0.3646743147850761]) - box_rotation = np.array([[0, 1.57, 0], - [-1.57, 0, 0], - [0, 0, 1]]) - - rdf_model = torch.load(args.rdf_model_path) - bbo_planner = rf.robolab.rdf.BBOPlanner(args, rdf_model, box_size, box_pos, box_rotation) - num_joint = bbo_planner.rdf_bp.robot.num_joint - - # contact points - contact_points = bbo_planner.contact_points - p_l, p_r, n_l, n_r = contact_points[0], contact_points[1], contact_points[2], contact_points[3] - - # initial joint value - joint_max = bbo_planner.rdf_bp.robot.joint_limit_max - joint_min = bbo_planner.rdf_bp.robot.joint_limit_min - mid_l = torch.rand(num_joint).to(args.device) * (joint_max - joint_min) + joint_min - mid_r = torch.rand(num_joint).to(args.device) * (joint_max - joint_min) + joint_min - - # planning for both arm - base_pose_l = torch.from_numpy(np.identity(4)).to(args.device).reshape(-1, 4, 4).float() - base_pose_r = torch.from_numpy(np.identity(4)).to(args.device).reshape(-1, 4, 4).float() - base_pose_l[0] = rf.robolab.homo_matrix_from_quat_tensor([-0.436865, 0.49775, 0.054428, 0.747283], - [0.396519, 0.07, 0.644388])[0].to(args.device) - base_pose_r[0] = rf.robolab.homo_matrix_from_quat_tensor([0.436865, 0.49775, -0.054428, 0.747283], - [0.396519, -0.07, 0.644388])[0].to(args.device) - # base_pose_l[0, :3, 3] = torch.tensor([0.4, 0.3, 0]).to(args.device) - # base_pose_r[0, :3, 3] = torch.tensor([0.4, -0.3, 0]).to(args.device) - - joint_value_left = bbo_planner.optimizer(p_l, n_l, mid_l, base_trans=base_pose_l, batch=64) - joint_value_right = bbo_planner.optimizer(p_r, n_r, mid_r, base_trans=base_pose_r, batch=64) - joint_conf = { - 'joint_value_left': joint_value_left, - 'joint_value_right': joint_value_right - } - torch.save(joint_conf, args.joint_conf_path) - - # load planned joint conf - data = torch.load(args.joint_conf_path) - joint_value_left = data['joint_value_left'] - joint_value_right = data['joint_value_right'] - print('joint_value_left', joint_value_left.shape, 'joint_value_right', joint_value_right.shape) - - # visualize planning results - scene = trimesh.Scene() - pc1 = trimesh.PointCloud(bbo_planner.object_internal_points.detach().cpu().numpy(), colors=[0, 255, 0]) - pc2 = trimesh.PointCloud(p_l.detach().cpu().numpy(), colors=[255, 0, 0]) - pc3 = trimesh.PointCloud(p_r.detach().cpu().numpy(), colors=[255, 0, 0]) - scene.add_geometry([pc1, pc2, pc3]) - scene.add_geometry(bbo_planner.object_mesh) - - # visualize the final joint configuration - for t_l, t_r in zip(joint_value_left, joint_value_right): - print('t left:', t_l) - robot_l = bbo_planner.rdf_bp.robot.get_forward_robot_mesh(t_l.reshape(-1, num_joint), base_pose_l)[0] - robot_l = np.sum(robot_l) - robot_l.visual.face_colors = [150, 150, 200, 200] - scene.add_geometry(robot_l, node_name='robot_l') - - print('t right:', t_r) - robot_r = bbo_planner.rdf_bp.robot.get_forward_robot_mesh(t_r.reshape(-1, num_joint), base_pose_r)[0] - robot_r = np.sum(robot_r) - robot_r.visual.face_colors = [150, 200, 150, 200] - scene.add_geometry(robot_r, node_name='robot_r') - scene.show() - - scene.delete_geometry('robot_l') - scene.delete_geometry('robot_r') - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--device', default='cuda', type=str) - parser.add_argument('--domain_max', default=1.0, type=float) - parser.add_argument('--domain_min', default=-1.0, type=float) - parser.add_argument('--n_func', default=8, type=int) - parser.add_argument('--train_epoch', default=200, type=int) - parser.add_argument('--train', action='store_true') - parser.add_argument('--save_mesh_dict', action='store_false') - parser.add_argument('--load_sampled_points', action='store_false') - parser.add_argument('--parallel', action='store_true') - parser.add_argument('--robot_asset_root', default="../../rofunc/simulator/assets/urdf/alicia", type=str) - parser.add_argument('--robot_asset_name', default="Alicia_0624.xml", type=str) - # parser.add_argument('--robot_asset_root', default="../../rofunc/simulator/assets/urdf/franka_description", type=str) - # parser.add_argument('--robot_asset_name', default="robots/franka_panda.urdf", type=str) - parser.add_argument('--rdf_model_path', default=None) - parser.add_argument('--joint_conf_path', default=None) - parser.add_argument('--sampled_points_dir', default=None, type=str) - args = parser.parse_args() - args.rdf_model_path = f"{args.robot_asset_root}/rdf/BP/BP_{args.n_func}.pt" - args.joint_conf_path = f"{args.robot_asset_root}/rdf/BP/joint_conf.pt" - - box_carrying_contact_rdf(args) diff --git a/examples/simulator/example_robot_play.py b/examples/simulator/example_robot_play.py index 7a434ac27..a142e9fd2 100644 --- a/examples/simulator/example_robot_play.py +++ b/examples/simulator/example_robot_play.py @@ -13,7 +13,7 @@ def hotu_random(): - file = "/home/ubuntu/Github/Xianova_Robotics/Rofunc-secret/examples/data/hotu2/20240509/Ramdom (good)_Take 2024-05-09 04.49.16 PM_optitrack2hotu.npy" + file = "../examples/data/hotu2/20240509/Ramdom (good)_Take 2024-05-09 04.49.16 PM_optitrack2hotu.npy" motion = SkeletonMotion.from_file(file) body_links = {"right_hand": gymapi.AXIS_ALL, "left_hand": gymapi.AXIS_ALL, "right_foot": gymapi.AXIS_ROTATION, "left_foot": gymapi.AXIS_ROTATION, @@ -90,7 +90,7 @@ def hotu_random(): def h1_random(): - file = "/home/ubuntu/Github/Xianova_Robotics/Rofunc-secret/examples/data/hotu2/20240509/Ramdom (good)_Take 2024-05-09 04.49.16 PM_optitrack2h1.npy" + file = "../examples/data/hotu2/20240509/Ramdom (good)_Take 2024-05-09 04.49.16 PM_optitrack2h1.npy" motion = SkeletonMotion.from_file(file) body_links = {"torso_link": gymapi.AXIS_ALL, "right_elbow_link": gymapi.AXIS_ROTATION, "left_elbow_link": gymapi.AXIS_ROTATION, @@ -172,7 +172,7 @@ def h1_random(): def zju_random(): - file = "/home/ubuntu/Github/Xianova_Robotics/Rofunc-secret/examples/data/hotu2/20240509/Ramdom (good)_Take 2024-05-09 04.49.16 PM_optitrack2zju.npy" + file = "../examples/data/hotu2/20240509/Ramdom (good)_Take 2024-05-09 04.49.16 PM_optitrack2zju.npy" motion = SkeletonMotion.from_file(file) body_links = {"pelvis": gymapi.AXIS_ROTATION, "FOREARM_R": gymapi.AXIS_ROTATION, "FOREARM_L": gymapi.AXIS_ROTATION, @@ -254,7 +254,7 @@ def zju_random(): def walker_random(): - file = "/home/ubuntu/Github/Xianova_Robotics/Rofunc-secret/examples/data/hotu2/20240509/Ramdom (good)_Take 2024-05-09 04.49.16 PM_optitrack2walker.npy" + file = "../examples/data/hotu2/20240509/Ramdom (good)_Take 2024-05-09 04.49.16 PM_optitrack2walker.npy" motion = SkeletonMotion.from_file(file) body_links = {"torso": gymapi.AXIS_ROTATION, "right_limb_l4": gymapi.AXIS_ROTATION, "left_limb_l4": gymapi.AXIS_ROTATION, @@ -304,7 +304,7 @@ def walker_random(): def bruce_random(): - file = "/home/ubuntu/Github/Xianova_Robotics/Rofunc-secret/examples/data/hotu2/20240509/Ramdom (good)_Take 2024-05-09 04.49.16 PM_optitrack2bruce.npy" + file = "../examples/data/hotu2/20240509/Ramdom (good)_Take 2024-05-09 04.49.16 PM_optitrack2bruce.npy" motion = SkeletonMotion.from_file(file) body_links = { "hand_l": gymapi.AXIS_TRANSLATION, "hand_r": gymapi.AXIS_TRANSLATION, @@ -354,7 +354,7 @@ def bruce_random(): def curi_random(): - file = "/home/ubuntu/Github/Xianova_Robotics/Rofunc-secret/examples/data/hotu2/20240509/Ramdom (good)_Take 2024-05-09 04.49.16 PM_optitrack2curi.npy" + file = "../examples/data/hotu2/20240509/Ramdom (good)_Take 2024-05-09 04.49.16 PM_optitrack2curi.npy" motion = SkeletonMotion.from_file(file) body_links = { # "torso_base2": gymapi.AXIS_ROTATION, "root": gymapi.AXIS_ROTATION, diff --git a/rofunc/__init__.py b/rofunc/__init__.py index 4f8426414..eb30dabb4 100644 --- a/rofunc/__init__.py +++ b/rofunc/__init__.py @@ -10,19 +10,20 @@ warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.simplefilter('ignore', DeprecationWarning) -try: - import pbdlib -except ImportError: - print("pbdlib is not installed. Install it automatically...") - pip.main( - ['install', 'https://github.com/Skylark0924/Rofunc/releases/download/v0.0.2.3/pbdlib-0.1-py3-none-any.whl']) +# try: +# import pbdlib +# except ImportError: +# print("pbdlib is not installed. Install it automatically...") +# pip.main( +# ['install', 'https://github.com/Skylark0924/Rofunc/releases/download/v0.0.2.3/pbdlib-0.1-py3-none-any.whl']) from .devices import zed, xsens, optitrack, mmodal, emg from . import simulator as sim -from .learning import ml, RofuncIL, RofuncRL -from .planning_control import lqt, lqr +# from .learning import ml +from .learning import RofuncIL, RofuncRL +# from .planning_control import lqt, lqr from .utils import visualab, robolab, logger, oslab from .utils.datalab import primitive, data_generator from . import config -from .learning.ml import tpgmm, gmr, tpgmr +# from .learning.ml import tpgmm, gmr, tpgmr diff --git a/rofunc/devices/emg/export.py b/rofunc/devices/emg/export.py index 2b0e1d8c0..913ad830d 100644 --- a/rofunc/devices/emg/export.py +++ b/rofunc/devices/emg/export.py @@ -1,5 +1,4 @@ import matplotlib.pyplot as plt -import neurokit2 as nk import numpy as np SAMPING_RATE = 2000 @@ -19,6 +18,8 @@ def process_one_channel(data, sampling_rate, k): data_mvc: Calculate the Maximum Voluntary Contraction (MVC) of the EMG signals data_abs: Take the absolute value of the EMG signals """ + import neurokit2 as nk + data_filter = [] for i in range(0, len(data) - k + 1, k): data_new = data[i] @@ -50,6 +51,8 @@ def process_all_channels(data, n, sampling_rate, k): DATA_MVC: Calculate the Maximum Voluntary Contraction (MVC) of the EMG signals DATA_ABS: Take the absolute value of the EMG signals """ + import neurokit2 as nk + DATA_FILTER = [] DATA_CLEAN = [] DATA_MVC = [] diff --git a/rofunc/learning/RofuncRL/models/utils.py b/rofunc/learning/RofuncRL/models/utils.py index 97f2241fc..8eeec9f37 100644 --- a/rofunc/learning/RofuncRL/models/utils.py +++ b/rofunc/learning/RofuncRL/models/utils.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Union, List +from collections.abc import Mapping import gym import gymnasium @@ -143,6 +144,8 @@ def get_space_dim(space): dim = 0 for i in range(len(space)): dim += get_space_dim(space[i]) + elif isinstance(space, Mapping): + dim = get_space_dim(space["policy"]) elif isinstance(space, gym.Space) or isinstance(space, gymnasium.Space): dim = space.shape if isinstance(dim, tuple) and len(dim) == 1: diff --git a/rofunc/learning/RofuncRL/state_encoders/graph_encoders.py b/rofunc/learning/RofuncRL/state_encoders/graph_encoders.py index 488884428..0ce8cca51 100644 --- a/rofunc/learning/RofuncRL/state_encoders/graph_encoders.py +++ b/rofunc/learning/RofuncRL/state_encoders/graph_encoders.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import dgl.nn.pytorch as dglnn import torch.nn as nn from .base_encoders import BaseEncoder @@ -20,6 +19,8 @@ class HomoGraphEncoder(BaseEncoder): def __init__(self, in_dim, hidden_dim): + import dgl.nn.pytorch as dglnn + super(HomoGraphEncoder, self).__init__(hidden_dim) # init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. # constant_(x, 0), nn.init.calculate_gain('relu')) @@ -35,6 +36,8 @@ def __init__(self, in_dim, hidden_dim): self.linear = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU()) def forward(self, g, inputs): + import torch.nn.functional as F + import dgl # 应用图卷积和激活函数 h = self.conv1(g, inputs) h = h.view(-1, h.size(1) * h.size(2)) diff --git a/rofunc/learning/RofuncRL/tasks/__init__.py b/rofunc/learning/RofuncRL/tasks/__init__.py index fa7728663..cf2e639cc 100644 --- a/rofunc/learning/RofuncRL/tasks/__init__.py +++ b/rofunc/learning/RofuncRL/tasks/__init__.py @@ -22,12 +22,12 @@ def __init__(self, env_type="isaacgym"): from .isaacgymenv.physhoi.humanoid_physhoi import HumanoidPhysHOITask # from .isaacgymenv.physhoi.physhoi import PhysHOI_BallPlay # from .isaacgymenv.hotu.humanoid_hotu import HumanoidHOTUTask - from .isaacgymenv.hotu.humanoid_hotu_getup import HumanoidHOTUGetupTask - from .isaacgymenv.hotu.humanoid_hotu_perturb import HumanoidHOTUPerturbTask - from .isaacgymenv.hotu.humanoid_view_motion import HumanoidHOTUViewMotionTask - from .isaacgymenv.hotu.humanoid_hotu_heading import HumanoidHOTUHeadingTask - from .isaacgymenv.hotu.humanoid_hotu_location import HumanoidHOTULocationTask - from .isaacgymenv.hotu.humanoid_hotu_style import HumanoidHOTUStyleTask + # from .isaacgymenv.hotu.humanoid_hotu_getup import HumanoidHOTUGetupTask + # from .isaacgymenv.hotu.humanoid_hotu_perturb import HumanoidHOTUPerturbTask + # from .isaacgymenv.hotu.humanoid_view_motion import HumanoidHOTUViewMotionTask + # from .isaacgymenv.hotu.humanoid_hotu_heading import HumanoidHOTUHeadingTask + # from .isaacgymenv.hotu.humanoid_hotu_location import HumanoidHOTULocationTask + # from .isaacgymenv.hotu.humanoid_hotu_style import HumanoidHOTUStyleTask from .isaacgymenv.hands.shadow_hand_block_stack import ShadowHandBlockStackTask from .isaacgymenv.hands.shadow_hand_bottle_cap import ShadowHandBottleCapTask @@ -76,12 +76,12 @@ def __init__(self, env_type="isaacgym"): "HumanoidASEViewMotion": HumanoidASEViewMotionTask, "HumanoidPhysHOI": HumanoidPhysHOITask, # "HumanoidPhysHOI": PhysHOI_BallPlay, - "HumanoidHOTUGetup": HumanoidHOTUGetupTask, - "HumanoidHOTUPerturb": HumanoidHOTUPerturbTask, - "HumanoidHOTUViewMotion": HumanoidHOTUViewMotionTask, - "HumanoidHOTUHeading": HumanoidHOTUHeadingTask, - "HumanoidHOTULocation": HumanoidHOTULocationTask, - "HumanoidHOTUStyle": HumanoidHOTUStyleTask, + # "HumanoidHOTUGetup": HumanoidHOTUGetupTask, + # "HumanoidHOTUPerturb": HumanoidHOTUPerturbTask, + # "HumanoidHOTUViewMotion": HumanoidHOTUViewMotionTask, + # "HumanoidHOTUHeading": HumanoidHOTUHeadingTask, + # "HumanoidHOTULocation": HumanoidHOTULocationTask, + # "HumanoidHOTUStyle": HumanoidHOTUStyleTask, "BiShadowHandOver": ShadowHandOverTask, "BiShadowHandBlockStack": ShadowHandBlockStackTask, diff --git a/rofunc/learning/RofuncRL/tasks/utils/env_loaders.py b/rofunc/learning/RofuncRL/tasks/utils/env_loaders.py index 00cf2adfc..482affaf4 100644 --- a/rofunc/learning/RofuncRL/tasks/utils/env_loaders.py +++ b/rofunc/learning/RofuncRL/tasks/utils/env_loaders.py @@ -1,12 +1,17 @@ import os -import sys import queue +import sys from contextlib import contextmanager +from typing import Optional, Sequence + +from rofunc.utils.logger.beauty_logger import beauty_print __all__ = ["load_isaacgym_env_preview2", "load_isaacgym_env_preview3", "load_isaacgym_env_preview4", - "load_omniverse_isaacgym_env"] + "load_omniverse_isaacgym_env", + "load_isaac_orbit_env", + "load_isaaclab_env"] @contextmanager @@ -25,6 +30,7 @@ def cwd(new_path: str) -> None: finally: os.chdir(current_path) + def _omegaconf_to_dict(config) -> dict: """Convert OmegaConf config to dict @@ -42,6 +48,7 @@ def _omegaconf_to_dict(config) -> dict: d[k] = _omegaconf_to_dict(v) if isinstance(v, DictConfig) else v return d + def _print_cfg(d, indent=0) -> None: """Print the environment configuration @@ -89,21 +96,25 @@ def load_isaacgym_env_preview2(task_name: str = "", isaacgymenvs_path: str = "", if defined: arg_index = sys.argv.index("--task") + 1 if arg_index >= len(sys.argv): - raise ValueError("No task name defined. Set the task_name parameter or use --task as command line argument") + raise ValueError( + "No task name defined. Set the task_name parameter or use --task as command line argument") if task_name and task_name != sys.argv[arg_index]: - print("[WARNING] Overriding task ({}) with command line argument ({})".format(task_name, sys.argv[arg_index])) + print( + "[WARNING] Overriding task ({}) with command line argument ({})".format(task_name, sys.argv[arg_index])) # get task name from function arguments else: if task_name: sys.argv.append("--task") sys.argv.append(task_name) else: - raise ValueError("No task name defined. Set the task_name parameter or use --task as command line argument") + raise ValueError( + "No task name defined. Set the task_name parameter or use --task as command line argument") # get isaacgym envs path from isaacgym package metadata if not isaacgymenvs_path: if not hasattr(isaacgym, "__path__"): - raise RuntimeError("isaacgym package is not installed or could not be accessed by the current Python environment") + raise RuntimeError( + "isaacgym package is not installed or could not be accessed by the current Python environment") path = isaacgym.__path__ path = os.path.join(path[0], "..", "rlgpu") else: @@ -120,8 +131,9 @@ def load_isaacgym_env_preview2(task_name: str = "", isaacgymenvs_path: str = "", status = False print("[ERROR] Failed to import required packages: {}".format(e)) if not status: - raise RuntimeError("The path ({}) is not valid or the isaacgym package is not installed in editable mode (pip install -e .)" \ - .format(path)) + raise RuntimeError( + "The path ({}) is not valid or the isaacgym package is not installed in editable mode (pip install -e .)" \ + .format(path)) args = get_args() @@ -142,6 +154,7 @@ def load_isaacgym_env_preview2(task_name: str = "", isaacgymenvs_path: str = "", return env + def load_isaacgym_env_preview3(task_name: str = "", isaacgymenvs_path: str = "", show_cfg: bool = True): """Load an Isaac Gym environment (preview 3) @@ -169,7 +182,6 @@ def load_isaacgym_env_preview3(task_name: str = "", isaacgymenvs_path: str = "", from omegaconf import OmegaConf - import isaacgym import isaacgymenvs # check task from command line arguments @@ -182,13 +194,14 @@ def load_isaacgym_env_preview3(task_name: str = "", isaacgymenvs_path: str = "", if defined: if task_name and task_name != arg.split("task=")[1].split(" ")[0]: print("[WARNING] Overriding task name ({}) with command line argument ({})" \ - .format(task_name, arg.split("task=")[1].split(" ")[0])) + .format(task_name, arg.split("task=")[1].split(" ")[0])) # get task name from function arguments else: if task_name: sys.argv.append("task={}".format(task_name)) else: - raise ValueError("No task name defined. Set task_name parameter or use task= as command line argument") + raise ValueError( + "No task name defined. Set task_name parameter or use task= as command line argument") # get isaacgymenvs path from isaacgymenvs package metadata if isaacgymenvs_path == "": @@ -248,6 +261,7 @@ def load_isaacgym_env_preview3(task_name: str = "", isaacgymenvs_path: str = "", return env + def load_isaacgym_env_preview4(task_name: str = "", isaacgymenvs_path: str = "", show_cfg: bool = True): """Load an Isaac Gym environment (preview 4) @@ -271,6 +285,7 @@ def load_isaacgym_env_preview4(task_name: str = "", isaacgymenvs_path: str = "", """ return load_isaacgym_env_preview3(task_name, isaacgymenvs_path, show_cfg) + def load_omniverse_isaacgym_env(task_name: str = "", omniisaacgymenvs_path: str = "", show_cfg: bool = True, @@ -322,13 +337,14 @@ def load_omniverse_isaacgym_env(task_name: str = "", if defined: if task_name and task_name != arg.split("task=")[1].split(" ")[0]: print("[WARNING] Overriding task name ({}) with command line argument ({})" \ - .format(task_name, arg.split("task=")[1].split(" ")[0])) + .format(task_name, arg.split("task=")[1].split(" ")[0])) # get task name from function arguments else: if task_name: sys.argv.append("task={}".format(task_name)) else: - raise ValueError("No task name defined. Set task_name parameter or use task= as command line argument") + raise ValueError( + "No task name defined. Set task_name parameter or use task= as command line argument") # get rofunc.learning.RofuncRL.tasks.omniisaacgym path from rofunc.learning.RofuncRL.tasks.omniisaacgym package metadata if omniisaacgymenvs_path == "": @@ -362,7 +378,8 @@ def load_omniverse_isaacgym_env(task_name: str = "", # internal classes class _OmniIsaacGymVecEnv(VecEnvBase): def step(self, actions): - actions = torch.clamp(actions, -self._task.clip_actions, self._task.clip_actions).to(self._task.device).clone() + actions = torch.clamp(actions, -self._task.clip_actions, self._task.clip_actions).to( + self._task.device).clone() self._task.pre_physics_step(actions) for _ in range(self._task.control_frequency_inv): @@ -371,7 +388,8 @@ def step(self, actions): observations, rewards, dones, info = self._task.post_physics_step() - return {"obs": torch.clamp(observations, -self._task.clip_obs, self._task.clip_obs).to(self._task.rl_device).clone()}, \ + return {"obs": torch.clamp(observations, -self._task.clip_obs, self._task.clip_obs).to( + self._task.rl_device).clone()}, \ rewards.to(self._task.rl_device).clone(), dones.to(self._task.rl_device).clone(), info.copy() def reset(self): @@ -397,7 +415,8 @@ def run(self, trainer=None): super().run(_OmniIsaacGymTrainerMT() if trainer is None else trainer) def _parse_data(self, data): - self._observations = torch.clamp(data["obs"], -self._task.clip_obs, self._task.clip_obs).to(self._task.rl_device).clone() + self._observations = torch.clamp(data["obs"], -self._task.clip_obs, self._task.clip_obs).to( + self._task.rl_device).clone() self._rewards = data["rew"].to(self._task.rl_device).clone() self._dones = data["reset"].to(self._task.rl_device).clone() self._info = data["extras"].copy() @@ -449,6 +468,7 @@ def close(self): return env + def load_isaac_orbit_env(task_name: str = "", show_cfg: bool = True): """Load an Isaac Orbit environment @@ -488,16 +508,19 @@ def load_isaac_orbit_env(task_name: str = "", show_cfg: bool = True): if defined: arg_index = sys.argv.index("--task") + 1 if arg_index >= len(sys.argv): - raise ValueError("No task name defined. Set the task_name parameter or use --task as command line argument") + raise ValueError( + "No task name defined. Set the task_name parameter or use --task as command line argument") if task_name and task_name != sys.argv[arg_index]: - print("[WARNING] Overriding task ({}) with command line argument ({})".format(task_name, sys.argv[arg_index])) + print( + "[WARNING] Overriding task ({}) with command line argument ({})".format(task_name, sys.argv[arg_index])) # get task name from function arguments else: if task_name: sys.argv.append("--task") sys.argv.append(task_name) else: - raise ValueError("No task name defined. Set the task_name parameter or use --task as command line argument") + raise ValueError( + "No task name defined. Set the task_name parameter or use --task as command line argument") # parse arguments parser = argparse.ArgumentParser("Welcome to Orbit: Omniverse Robotics Environments!") @@ -543,3 +566,146 @@ def close_the_simulator(): env = gym.make(args.task, cfg=cfg, headless=args.headless) return env + + +def load_isaaclab_env(task_name: str = "", + num_envs: Optional[int] = None, + headless: Optional[bool] = None, + cli_args: Sequence[str] = [], + show_cfg: bool = True): + """Load an Isaac Lab environment + + Isaac Lab: https://isaac-sim.github.io/IsaacLab + + This function includes the definition and parsing of command line arguments used by Isaac Lab: + + - ``--headless``: Force display off at all times + - ``--cpu``: Use CPU pipeline + - ``--num_envs``: Number of environments to simulate + - ``--task``: Name of the task + - ``--num_envs``: Seed used for the environment + + :param task_name: The name of the task (default: ``""``). + If not specified, the task name is taken from the command line argument (``--task TASK_NAME``). + Command line argument has priority over function parameter if both are specified + :type task_name: str, optional + :param num_envs: Number of parallel environments to create (default: ``None``). + If not specified, the default number of environments defined in the task configuration is used. + Command line argument has priority over function parameter if both are specified + :type num_envs: int, optional + :param headless: Whether to use headless mode (no rendering) (default: ``None``). + If not specified, the default task configuration is used. + Command line argument has priority over function parameter if both are specified + :type headless: bool, optional + :param cli_args: Isaac Lab configuration and command line arguments (default: ``[]``) + :type cli_args: list of str, optional + :param show_cfg: Whether to print the configuration (default: ``True``) + :type show_cfg: bool, optional + + :raises ValueError: The task name has not been defined, neither by the function parameter nor by the command line arguments + + :return: Isaac Lab environment + :rtype: gym.Env + """ + import argparse + import atexit + import gymnasium as gym + + # check task from command line arguments + defined = False + for arg in sys.argv: + if arg.startswith("--task"): + defined = True + break + # get task name from command line arguments + if defined: + arg_index = sys.argv.index("--task") + 1 + if arg_index >= len(sys.argv): + raise ValueError( + "No task name defined. Set the task_name parameter or use --task as command line argument") + if task_name and task_name != sys.argv[arg_index]: + beauty_print(f"Overriding task ({task_name}) with command line argument ({sys.argv[arg_index]})", "warning") + # get task name from function arguments + else: + if task_name: + sys.argv.append("--task") + sys.argv.append(task_name) + else: + raise ValueError( + "No task name defined. Set the task_name parameter or use --task as command line argument") + + # check num_envs from command line arguments + defined = False + for arg in sys.argv: + if arg.startswith("--num_envs"): + defined = True + break + # get num_envs from command line arguments + if defined: + if num_envs is not None: + beauty_print.warning("Overriding num_envs with command line argument (--num_envs)", "warning") + # get num_envs from function arguments + elif num_envs is not None and num_envs > 0: + sys.argv.append("--num_envs") + sys.argv.append(str(num_envs)) + + # check headless from command line arguments + defined = False + for arg in sys.argv: + if arg.startswith("--headless"): + defined = True + break + # get headless from command line arguments + if defined: + if headless is not None: + beauty_print("Overriding headless with command line argument (--headless)", "warning") + # get headless from function arguments + elif headless is not None: + sys.argv.append("--headless") + + # others command line arguments + sys.argv += cli_args + + # parse arguments + parser = argparse.ArgumentParser("Isaac Lab: Omniverse Robotics Environments!") + parser.add_argument("--cpu", action="store_true", default=False, help="Use CPU pipeline.") + parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.") + parser.add_argument("--task", type=str, default=None, help="Name of the task.") + parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment") + parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.") + parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).") + parser.add_argument("--video_interval", type=int, default=2000, + help="Interval between video recordings (in steps).") + parser.add_argument("--disable_fabric", action="store_true", default=False, + help="Disable fabric and use USD I/O operations.") + parser.add_argument("--distributed", action="store_true", default=False, + help="Run training with multiple GPUs or nodes.") + + # launch the simulation app + from omni.isaac.lab.app import AppLauncher + + AppLauncher.add_app_launcher_args(parser) + args = parser.parse_args() + app_launcher = AppLauncher(args) + + @atexit.register + def close_the_simulator(): + app_launcher.app.close() + + import omni.isaac.lab_tasks # type: ignore + from omni.isaac.lab_tasks.utils import parse_env_cfg # type: ignore + + cfg = parse_env_cfg(args.task, use_gpu=not args.cpu, num_envs=args.num_envs, use_fabric=not args.disable_fabric) + + # print config + if show_cfg: + print(f"\nIsaac Lab environment ({args.task})") + try: + _print_cfg(cfg) + except AttributeError as e: + pass + + # load environment + env = gym.make(args.task, cfg=cfg, render_mode="rgb_array" if args.video else None) + + return env diff --git a/rofunc/learning/RofuncRL/trainers/__init__.py b/rofunc/learning/RofuncRL/trainers/__init__.py index e75417104..3ff5d7ca9 100644 --- a/rofunc/learning/RofuncRL/trainers/__init__.py +++ b/rofunc/learning/RofuncRL/trainers/__init__.py @@ -7,7 +7,7 @@ def __init__(self): from .amp_trainer import AMPTrainer from .ase_trainer import ASETrainer from .dtrans_trainer import DTransTrainer - from .hotu_trainer import HOTUTrainer + # from .hotu_trainer import HOTUTrainer from .physhoi_trainer import PhysHOITrainer self.trainer_map = { @@ -18,7 +18,7 @@ def __init__(self): "amp": AMPTrainer, "ase": ASETrainer, "dtrans": DTransTrainer, - "hotu": HOTUTrainer, + # "hotu": HOTUTrainer, "physhoi": PhysHOITrainer, } diff --git a/rofunc/learning/RofuncRL/trainers/ppo_trainer.py b/rofunc/learning/RofuncRL/trainers/ppo_trainer.py index 90e0a99fc..343687d83 100644 --- a/rofunc/learning/RofuncRL/trainers/ppo_trainer.py +++ b/rofunc/learning/RofuncRL/trainers/ppo_trainer.py @@ -26,7 +26,8 @@ def __init__(self, cfg, env, device, env_name, inference=False): device, self.exp_dir, self.rofunc_logger) def pre_interaction(self): - self.env.reset_done() + # self.env.reset_done() + pass def post_interaction(self): self._rollout += 1 diff --git a/rofunc/learning/__init__.py b/rofunc/learning/__init__.py index 05cf8e86b..43b82fa2f 100644 --- a/rofunc/learning/__init__.py +++ b/rofunc/learning/__init__.py @@ -3,6 +3,5 @@ shutup.please() -from .ml import * from .RofuncIL import * from .RofuncRL import * diff --git a/rofunc/learning/utils/env_wrappers.py b/rofunc/learning/utils/env_wrappers.py index 0d90bd677..f9a9aef65 100644 --- a/rofunc/learning/utils/env_wrappers.py +++ b/rofunc/learning/utils/env_wrappers.py @@ -12,15 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union, Tuple, Any, Optional +import collections +from typing import Tuple, Any, Optional import gym import gymnasium -import collections import numpy as np -from packaging import version - import torch +from packaging import version from rofunc.utils.logger.beauty_logger import beauty_print @@ -28,7 +27,7 @@ class Wrapper(object): - def __init__(self, env: Any, device=None) -> None: + def __init__(self, env: Any) -> None: """Base wrapper class for RL environments :param env: The environment to wrap @@ -41,8 +40,14 @@ def __init__(self, env: Any, device=None) -> None: self.device = torch.device(self._env.device) else: self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - if device is not None: - self.device = torch.device(device) + # spaces + try: + self._action_space = self._env.single_action_space + self._observation_space = self._env.single_observation_space + except AttributeError: + self._action_space = self._env.action_space + self._observation_space = self._env.observation_space + self._state_space = self._env.state_space if hasattr(self._env, "state_space") else self._observation_space def __getattr__(self, key: str) -> Any: """Get an attribute from the wrapped environment @@ -57,8 +62,7 @@ def __getattr__(self, key: str) -> Any: """ if hasattr(self._env, key): return getattr(self._env, key) - raise AttributeError("Wrapped environment ({}) does not have attribute '{}'" \ - .format(self._env.__class__.__name__, key)) + raise AttributeError(f"Wrapped environment ({self._env.__class__.__name__}) does not have attribute '{key}'") def reset(self) -> Tuple[torch.Tensor, Any]: """Reset the environment @@ -85,17 +89,13 @@ def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch def render(self, *args, **kwargs) -> None: """Render the environment - - :raises NotImplementedError: Not implemented """ - raise NotImplementedError + pass def close(self) -> None: """Close the environment - - :raises NotImplementedError: Not implemented """ - raise NotImplementedError + pass @property def num_envs(self) -> int: @@ -105,6 +105,14 @@ def num_envs(self) -> int: """ return self._env.num_envs if hasattr(self._env, "num_envs") else 1 + @property + def num_agents(self) -> int: + """Number of agents + + If the wrapped environment does not have the ``num_agents`` property, it will be set to 1 + """ + return self._env.num_agents if hasattr(self._env, "num_agents") else 1 + @property def state_space(self) -> gym.Space: """State space @@ -112,19 +120,19 @@ def state_space(self) -> gym.Space: If the wrapped environment does not have the ``state_space`` property, the value of the ``observation_space`` property will be used """ - return self._env.state_space if hasattr(self._env, "state_space") else self._env.observation_space + return self._state_space @property def observation_space(self) -> gym.Space: """Observation space """ - return self._env.observation_space + return self._observation_space @property def action_space(self) -> gym.Space: """Action space """ - return self._env.action_space + return self._action_space class IsaacGymPreview2Wrapper(Wrapper): @@ -250,7 +258,8 @@ def run(self, trainer: Optional["omni.isaac.gym.vec_env.vec_env_mt.TrainerMT"] = self._env.run(trainer) def _process_data(self): - self._obs = torch.clamp(self._obs, -self._env._task.clip_obs, self._env._task.clip_obs).to(self._env._task.rl_device).clone() + self._obs = torch.clamp(self._obs, -self._env._task.clip_obs, self._env._task.clip_obs).to( + self._env._task.rl_device).clone() self._rew = self._rew.to(self._env._task.rl_device).clone() self._states = torch.clamp(self._states, -self._env._task.clip_obs, self._env._task.clip_obs).to( self._env._task.rl_device).clone() @@ -292,7 +301,8 @@ def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch # self._obs_dict, reward, terminated, info = self._env.step(actions) # truncated = torch.zeros_like(terminated) - return self._obs_dict["obs"], self._rew.view(-1, 1), self._resets.view(-1, 1), self._resets.view(-1, 1), self._extras + return self._obs_dict["obs"], self._rew.view(-1, 1), self._resets.view(-1, 1), self._resets.view(-1, + 1), self._extras def reset(self) -> Tuple[torch.Tensor, Any]: """Reset the environment @@ -995,6 +1005,54 @@ def close(self) -> None: self._env.close() +class IsaacLabWrapper(Wrapper): + def __init__(self, env: Any) -> None: + """Isaac Lab environment wrapper + + :param env: The environment to wrap + :type env: Any supported Isaac Lab environment + """ + super().__init__(env) + + self._reset_once = True + self._obs_dict = None + + self._observation_space = self._observation_space["policy"] + + def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]: + """Perform a step in the environment + + :param actions: The actions to perform + :type actions: torch.Tensor + + :return: Observation, reward, terminated, truncated, info + :rtype: tuple of torch.Tensor and any other info + """ + self._obs_dict, reward, terminated, truncated, info = self._env.step(actions) + return self._obs_dict["policy"], reward.view(-1, 1), terminated.view(-1, 1), truncated.view(-1, 1), info + + def reset(self) -> Tuple[torch.Tensor, Any]: + """Reset the environment + + :return: Observation, info + :rtype: torch.Tensor and any other info + """ + if self._reset_once: + self._obs_dict, info = self._env.reset() + self._reset_once = False + return self._obs_dict["policy"], {} + + def render(self, *args, **kwargs) -> None: + """Render the environment + """ + pass + + def close(self) -> None: + """Close the environment + """ + self._env.close() + + def wrap_env(env: Any, wrapper: str = "auto", verbose: bool = True, logger=None, seed=None) -> Wrapper: """ Wrap an environment to use a common interface @@ -1028,6 +1086,8 @@ def wrap_env(env: Any, wrapper: str = "auto", verbose: bool = True, logger=None, +--------------------+-------------------------+ |Isaac Sim (orbit) |``"isaac-orbit"`` | +--------------------+-------------------------+ + |Isaac Lab |``"isaaclab"`` | + +--------------------+-------------------------+ :param verbose: Whether to print the wrapper type (default: True) :param logger: rofunc logger (default: None) :param seed: random seed for env (default: None) @@ -1109,5 +1169,9 @@ def wrap_env(env: Any, wrapper: str = "auto", verbose: bool = True, logger=None, if verbose: logger.info("Environment wrapper: Isaac Orbit") return IsaacOrbitWrapper(env) + elif wrapper == "isaaclab": + if verbose: + logger.info("Environment wrapper: Isaac Lab") + return IsaacLabWrapper(env) else: raise ValueError("Unknown {} wrapper type".format(wrapper)) diff --git a/rofunc/simulator/assets/mjcf/bruce/bruce.xml b/rofunc/simulator/assets/mjcf/bruce/bruce.xml index 51db09eeb..ca9212294 100644 --- a/rofunc/simulator/assets/mjcf/bruce/bruce.xml +++ b/rofunc/simulator/assets/mjcf/bruce/bruce.xml @@ -36,7 +36,7 @@ - + diff --git a/rofunc/simulator/assets/mjcf/hotu/hotu_humanoid.xml b/rofunc/simulator/assets/mjcf/hotu/hotu_humanoid.xml index c034dd772..7c68eb0a0 100644 --- a/rofunc/simulator/assets/mjcf/hotu/hotu_humanoid.xml +++ b/rofunc/simulator/assets/mjcf/hotu/hotu_humanoid.xml @@ -20,9 +20,9 @@ - + - + diff --git a/rofunc/simulator/utils/dae2stl.py b/rofunc/simulator/utils/dae2stl.py index 82a4f8ef7..479688ea5 100644 --- a/rofunc/simulator/utils/dae2stl.py +++ b/rofunc/simulator/utils/dae2stl.py @@ -12,8 +12,8 @@ def dae2stl(dae_files, stl_save_path): if __name__ == '__main__': - dae_folder = '/home/ubuntu/Github/Xianova_Robotics/Rofunc-secret/rofunc/simulator/assets/urdf/curi/meshes' + dae_folder = './simulator/assets/urdf/curi/meshes' dae_files = rf.oslab.list_absl_path(dae_folder, recursive=True, suffix='.dae') - stl_save_path = '/home/ubuntu/Github/Xianova_Robotics/Rofunc-secret/rofunc/simulator/assets/urdf/curi/all_visual' + stl_save_path = './simulator/assets/urdf/curi/all_visual' dae2stl(dae_files[6:12], stl_save_path) diff --git a/rofunc/simulator/utils/get_inertia.py b/rofunc/simulator/utils/get_inertia.py index c27f8fb0a..98a23f3c9 100644 --- a/rofunc/simulator/utils/get_inertia.py +++ b/rofunc/simulator/utils/get_inertia.py @@ -34,7 +34,7 @@ def calculate_inertial_tag(file_name=None, mass=-1, pr=8, scale_factor=100): if __name__ == '__main__': - path = "/home/ubuntu/Github/Xianova_Robotics/Rofunc-secret/rofunc/simulator/assets/urdf/zju_humanoid/low_meshes" + path = "./simulator/assets/urdf/zju_humanoid/low_meshes" name = "WRIST_UPDOWN_R.STL" path = os.path.join(path, name) calculate_inertial_tag(path, 1) diff --git a/rofunc/utils/oslab/path.py b/rofunc/utils/oslab/path.py index 84db7cd9a..cb4ef94e0 100644 --- a/rofunc/utils/oslab/path.py +++ b/rofunc/utils/oslab/path.py @@ -20,7 +20,7 @@ import rofunc as rf -def get_rofunc_path(): +def get_rofunc_path(extra_path=None): """ Get the path of the rofunc package. @@ -29,9 +29,13 @@ def get_rofunc_path(): if not hasattr(rf, "__path__"): raise RuntimeError("rofunc package is not installed") rofunc_path = list(rf.__path__)[0] + + if extra_path is not None: + rofunc_path = os.path.join(rofunc_path, extra_path) return rofunc_path + def get_elegantrl_path(): """ Get the path of the elegantrl package. diff --git a/rofunc/utils/robolab/formatter/mjcf.py b/rofunc/utils/robolab/formatter/mjcf.py new file mode 100644 index 000000000..d64cb0547 --- /dev/null +++ b/rofunc/utils/robolab/formatter/mjcf.py @@ -0,0 +1,174 @@ +from typing import Union, Optional + +import mujoco +import pytorch_kinematics.transforms as tf +from mujoco._structs import _MjModelBodyViews as MjModelBodyViews +from pytorch_kinematics import chain, frame + +# Converts from MuJoCo joint types to pytorch_kinematics joint types +JOINT_TYPE_MAP = { + mujoco.mjtJoint.mjJNT_HINGE: 'revolute', + mujoco.mjtJoint.mjJNT_SLIDE: "prismatic" +} + + +def get_body_geoms(m: mujoco.MjModel, body: MjModelBodyViews, base: Optional[tf.Transform3d] = None): + # Find all geoms which have body as parent + base = base or tf.Transform3d() + visuals = [] + for geom_id in range(m.ngeom): + geom = m.geom(geom_id) + if geom.bodyid == body.id: + if geom.type == "capsule": + param = (geom.size[0], geom.fromto) + elif geom.type == "sphere": + param = geom.size[0] + else: + param = geom.size + visuals.append(frame.Visual(offset=tf.Transform3d(rot=geom.quat, pos=geom.pos), geom_type=geom.type, + geom_param=param)) + return visuals + + +def body_to_link(body, base: Optional[tf.Transform3d] = None): + base = base or tf.Transform3d() + return frame.Link(body.name, offset=tf.Transform3d(rot=body.quat, pos=body.pos)) + + +def joint_to_joint(joint, base: Optional[tf.Transform3d] = None): + base = base or tf.Transform3d() + return frame.Joint( + joint.name, + offset=tf.Transform3d(pos=joint.pos), + joint_type=JOINT_TYPE_MAP[joint.type], + axis=joint.axis, + ) + + +def add_composite_joint(root_frame, joints, base: Optional[tf.Transform3d] = None): + base = base or tf.Transform3d() + if len(joints) > 0: + root_frame.children = root_frame.children + [ + frame.Frame(link=frame.Link(name=root_frame.link.name + "_child"), joint=joint_to_joint(joints[0], base)) + ] + ret, offset = add_composite_joint(root_frame.children[-1], joints[1:]) + return ret, root_frame.joint.offset * offset + else: + return root_frame, root_frame.joint.offset + + +def _build_chain_recurse(m, parent_frame, parent_body): + parent_frame.link.visuals = get_body_geoms(m, parent_body) + # iterate through all bodies that are children of parent_body + for body_id in range(m.nbody): + body = m.body(body_id) + if body.parentid == parent_body.id and body_id != parent_body.id: + n_joints = body.jntnum + if n_joints > 1: + # Support for composite joints + old_parent_frame = parent_frame + for i in range(int(n_joints)): + joint = m.joint(body.jntadr[0] + i) + if i == 0: + joint_offset = tf.Transform3d(pos=joint.pos) + child_joint = frame.Joint(joint.name, offset=joint_offset, axis=joint.axis, + joint_type=JOINT_TYPE_MAP[joint.type[0]], + limits=(joint.range[0], joint.range[1])) + else: + child_joint = frame.Joint(joint.name, axis=joint.axis, + joint_type=JOINT_TYPE_MAP[joint.type[0]], + limits=(joint.range[0], joint.range[1])) + if i == 0: + child_link = frame.Link(body.name + "_" + str(i), + offset=tf.Transform3d(rot=body.quat, pos=body.pos)) + else: + child_link = frame.Link(body.name + "_" + str(i)) + child_frame = frame.Frame(name=body.name + "_" + str(i), link=child_link, joint=child_joint) + parent_frame.children = parent_frame.children + [child_frame, ] + parent_frame = child_frame + parent_frame = old_parent_frame + elif n_joints == 1: + # Find the joint for this body, again assuming there's only one joint per body. + joint = m.joint(body.jntadr[0]) + joint_offset = tf.Transform3d(pos=joint.pos) + child_joint = frame.Joint(joint.name, offset=joint_offset, axis=joint.axis, + joint_type=JOINT_TYPE_MAP[joint.type[0]], + limits=(joint.range[0], joint.range[1])) + child_link = frame.Link(body.name, offset=tf.Transform3d(rot=body.quat, pos=body.pos)) + child_frame = frame.Frame(name=body.name, link=child_link, joint=child_joint) + parent_frame.children = parent_frame.children + [child_frame, ] + else: + child_joint = frame.Joint(body.name + "_fixed_joint") + child_link = frame.Link(body.name, offset=tf.Transform3d(rot=body.quat, pos=body.pos)) + child_frame = frame.Frame(name=body.name, link=child_link, joint=child_joint) + parent_frame.children = parent_frame.children + [child_frame, ] + _build_chain_recurse(m, child_frame, body) + + # # iterate through all sites that are children of parent_body + # for site_id in range(m.nsite): + # site = m.site(site_id) + # if site.bodyid == parent_body.id: + # site_link = frame.Link(site.name, offset=tf.Transform3d(rot=site.quat, pos=site.pos)) + # site_frame = frame.Frame(name=site.name, link=site_link) + # parent_frame.children = parent_frame.children + [site_frame, ] + + +# def _build_chain_recurse(m, root_frame, root_body): +# base = root_frame.link.offset +# cur_frame, cur_base = add_composite_joint(root_frame, root_body.joint, base) +# jbase = cur_base.inverse() * base +# if len(root_body.joint) > 0: +# cur_frame.link.visuals = get_body_geoms(m, root_body.geom, jbase) +# else: +# cur_frame.link.visuals = get_body_geoms(m, root_body.geom) +# for b in root_body.body: +# cur_frame.children = cur_frame.children + [frame.Frame()] +# next_frame = cur_frame.children[-1] +# next_frame.name = b.name + "_frame" +# next_frame.link = body_to_link(b, jbase) +# _build_chain_recurse(m, next_frame, b) + + +def build_chain_from_mjcf(data, body: Union[None, str, int] = None): + """ + Build a Chain object from MJCF data. + + :param data: MJCF string data + :param body: the name or index of the body to use as the root of the chain. If None, body idx=0 is used. + :return: Chain object created from MJCF + """ + # import xml.etree.ElementTree as ET + # root = ET.parse(path).getroot() + # + # ASSETS = dict() + # mesh_dir = root.find("compiler").attrib["meshdir"] + # for asset in root.findall("asset"): + # for mesh in asset.findall("mesh"): + # filename = mesh.attrib["file"] + # with open(os.path.join(os.path.dirname(path), mesh_dir, filename), 'rb') as f: + # ASSETS[filename] = f.read() + + m = mujoco.MjModel.from_xml_path(data) + if body is None: + root_body = m.body(0) + else: + root_body = m.body(body) + root_frame = frame.Frame(root_body.name, + link=body_to_link(root_body), + joint=frame.Joint()) + _build_chain_recurse(m, root_frame, root_body) + return chain.Chain(root_frame) + + +def build_serial_chain_from_mjcf(data, end_link_name, root_link_name=""): + """ + Build a SerialChain object from MJCF data. + + :param data: MJCF string data + :param end_link_name: the name of the link that is the end effector + :param root_link_name: the name of the root link + :return: SerialChain object created from MJCF + """ + mjcf_chain = build_chain_from_mjcf(data) + serial_chain = chain.SerialChain(mjcf_chain, end_link_name, "" if root_link_name == "" else root_link_name) + return serial_chain diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/__init__.py b/rofunc/utils/robolab/formatter/mjcf_parser/__init__.py new file mode 100644 index 000000000..31b05c4ad --- /dev/null +++ b/rofunc/utils/robolab/formatter/mjcf_parser/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""PyMJCF: an MJCF object-model library.""" + +from rofunc.utils.robolab.formatter.mjcf_parser.parser import * diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/attribute.py b/rofunc/utils/robolab/formatter/mjcf_parser/attribute.py new file mode 100644 index 000000000..26f8cfc30 --- /dev/null +++ b/rofunc/utils/robolab/formatter/mjcf_parser/attribute.py @@ -0,0 +1,572 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Classes representing various MJCF attribute data types.""" + +import abc +import collections +import hashlib +import io +import os + +import numpy as np +from rofunc.utils.robolab.formatter.mjcf_parser import util +from rofunc.utils.robolab.formatter.mjcf_parser import io as resources + +from rofunc.utils.robolab.formatter.mjcf_parser import base +from rofunc.utils.robolab.formatter.mjcf_parser import constants +from rofunc.utils.robolab.formatter.mjcf_parser import debugging +from rofunc.utils.robolab.formatter.mjcf_parser import skin + +# Copybara placeholder for internal file handling dependency. + +_INVALID_REFERENCE_TYPE = ( + 'Reference should be an MJCF Element whose type is {valid_type!r}: ' + 'got {actual_type!r}.') + +_MESH_EXTENSIONS = ('.stl', '.msh', '.obj') + +# MuJoCo's compiler enforces this. +_INVALID_MESH_EXTENSION = ( + 'Mesh files must have one of the following extensions: {}, got {{}}.' + .format(_MESH_EXTENSIONS)) + + +class _Attribute(metaclass=abc.ABCMeta): + """Abstract base class for MJCF attribute data types.""" + + def __init__(self, name, required, parent, value, + conflict_allowed, conflict_behavior): + self._name = name + self._required = required + self._parent = parent + self._value = None + self._conflict_allowed = conflict_allowed + self._conflict_behavior = conflict_behavior + self._check_and_assign(value) + + def _check_and_assign(self, new_value): + if new_value is None: + self.clear() + elif isinstance(new_value, str): + self._assign_from_string(new_value) + else: + self._assign(new_value) + if debugging.debug_mode(): + self._last_modified_stack = debugging.get_current_stack_trace() + + @property + def last_modified_stack(self): + if debugging.debug_mode(): + return self._last_modified_stack + + @property + def value(self): + return self._value + + @value.setter + def value(self, new_value): + self._check_and_assign(new_value) + + @abc.abstractmethod + def _assign(self, value): + raise NotImplementedError # pragma: no cover + + def clear(self): + if self._required: + raise AttributeError( + 'Attribute {!r} of element <{}> is required' + .format(self._name, self._parent.tag)) + else: + self._force_clear() + + def _force_clear(self): + self._before_clear() + self._value = None + if debugging.debug_mode(): + self._last_modified_stack = debugging.get_current_stack_trace() + + def _before_clear(self): + pass + + def _assign_from_string(self, string): + self._assign(string) + + def to_xml_string(self, prefix_root, **kwargs): # pylint: disable=unused-argument + if self._value is None: + return None + else: + return str(self._value) + + @property + def conflict_allowed(self): + return self._conflict_allowed + + @property + def conflict_behavior(self): + return self._conflict_behavior + + +class String(_Attribute): + """A string MJCF attribute.""" + + def _assign(self, value): + if not isinstance(value, str): + raise ValueError('Expect a string value: got {}'.format(value)) + elif not value: + self.clear() + else: + self._value = value + + +class Integer(_Attribute): + """An integer MJCF attribute.""" + + def _assign(self, value): + try: + float_value = float(value) + int_value = int(float(value)) + if float_value != int_value: + raise ValueError + except ValueError: + raise ValueError( + 'Expect an integer value: got {}'.format(value)) from None + self._value = int_value + + +class Float(_Attribute): + """An float MJCF attribute.""" + + def _assign(self, value): + try: + float_value = float(value) + except ValueError: + raise ValueError('Expect a float value: got {}'.format(value)) from None + self._value = float_value + + def to_xml_string(self, prefix_root=None, + *, + precision=constants.XML_DEFAULT_PRECISION, + zero_threshold=0, + **kwargs): + if self._value is None: + return None + else: + out = io.BytesIO() + value = self._value + if abs(value) < zero_threshold: + value = 0.0 + np.savetxt(out, [value], fmt=f'%.{precision:d}g', newline=' ') + return util.to_native_string(out.getvalue())[:-1] # Strip trailing space. + + +class Keyword(_Attribute): + """A keyword MJCF attribute.""" + + def __init__(self, name, required, parent, value, + conflict_allowed, conflict_behavior, valid_values): + self._valid_values = collections.OrderedDict( + (value.lower(), value) for value in valid_values) + super().__init__(name, required, parent, value, conflict_allowed, + conflict_behavior) + + def _assign(self, value): + if value is None or value == '': # pylint: disable=g-explicit-bool-comparison + self.clear() + else: + try: + self._value = self._valid_values[str(value).lower()] + except KeyError: + raise ValueError('Expect keyword to be one of {} but got: {}'.format( + list(self._valid_values.values()), value)) from None + + @property + def valid_values(self): + return list(self._valid_values.keys()) + + +class Array(_Attribute): + """An array MJCF attribute.""" + + def __init__(self, name, required, parent, value, + conflict_allowed, conflict_behavior, length, dtype): + self._length = length + self._dtype = dtype + super().__init__(name, required, parent, value, conflict_allowed, + conflict_behavior) + + def _assign(self, value): + self._value = self._check_shape(np.array(value, dtype=self._dtype)) + + def _assign_from_string(self, string): + self._assign(np.fromstring(string, dtype=self._dtype, sep=' ')) + + def to_xml_string(self, prefix_root=None, + *, + precision=constants.XML_DEFAULT_PRECISION, + zero_threshold=0, + **kwargs): + if self._value is None: + return None + else: + out = io.BytesIO() + value = self._value + if zero_threshold: + value = np.copy(value) + value[np.abs(value) < zero_threshold] = 0 + np.savetxt(out, value, fmt=f'%.{precision:d}g', newline=' ') + return util.to_native_string(out.getvalue())[:-1] # Strip trailing space. + + def _check_shape(self, array): + actual_length = array.shape[0] + if len(array.shape) > 1: + raise ValueError('Expect one-dimensional array: got {}'.format(array)) + if self._length and actual_length > self._length: + raise ValueError('Expect array with no more than {} entries: got {}' + .format(self._length, array)) + return array + + +class Identifier(_Attribute): + """A string attribute that represents a unique identifier of an element.""" + + def _assign(self, value): + if not isinstance(value, str): + raise ValueError('Expect a string value: got {}'.format(value)) + elif not value: + self.clear() + elif self._parent.spec.namespace == 'body' and value == 'world': + raise ValueError('A body cannot be named \'world\'. ' + 'The name \'world\' is used by MuJoCo to refer to the ' + '.') + elif constants.PREFIX_SEPARATOR in value: + raise ValueError( + 'An identifier cannot contain a {!r}, ' + 'as this is reserved for scoping purposes: got {!r}' + .format(constants.PREFIX_SEPARATOR, value)) + else: + old_value = self._value + if value != old_value: + self._parent.namescope.add( + self._parent.spec.namespace, value, self._parent) + if old_value: + self._parent.namescope.remove(self._parent.spec.namespace, old_value) + self._value = value + + def _before_clear(self): + if self._value: + self._parent.namescope.remove(self._parent.spec.namespace, self._value) + + def _defaults_string(self, prefix_root): + prefix = self._parent.namescope.full_prefix(prefix_root, as_list=True) + prefix.append(self._value or '') + return constants.PREFIX_SEPARATOR.join(prefix) or constants.PREFIX_SEPARATOR + + def to_xml_string(self, prefix_root=None, **kwargs): + if self._parent.tag == constants.DEFAULT: + return self._defaults_string(prefix_root) + elif self._value: + prefix = self._parent.namescope.full_prefix(prefix_root, as_list=True) + prefix.append(self._value) + return constants.PREFIX_SEPARATOR.join(prefix) + else: + return self._value + + +class Reference(_Attribute): + """A string attribute that represents a reference to an identifier.""" + + def __init__(self, name, required, parent, value, + conflict_allowed, conflict_behavior, reference_namespace): + self._reference_namespace = reference_namespace + super().__init__(name, required, parent, value, conflict_allowed, + conflict_behavior) + + def _check_dead_reference(self): + if isinstance(self._value, base.Element) and self._value.is_removed: + self.clear() + + @property + def value(self): + self._check_dead_reference() + return super().value + + @value.setter + def value(self, new_value): + super(Reference, self.__class__).value.fset(self, new_value) + + @property + def reference_namespace(self): + if isinstance(self._reference_namespace, _Attribute): + return constants.INDIRECT_REFERENCE_ATTRIB.get( + self._reference_namespace.value, self._reference_namespace.value) + else: + return self._reference_namespace + + def _assign(self, value): + if not isinstance(value, (base.Element, str)): + raise ValueError( + 'Expect a string or `mjcf.Element` value: got {}'.format(value)) + elif not value: + self.clear() + else: + if isinstance(value, base.Element): + value_namespace = ( + value.spec.namespace.split(constants.NAMESPACE_SEPARATOR)[0]) + if value_namespace != self.reference_namespace: + raise ValueError(_INVALID_REFERENCE_TYPE.format( + valid_type=self.reference_namespace, + actual_type=value_namespace)) + self._value = value + + def _before_clear(self): + if isinstance(self._value, base.Element): + if isinstance(self._reference_namespace, _Attribute): + self._reference_namespace._force_clear() # pylint: disable=protected-access + + def _defaults_string(self, prefix_root): + """Generates the XML string if this is a reference to a defaults class. + + To prevent global defaults from clashing, we turn all global defaults + into a properly named defaults class. Therefore, care must be taken when + this attribute is not explicitly defined. If the parent element can be + traced up to a body with a nontrivial 'childclass' then must continue to + leave this attribute undefined. + + Args: + prefix_root: A `NameScope` object to be treated as root + for the purpose of calculating the prefix. + + Returns: + A string to be used in the generated XML. + """ + self._check_dead_reference() + prefix = self._parent.namescope.full_prefix(prefix_root) + if not self._value: + defaults_root = self._parent.parent + while defaults_root is not None: + if (hasattr(defaults_root, constants.CHILDCLASS) + and defaults_root.childclass): + break + defaults_root = defaults_root.parent + if defaults_root is None: + # This element doesn't belong to a childclass'd body. + global_class = self._parent.root.default.dclass or '' + out_string = (prefix + global_class) or constants.PREFIX_SEPARATOR + else: + out_string = None + else: + out_string = prefix + self._value + return out_string + + def to_xml_string(self, prefix_root, **kwargs): + self._check_dead_reference() + if isinstance(self._value, base.Element): + return self._value.prefixed_identifier(prefix_root) + elif (self.reference_namespace == constants.DEFAULT + and self._name != constants.CHILDCLASS): + return self._defaults_string(prefix_root) + elif self._value: + return self._parent.namescope.full_prefix(prefix_root) + self._value + else: + return None + + +class BasePath(_Attribute): + """A string attribute that represents a base path for an asset type.""" + + def __init__(self, name, required, parent, value, + conflict_allowed, conflict_behavior, path_namespace): + self._path_namespace = path_namespace + super().__init__(name, required, parent, value, conflict_allowed, + conflict_behavior) + + def _assign(self, value): + if not isinstance(value, str): + raise ValueError('Expect a string value: got {}'.format(value)) + elif not value: + self.clear() + else: + self._parent.namescope.replace( + constants.BASEPATH, self._path_namespace, value) + self._value = value + + def _before_clear(self): + if self._value: + self._parent.namescope.remove(constants.BASEPATH, self._path_namespace) + + def to_xml_string(self, prefix_root=None, **kwargs): + return None + + +class BaseAsset: + """Base class for binary assets.""" + + __slots__ = ('extension', 'prefix') + + def __init__(self, extension, prefix=''): + self.extension = extension + self.prefix = prefix + + def __eq__(self, other): + return self.get_vfs_filename() == other.get_vfs_filename() + + def get_vfs_filename(self): + """Returns the name of the asset file as registered in MuJoCo's VFS.""" + # Hash the contents of the asset to get a unique identifier. + hash_string = hashlib.sha1(util.to_binary_string(self.contents)).hexdigest() + # Prepend the prefix, if one exists. + if self.prefix: + prefix = self.prefix + raw_length = len(prefix) + len(hash_string) + len(self.extension) + 1 + if raw_length > constants.MAX_VFS_FILENAME_LENGTH: + trim_amount = raw_length - constants.MAX_VFS_FILENAME_LENGTH + prefix = prefix[:-trim_amount] + filename = '-'.join([prefix, hash_string]) + else: + filename = hash_string + + # An extension is needed because MuJoCo's compiler looks at this when + # deciding how to load meshes and heightfields. + return filename + self.extension + + +class Asset(BaseAsset): + """Class representing a binary asset.""" + + __slots__ = ('contents',) + + def __init__(self, contents, extension, prefix=''): + """Initializes a new `Asset`. + + Args: + contents: The contents of the file as a bytestring. + extension: A string specifying the file extension (e.g. '.png', '.stl'). + prefix: (optional) A prefix applied to the filename given in MuJoCo's VFS. + """ + self.contents = contents + super().__init__(extension, prefix) + + +class SkinAsset(BaseAsset): + """Class representing a binary asset corresponding to a skin.""" + + __slots__ = ('skin', 'parent', '_cached_revision', '_cached_contents') + + def __init__(self, contents, parent, extension, prefix=''): + self.skin = skin.parse( + contents, lambda body_name: parent.root.find('body', body_name)) + self.parent = parent + self._cached_revision = -1 + self._cached_contents = None + super().__init__(extension, prefix) + + @property + def contents(self): + if self._cached_revision < self.parent.namescope.revision: + self._cached_contents = skin.serialize(self.skin) + self._cached_revision = self.parent.namescope.revision + return self._cached_contents + + +class File(_Attribute): + """Attribute representing an asset file.""" + + def __init__(self, name, required, parent, value, + conflict_allowed, conflict_behavior, path_namespace): + self._path_namespace = path_namespace + super().__init__(name, required, parent, value, conflict_allowed, + conflict_behavior) + parent.namescope.files.add(self) + + def _assign(self, value): + if not value: + self.clear() + else: + if isinstance(value, str): + asset = self._get_asset_from_path(value) + elif isinstance(value, Asset): + asset = value + else: + raise ValueError('Expect either a string or `Asset` value: got {}' + .format(value)) + self._validate_extension(asset.extension) + self._value = asset + + def _get_asset_from_path(self, path): + """Constructs a `Asset` given a file path.""" + _, basename = os.path.split(path) + filename, extension = os.path.splitext(basename) + + assetdir = None + if self._parent.namescope.has_identifier( + constants.BASEPATH, constants.ASSETDIR_NAMESPACE + ): + assetdir = self._parent.namescope.get( + constants.BASEPATH, constants.ASSETDIR_NAMESPACE + ) + + if path in self._parent.namescope.assets: + # Look in the dict of pre-loaded assets before checking the filesystem. + contents = self._parent.namescope.assets[path] + else: + # Construct the full path to the asset file, prefixed by the path to the + # model directory, and by `meshdir` or `texturedir` if appropriate. + path_parts = [] + if self._parent.namescope.model_dir: + path_parts.append(self._parent.namescope.model_dir) + + if self._parent.namescope.has_identifier( + constants.BASEPATH, self._path_namespace + ): + base_path = self._parent.namescope.get( + constants.BASEPATH, self._path_namespace + ) + path_parts.append(base_path) + elif ( + self._path_namespace + in (constants.TEXTUREDIR_NAMESPACE, constants.MESHDIR_NAMESPACE) + and assetdir is not None + ): + path_parts.append(assetdir) + path_parts.append(path) + full_path = os.path.join(*path_parts) # pylint: disable=no-value-for-parameter + contents = resources.GetResource(full_path) + + if self._parent.tag == constants.SKIN: + return SkinAsset(contents=contents, parent=self._parent, + extension=extension, prefix=filename) + else: + return Asset(contents=contents, extension=extension, prefix=filename) + + def _validate_extension(self, extension): + if self._parent.tag == constants.MESH: + if extension.lower() not in _MESH_EXTENSIONS: + raise ValueError(_INVALID_MESH_EXTENSION.format(extension)) + + def get_contents(self): + """Returns a bytestring representing the contents of the asset.""" + if self._value is None: + raise RuntimeError('You must assign a value to this attribute before ' + 'querying the contents.') + return self._value.contents + + def to_xml_string(self, prefix_root=None, **kwargs): + """Returns the asset filename as it will appear in the generated XML.""" + del prefix_root # Unused + if self._value is not None: + return self._value.get_vfs_filename() + else: + return None diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/attribute_test.py b/rofunc/utils/robolab/formatter/mjcf_parser/attribute_test.py new file mode 100644 index 000000000..66e5bf7fb --- /dev/null +++ b/rofunc/utils/robolab/formatter/mjcf_parser/attribute_test.py @@ -0,0 +1,497 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for `dm_control.mjcf.attribute`.""" + +import contextlib +import hashlib +import os + +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized + +from rofunc.utils.robolab.formatter.mjcf_parser import attribute +from rofunc.utils.robolab.formatter.mjcf_parser import element +from rofunc.utils.robolab.formatter.mjcf_parser import namescope +from rofunc.utils.robolab.formatter.mjcf_parser import schema + +ASSETS_DIR = os.path.join(os.path.dirname(__file__), 'test_assets') +FAKE_SCHEMA_FILENAME = 'attribute_test_schema.xml' + +ORIGINAL_SCHEMA_PATH = os.path.join(os.path.dirname(__file__), 'schema.xml') + + +class AttributeTest(parameterized.TestCase): + """Test for Attribute classes. + + Our tests here reflect actual usages of the Attribute classes, namely that we + never directly create attributes but instead access them through Elements. + """ + + def setUp(self): + super().setUp() + schema.override_schema(os.path.join(ASSETS_DIR, FAKE_SCHEMA_FILENAME)) + self._alpha = namescope.NameScope('alpha', None) + self._beta = namescope.NameScope('beta', None) + self._beta.parent = self._alpha + self._mujoco = element.RootElement() + self._mujoco.namescope.parent = self._beta + + def tearDown(self): + super().tearDown() + schema.override_schema(ORIGINAL_SCHEMA_PATH) + + def assertXMLStringIsNone(self, mjcf_element, attribute_name): + for prefix_root in (self._alpha, self._beta, self._mujoco.namescope, None): + self.assertIsNone( + mjcf_element.get_attribute_xml_string(attribute_name, prefix_root)) + + def assertXMLStringEqual(self, mjcf_element, attribute_name, expected): + for prefix_root in (self._alpha, self._beta, self._mujoco.namescope, None): + self.assertEqual( + mjcf_element.get_attribute_xml_string(attribute_name, prefix_root), + expected) + + def assertXMLStringIsCorrectlyScoped( + self, mjcf_element, attribute_name, expected): + for prefix_root in (self._alpha, self._beta, self._mujoco.namescope, None): + self.assertEqual( + mjcf_element.get_attribute_xml_string(attribute_name, prefix_root), + self._mujoco.namescope.full_prefix(prefix_root) + expected) + + def assertCorrectXMLStringForDefaultsClass( + self, mjcf_element, attribute_name, expected): + for prefix_root in (self._alpha, self._beta, self._mujoco.namescope, None): + self.assertEqual( + mjcf_element.get_attribute_xml_string(attribute_name, prefix_root), + (self._mujoco.namescope.full_prefix(prefix_root) + expected) or '/') + + def assertElementIsIdentifiedByName(self, mjcf_element, expected): + self.assertEqual(mjcf_element.name, expected) + self.assertEqual(self._mujoco.find(mjcf_element.spec.namespace, expected), + mjcf_element) + + @contextlib.contextmanager + def assertAttributeIsNoneWhenDone(self, mjcf_element, attribute_name): + yield + self.assertIsNone(getattr(mjcf_element, attribute_name)) + self.assertXMLStringIsNone(mjcf_element, attribute_name) + + def assertCorrectClearBehavior(self, mjcf_element, attribute_name, required): + if required: + return self.assertRaisesRegex(AttributeError, 'is required') + else: + return self.assertAttributeIsNoneWhenDone(mjcf_element, attribute_name) + + def assertCorrectClearBehaviorByAllMethods( + self, mjcf_element, attribute_name, required): + original_value = getattr(mjcf_element, attribute_name) + + def reset_value(): + setattr(mjcf_element, attribute_name, original_value) + if original_value is not None: + self.assertIsNotNone(getattr(mjcf_element, attribute_name)) + + # clear by using del + with self.assertCorrectClearBehavior( + mjcf_element, attribute_name, required): + delattr(mjcf_element, attribute_name) + + # clear by assigning None + reset_value() + with self.assertCorrectClearBehavior( + mjcf_element, attribute_name, required): + setattr(mjcf_element, attribute_name, None) + + if isinstance(original_value, str): + # clear by assigning empty string + reset_value() + with self.assertCorrectClearBehavior( + mjcf_element, attribute_name, required): + setattr(mjcf_element, attribute_name, '') + + def assertCanBeCleared(self, mjcf_element, attribute_name): + self.assertCorrectClearBehaviorByAllMethods( + mjcf_element, attribute_name, required=False) + + def assertCanNotBeCleared(self, mjcf_element, attribute_name): + self.assertCorrectClearBehaviorByAllMethods( + mjcf_element, attribute_name, required=True) + + def testFloatScalar(self): + mujoco = self._mujoco + mujoco.optional.float = 0.357357 + self.assertEqual(mujoco.optional.float, 0.357357) + self.assertEqual(type(mujoco.optional.float), float) + with self.assertRaisesRegex(ValueError, 'Expect a float value'): + mujoco.optional.float = 'five' + # failed assignment should not change the value + self.assertEqual(mujoco.optional.float, 0.357357) + self.assertEqual( + mujoco.optional.get_attribute_xml_string('float', precision=1), + '0.4') + self.assertEqual( + mujoco.optional.get_attribute_xml_string('float', precision=2), + '0.36') + self.assertEqual( + mujoco.optional.get_attribute_xml_string('float', precision=3), + '0.357') + self.assertEqual( + mujoco.optional.get_attribute_xml_string('float', precision=4), + '0.3574') + self.assertEqual( + mujoco.optional.get_attribute_xml_string('float', precision=5), + '0.35736') + self.assertEqual( + mujoco.optional.get_attribute_xml_string('float', precision=6), + '0.357357') + self.assertEqual( + mujoco.optional.get_attribute_xml_string('float', precision=7), + '0.357357') + self.assertEqual( + mujoco.optional.get_attribute_xml_string('float', precision=8), + '0.357357') + + def testIntScalar(self): + mujoco = self._mujoco + mujoco.optional.int = 12345 + self.assertEqual(mujoco.optional.int, 12345) + self.assertEqual(type(mujoco.optional.int), int) + with self.assertRaisesRegex(ValueError, 'Expect an integer value'): + mujoco.optional.int = 10.5 + # failed assignment should not change the value + self.assertEqual(mujoco.optional.int, 12345) + self.assertXMLStringEqual(mujoco.optional, 'int', '12345') + self.assertCanBeCleared(mujoco.optional, 'int') + + def testStringScalar(self): + mujoco = self._mujoco + mujoco.optional.string = 'foobar' + self.assertEqual(mujoco.optional.string, 'foobar') + self.assertXMLStringEqual(mujoco.optional, 'string', 'foobar') + with self.assertRaisesRegex(ValueError, 'Expect a string value'): + mujoco.optional.string = mujoco.optional + self.assertCanBeCleared(mujoco.optional, 'string') + + def testFloatArray(self): + mujoco = self._mujoco + mujoco.optional.float_array = [3, 2, 1] + np.testing.assert_array_equal(mujoco.optional.float_array, [3, 2, 1]) + self.assertEqual(mujoco.optional.float_array.dtype, float) + with self.assertRaisesRegex(ValueError, 'no more than 3 entries'): + mujoco.optional.float_array = [0, 0, 0, -10] + with self.assertRaisesRegex(ValueError, 'one-dimensional array'): + mujoco.optional.float_array = np.array([[1, 2], [3, 4]]) + # failed assignments should not change the value + np.testing.assert_array_equal(mujoco.optional.float_array, [3, 2, 1]) + # XML string should not be affected by global print options + np.set_printoptions(precision=3, suppress=True) + mujoco.optional.float_array = [np.pi, 2, 1e-16] + self.assertXMLStringEqual(mujoco.optional, 'float_array', + '3.1415926535897931 2 9.9999999999999998e-17') + self.assertEqual( + mujoco.optional.get_attribute_xml_string('float_array', precision=5), + '3.1416 2 1e-16') + self.assertEqual( + mujoco.optional.get_attribute_xml_string( + 'float_array', precision=5, zero_threshold=1e-10), + '3.1416 2 0') + self.assertCanBeCleared(mujoco.optional, 'float_array') + + def testFormatVeryLargeArray(self): + mujoco = self._mujoco + array = np.arange(2000, dtype=np.double) + mujoco.optional.huge_float_array = array + xml_string = mujoco.optional.get_attribute_xml_string('huge_float_array') + self.assertNotIn('...', xml_string) + # Check that array <--> string conversion is a round trip. + mujoco.optional.huge_float_array = None + self.assertIsNone(mujoco.optional.huge_float_array) + mujoco.optional.huge_float_array = xml_string + np.testing.assert_array_equal(mujoco.optional.huge_float_array, array) + + def testIntArray(self): + mujoco = self._mujoco + mujoco.optional.int_array = [2, 2] + np.testing.assert_array_equal(mujoco.optional.int_array, [2, 2]) + self.assertEqual(mujoco.optional.int_array.dtype, int) + with self.assertRaisesRegex(ValueError, 'no more than 2 entries'): + mujoco.optional.int_array = [0, 0, 10] + # failed assignment should not change the value + np.testing.assert_array_equal(mujoco.optional.int_array, [2, 2]) + self.assertXMLStringEqual(mujoco.optional, 'int_array', '2 2') + self.assertCanBeCleared(mujoco.optional, 'int_array') + + def testKeyword(self): + mujoco = self._mujoco + + valid_values = ['Alpha', 'Beta', 'Gamma'] + for value in valid_values: + mujoco.optional.keyword = value.lower() + self.assertEqual(mujoco.optional.keyword, value) + self.assertXMLStringEqual(mujoco.optional, 'keyword', value) + + mujoco.optional.keyword = value.upper() + self.assertEqual(mujoco.optional.keyword, value) + self.assertXMLStringEqual(mujoco.optional, 'keyword', value) + + with self.assertRaisesRegex(ValueError, str(valid_values)): + mujoco.optional.keyword = 'delta' + # failed assignment should not change the value + self.assertXMLStringEqual(mujoco.optional, 'keyword', valid_values[-1]) + self.assertCanBeCleared(mujoco.optional, 'keyword') + + def testKeywordFalseTrueAuto(self): + mujoco = self._mujoco + for value in ('false', 'False', False): + mujoco.optional.fta = value + self.assertEqual(mujoco.optional.fta, 'false') + self.assertXMLStringEqual(mujoco.optional, 'fta', 'false') + for value in ('true', 'True', True): + mujoco.optional.fta = value + self.assertEqual(mujoco.optional.fta, 'true') + self.assertXMLStringEqual(mujoco.optional, 'fta', 'true') + for value in ('auto', 'AUTO'): + mujoco.optional.fta = value + self.assertEqual(mujoco.optional.fta, 'auto') + self.assertXMLStringEqual(mujoco.optional, 'fta', 'auto') + for value in (None, ''): + mujoco.optional.fta = value + self.assertIsNone(mujoco.optional.fta) + self.assertXMLStringEqual(mujoco.optional, 'fta', None) + + def testIdentifier(self): + mujoco = self._mujoco + + entity = mujoco.worldentity.add('entity') + subentity_1 = entity.add('subentity', name='foo') + subentity_2 = entity.add('subentity_alias', name='bar') + + self.assertIsNone(entity.name) + self.assertElementIsIdentifiedByName(subentity_1, 'foo') + self.assertElementIsIdentifiedByName(subentity_2, 'bar') + self.assertXMLStringIsCorrectlyScoped(subentity_1, 'name', 'foo') + self.assertXMLStringIsCorrectlyScoped(subentity_2, 'name', 'bar') + + with self.assertRaisesRegex(ValueError, 'Expect a string value'): + subentity_2.name = subentity_1 + with self.assertRaisesRegex(ValueError, 'reserved for scoping'): + subentity_2.name = 'foo/bar' + with self.assertRaisesRegex(ValueError, 'Duplicated identifier'): + subentity_2.name = 'foo' + # failed assignment should not change the value + self.assertElementIsIdentifiedByName(subentity_2, 'bar') + + with self.assertRaisesRegex(ValueError, 'cannot be named \'world\''): + mujoco.worldentity.add('body', name='world') + + subentity_1.name = 'baz' + self.assertElementIsIdentifiedByName(subentity_1, 'baz') + self.assertIsNone(mujoco.find('subentity', 'foo')) + + # 'foo' is now unused, so we should be allowed to use it + subentity_2.name = 'foo' + self.assertElementIsIdentifiedByName(subentity_2, 'foo') + + # duplicate name should be allowed when in different namespaces + entity.name = 'foo' + self.assertElementIsIdentifiedByName(entity, 'foo') + self.assertCanBeCleared(entity, 'name') + + def testStringReference(self): + mujoco = self._mujoco + mujoco.optional.reference = 'foo' + self.assertEqual(mujoco.optional.reference, 'foo') + self.assertXMLStringIsCorrectlyScoped(mujoco.optional, 'reference', 'foo') + self.assertCanBeCleared(mujoco.optional, 'reference') + + def testElementReferenceWithFixedNamespace(self): + mujoco = self._mujoco + # `mujoco.optional.fixed_type_ref` must be an element in the 'optional' + # namespace. 'identified' elements are part of the 'optional' namespace. + bar = mujoco.add('identified', identifier='bar') + mujoco.optional.fixed_type_ref = bar + self.assertXMLStringIsCorrectlyScoped( + mujoco.optional, 'fixed_type_ref', 'bar') + # Removing the referenced entity should cause the `fixed_type_ref` to be set + # to None. + bar.remove() + self.assertIsNone(mujoco.optional.fixed_type_ref) + + def testElementReferenceWithVariableNamespace(self): + mujoco = self._mujoco + + # `mujoco.optional.reference` can be an element in either the 'entity' or + # or 'optional' namespaces. First we assign an 'identified' element to the + # reference attribute. These are part of the 'optional' namespace. + bar = mujoco.add('identified', identifier='bar') + mujoco.optional.reftype = 'optional' + mujoco.optional.reference = bar + self.assertXMLStringIsCorrectlyScoped(mujoco.optional, 'reference', 'bar') + + # Assigning to `mujoco.optional.reference` should also change the value of + # `mujoco.optional.reftype` to match the namespace of the element that was + # assigned to `mujoco.optional.reference` + self.assertXMLStringEqual(mujoco.optional, 'reftype', 'optional') + + # Now assign an 'entity' element to the reference attribute. These are part + # of the 'entity' namespace. + baz = mujoco.worldentity.add('entity', name='baz') + mujoco.optional.reftype = 'entity' + mujoco.optional.reference = baz + self.assertXMLStringIsCorrectlyScoped(mujoco.optional, 'reference', 'baz') + # The `reftype` should change to 'entity' accordingly. + self.assertXMLStringEqual(mujoco.optional, 'reftype', 'entity') + + # Removing the referenced entity should cause the `reference` and `reftype` + # to be set to None. + baz.remove() + self.assertIsNone(mujoco.optional.reference) + self.assertIsNone(mujoco.optional.reftype) + + def testInvalidReference(self): + mujoco = self._mujoco + bar = mujoco.worldentity.add('entity', name='bar') + baz = bar.add('subentity', name='baz') + mujoco.optional.reftype = 'entity' + with self.assertRaisesWithLiteralMatch( + ValueError, attribute._INVALID_REFERENCE_TYPE.format( + valid_type='entity', actual_type='subentity')): + mujoco.optional.reference = baz + with self.assertRaisesWithLiteralMatch( + ValueError, attribute._INVALID_REFERENCE_TYPE.format( + valid_type='optional', actual_type='subentity')): + mujoco.optional.fixed_type_ref = baz + + def testDefaults(self): + mujoco = self._mujoco + + # Unnamed global defaults class should become a properly named and scoped + # class with a trailing slash + self.assertIsNone(mujoco.default.dclass) + self.assertCorrectXMLStringForDefaultsClass(mujoco.default, 'class', '') + + # An element without an explicit dclass should be assigned to the properly + # scoped global defaults class + entity = mujoco.worldentity.add('entity') + subentity = entity.add('subentity') + self.assertIsNone(subentity.dclass) + self.assertCorrectXMLStringForDefaultsClass(subentity, 'class', '') + + # Named global defaults class should gain scoping prefix + mujoco.default.dclass = 'main' + self.assertEqual(mujoco.default.dclass, 'main') + self.assertCorrectXMLStringForDefaultsClass(mujoco.default, 'class', 'main') + self.assertCorrectXMLStringForDefaultsClass(subentity, 'class', 'main') + + # Named subordinate defaults class should gain scoping prefix + sub_default = mujoco.default.add('default', dclass='sub') + self.assertEqual(sub_default.dclass, 'sub') + self.assertCorrectXMLStringForDefaultsClass(sub_default, 'class', 'sub') + + # An element without an explicit dclass but belongs to a childclassed + # parent should be left alone + entity.childclass = 'sub' + self.assertEqual(entity.childclass, 'sub') + self.assertCorrectXMLStringForDefaultsClass(entity, 'childclass', 'sub') + self.assertXMLStringIsNone(subentity, 'class') + + # An element WITH an explicit dclass should be left alone have it properly + # scoped regardless of whether it belongs to a childclassed parent or not. + subentity.dclass = 'main' + self.assertCorrectXMLStringForDefaultsClass(subentity, 'class', 'main') + + @parameterized.named_parameters( + ('NoBasepath', '', os.path.join(ASSETS_DIR, FAKE_SCHEMA_FILENAME)), + ('WithBasepath', ASSETS_DIR, FAKE_SCHEMA_FILENAME)) + def testFileFromPath(self, basepath, value): + mujoco = self._mujoco + full_path = os.path.join(basepath, value) + with open(full_path, 'rb') as f: + contents = f.read() + _, basename = os.path.split(value) + prefix, extension = os.path.splitext(basename) + expected_xml = prefix + '-' + hashlib.sha1(contents).hexdigest() + extension + mujoco.files.text_path = basepath + text_file = mujoco.files.add('text', file=value) + expected_value = attribute.Asset( + contents=contents, extension=extension, prefix=prefix) + self.assertEqual(text_file.file, expected_value) + self.assertXMLStringEqual(text_file, 'file', expected_xml) + self.assertCanBeCleared(text_file, 'file') + self.assertCanBeCleared(mujoco.files, 'text_path') + + def testFileFromPlaceholder(self): + mujoco = self._mujoco + contents = b'Fake contents' + extension = '.whatever' + expected_xml = hashlib.sha1(contents).hexdigest() + extension + placeholder = attribute.Asset(contents=contents, extension=extension) + text_file = mujoco.files.add('text', file=placeholder) + self.assertEqual(text_file.file, placeholder) + self.assertXMLStringEqual(text_file, 'file', expected_xml) + self.assertCanBeCleared(text_file, 'file') + + def testFileFromAssetsDict(self): + prefix = 'fake_filename' + extension = '.whatever' + path = 'invalid/path/' + prefix + extension + contents = 'Fake contents' + assets = {path: contents} + mujoco = element.RootElement(assets=assets) + text_file = mujoco.files.add('text', file=path) + expected_value = attribute.Asset( + contents=contents, extension=extension, prefix=prefix) + self.assertEqual(text_file.file, expected_value) + + def testFileExceptions(self): + mujoco = self._mujoco + text_file = mujoco.files.add('text') + with self.assertRaisesRegex(ValueError, + 'Expect either a string or `Asset` value'): + text_file.file = mujoco.optional + + def testBasePathExceptions(self): + mujoco = self._mujoco + with self.assertRaisesRegex(ValueError, 'Expect a string value'): + mujoco.files.text_path = mujoco.optional + + def testRequiredAttributes(self): + mujoco = self._mujoco + attributes = ( + ('float', 1.0), ('int', 2), ('string', 'foobar'), + ('float_array', [1.5, 2.5, 3.5]), ('int_array', [4, 5]), + ('keyword', 'alpha'), ('identifier', 'thing'), + ('reference', 'other_thing'), ('basepath', ASSETS_DIR), + ('file', FAKE_SCHEMA_FILENAME) + ) + + # Removing any one of the required attributes should cause initialization + # of a new element to fail + for name, _ in attributes: + attributes_dict = {key: value for key, value in attributes if key != name} + with self.assertRaisesRegex(AttributeError, name + '.+ is required'): + mujoco.add('required', **attributes_dict) + + attributes_dict = {key: value for key, value in attributes} + mujoco.add('required', **attributes_dict) + # Should not be allowed to clear each required attribute after the fact + for name, _ in attributes: + self.assertCanNotBeCleared(mujoco.required, name) + + +if __name__ == '__main__': + absltest.main() diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/base.py b/rofunc/utils/robolab/formatter/mjcf_parser/base.py new file mode 100644 index 000000000..37f90adef --- /dev/null +++ b/rofunc/utils/robolab/formatter/mjcf_parser/base.py @@ -0,0 +1,279 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Base class for all MJCF elements in the object model.""" + +import abc + +from rofunc.utils.robolab.formatter.mjcf_parser import constants + + +class Element(metaclass=abc.ABCMeta): + """Abstract base class for an MJCF element. + + This class is provided so that `isinstance(foo, Element)` is `True` for all + Element-like objects. We do not implement the actual element here because + the actual object returned from traversing the object hierarchy is a + weakproxy-like proxy to an actual element. This is because we do not allow + orphaned non-root elements, so when a particular element is removed from the + tree, all references held automatically become invalid. + """ + __slots__ = [] + + @abc.abstractmethod + def get_init_stack(self): + """Gets the stack trace where this element was first initialized.""" + + @abc.abstractmethod + def get_last_modified_stacks_for_all_attributes(self): + """Gets a dict of stack traces where each attribute was last modified.""" + + @abc.abstractmethod + def is_same_as(self, other): + """Checks whether another element is semantically equivalent to this one. + + Two elements are considered equivalent if they have the same + specification (i.e. same tag appearing in the same context), the same + attribute values, and all of their children are equivalent. The ordering + of non-repeated children is not important for this comparison, while + the ordering of repeated children are important only amongst the same + type* of children. In other words, for two bodies to be considered + equivalent, their child sites must appear in the same order, and their + child geoms must appear in the same order, but permutations between sites + and geoms are disregarded. (The only exception is in tendon definition, + where strict ordering of all children is necessary for equivalence.) + + *Note that the notion of "same type" in this function is very loose: + for example different actuator element subtypes are treated as separate + types when children ordering is considered. Therefore, two + elements might be considered equivalent even though they result in different + orderings of `mjData.ctrl` when compiled. As it stands, this function + is designed primarily as a testing aid and should not be used to guarantee + that models are actually identical. + + Args: + other: An `mjcf.Element` + + Returns: + `True` if `other` element is semantically equivalent to this one. + """ + + @property + @abc.abstractmethod + def tag(self): + pass + + @property + @abc.abstractmethod + def spec(self): + pass + + @property + @abc.abstractmethod + def parent(self): + pass + + @property + @abc.abstractmethod + def namescope(self): + pass + + @property + @abc.abstractmethod + def root(self): + pass + + @abc.abstractmethod + def prefixed_identifier(self, prefix_root): + pass + + @property + @abc.abstractmethod + def full_identifier(self): + """Fully-qualified identifier used for this element in the generated XML.""" + + @abc.abstractmethod + def find(self, namespace, identifier): + """Finds an element with a particular identifier. + + This function allows the direct access to an arbitrarily deeply nested + child element by name, without the need to manually traverse through the + object tree. The `namespace` argument specifies the kind of element to + find. In most cases, this corresponds to the element's XML tag name. + However, if an element has multiple specialized tags, then the namespace + corresponds to the tag name of the most general element of that kind. + For example, `namespace='joint'` would search for `` and + ``, while `namespace='actuator'` would search for ``, + ``, ``, ``, and ``. + + Args: + namespace: A string specifying the namespace being searched. See the + docstring above for explanation. + identifier: The identifier string of the desired element. + + Returns: + An `mjcf.Element` object, or `None` if an element with the specified + identifier is not found. + + Raises: + ValueError: if either `namespace` or `identifier` is not a string, or if + `namespace` is not a valid namespace. + """ + + @abc.abstractmethod + def find_all(self, namespace, + immediate_children_only=False, exclude_attachments=False): + """Finds all elements of a particular kind. + + The `namespace` argument specifies the kind of element to + find. In most cases, this corresponds to the element's XML tag name. + However, if an element has multiple specialized tags, then the namespace + corresponds to the tag name of the most general element of that kind. + For example, `namespace='joint'` would search for `` and + ``, while `namespace='actuator'` would search for ``, + ``, ``, ``, and ``. + + Args: + namespace: A string specifying the namespace being searched. See the + docstring above for explanation. + immediate_children_only: (optional) A boolean, if `True` then only + the immediate children of this element are returned. + exclude_attachments: (optional) A boolean, if `True` then elements + belonging to attached models are excluded. + + Returns: + A list of `mjcf.Element`. + + Raises: + ValueError: if `namespace` is not a valid namespace. + """ + + @abc.abstractmethod + def enter_scope(self, scope_identifier): + """Finds the root element of the given scope and returns it. + + This function allows the access to a nested scope that is a child of this + element. The `scope_identifier` argument specifies the path to the child + scope element. + + Args: + scope_identifier: The path of the desired scope element. + + Returns: + An `mjcf.Element` object, or `None` if a scope element with the + specified path is not found. + """ + + @abc.abstractmethod + def get_attribute_xml_string(self, attribute_name, prefix_root=None): + pass + + @abc.abstractmethod + def get_attributes(self): + pass + + @abc.abstractmethod + def set_attributes(self, **kwargs): + pass + + @abc.abstractmethod + def get_children(self, element_name): + pass + + @abc.abstractmethod + def add(self, element_name, **kwargs): + """Add a new child element to this element. + + Args: + element_name: The tag of the element to add. + **kwargs: Attributes of the new element being created. + + Raises: + ValueError: If the 'element_name' is not a valid child, or if an invalid + attribute is specified in `kwargs`. + + Returns: + An `mjcf.Element` corresponding to the newly created child element. + """ + + @abc.abstractmethod + def remove(self, affect_attachments=False): + """Removes this element from the model.""" + + @property + @abc.abstractmethod + def is_removed(self): + pass + + @abc.abstractmethod + def all_children(self): + pass + + @abc.abstractmethod + def to_xml(self, prefix_root=None, debug_context=None, + *, + precision=constants.XML_DEFAULT_PRECISION, + zero_threshold=0): + """Generates an etree._Element corresponding to this MJCF element. + + Args: + prefix_root: (optional) A `NameScope` object to be treated as root + for the purpose of calculating the prefix. + If `None` then no prefix is included. + debug_context: (optional) A `debugging.DebugContext` object to which + the debugging information associated with the generated XML is written. + This is intended for internal use within PyMJCF; users should never need + manually pass this argument. + precision: (optional) Number of digits to output for floating point + quantities. + zero_threshold: (optional) When outputting XML, floating point quantities + whose absolute value falls below this threshold will be treated as zero. + + Returns: + An etree._Element object. + """ + + @abc.abstractmethod + def to_xml_string(self, prefix_root=None, + self_only=False, pretty_print=True, debug_context=None, + *, + precision=constants.XML_DEFAULT_PRECISION, + zero_threshold=0): + """Generates an XML string corresponding to this MJCF element. + + Args: + prefix_root: (optional) A `NameScope` object to be treated as root + for the purpose of calculating the prefix. + If `None` then no prefix is included. + self_only: (optional) A boolean, whether to generate an XML corresponding + only to this element without any children. + pretty_print: (optional) A boolean, whether to the XML string should be + properly indented. + debug_context: (optional) A `debugging.DebugContext` object to which + the debugging information associated with the generated XML is written. + This is intended for internal use within PyMJCF; users should never need + manually pass this argument. + precision: (optional) Number of digits to output for floating point + quantities. + zero_threshold: (optional) When outputting XML, floating point quantities + whose absolute value falls below this threshold will be treated as zero. + + Returns: + A string. + """ + + @abc.abstractmethod + def resolve_references(self): + pass diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/code_for_debugging_test.py b/rofunc/utils/robolab/formatter/mjcf_parser/code_for_debugging_test.py new file mode 100644 index 000000000..22f54cdc8 --- /dev/null +++ b/rofunc/utils/robolab/formatter/mjcf_parser/code_for_debugging_test.py @@ -0,0 +1,82 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Constructs models for debugging_test.py. + +The purpose of this file is to provide "golden" source line numbers for test +cases in debugging_test.py. When this module is loaded, it inspects its own +source code to look for lines that begin with `# !!LINE_REF`, and stores the +following line number in a dict. This allows test cases to look up the line +number by name, rather than brittly hard-coding in the line number. +""" + +import collections +import os +import re + +from rofunc.utils.robolab.formatter import mjcf_parser as mjcf + +SourceLine = collections.namedtuple( + 'SourceLine', ('line_number', 'text')) + +LINE_REF = {} + + +def make_valid_model(): + # !!LINE_REF make_valid_model.mjcf_model + mjcf_model = mjcf.RootElement() + # !!LINE_REF make_valid_model.my_body + my_body = mjcf_model.worldbody.add('body', name='my_body') + my_body.add('inertial', mass=1, pos=[0, 0, 0], diaginertia=[1, 1, 1]) + # !!LINE_REF make_valid_model.my_joint + my_joint = my_body.add('joint', name='my_joint', type='hinge') + # !!LINE_REF make_valid_model.my_actuator + mjcf_model.actuator.add('velocity', name='my_actuator', joint=my_joint) + return mjcf_model + + +def make_broken_model(): + # !!LINE_REF make_broken_model.mjcf_model + mjcf_model = mjcf.RootElement() + # !!LINE_REF make_broken_model.my_body + my_body = mjcf_model.worldbody.add('body', name='my_body') + my_body.add('inertial', mass=1, pos=[0, 0, 0], diaginertia=[1, 1, 1]) + # !!LINE_REF make_broken_model.my_joint + my_body.add('joint', name='my_joint', type='hinge') + # !!LINE_REF make_broken_model.my_actuator + mjcf_model.actuator.add('velocity', name='my_actuator', joint='invalid_joint') + return mjcf_model + + +def break_valid_model(mjcf_model): + # !!LINE_REF break_valid_model.my_actuator.joint + mjcf_model.find('actuator', 'my_actuator').joint = 'invalid_joint' + return mjcf_model + + +def _parse_line_refs(): + line_ref_pattern = re.compile(r'\s*# !!LINE_REF\s*([^\s]+)') + filename, _ = os.path.splitext(__file__) # __file__ can be `.pyc`. + with open(filename + '.py') as f: + src = f.read() + src_lines = src.split('\n') + for line_number, line in enumerate(src_lines): + match = line_ref_pattern.match(line) + if match: + LINE_REF[match.group(1)] = SourceLine( + line_number + 2, src_lines[line_number + 1].strip()) + + +_parse_line_refs() diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/constants.py b/rofunc/utils/robolab/formatter/mjcf_parser/constants.py new file mode 100644 index 000000000..ed24b0da5 --- /dev/null +++ b/rofunc/utils/robolab/formatter/mjcf_parser/constants.py @@ -0,0 +1,83 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Magic constants used within `dm_control.mjcf`.""" + +PREFIX_SEPARATOR = '/' +PREFIX_SEPARATOR_ESCAPE = '\\' + +# Used to disambiguate namespaces between attachment frames. +NAMESPACE_SEPARATOR = '@' + +# Magic attribute names +BASEPATH = 'basepath' +CHILDCLASS = 'childclass' +CLASS = 'class' +DEFAULT = 'default' +DCLASS = 'dclass' + +# Magic tags +ACTUATOR = 'actuator' +BODY = 'body' +DEFAULT = 'default' +MESH = 'mesh' +SITE = 'site' +SKIN = 'skin' +TENDON = 'tendon' +WORLDBODY = 'worldbody' + +# Path namespaces. +MESHDIR_NAMESPACE = 'mesh' +TEXTUREDIR_NAMESPACE = 'texture' +ASSETDIR_NAMESPACE = 'asset' + +MJDATA_TRIGGERS_DIRTY = [ + 'qpos', 'qvel', 'act', 'ctrl', 'qfrc_applied', 'xfrc_applied'] +MJMODEL_DOESNT_TRIGGER_DIRTY = [ + 'rgba', 'matid', 'emission', 'specular', 'shininess', 'reflectance', + 'needstage', +] + +# When writing into `model.{body,geom,site}_{pos,quat}` we must ensure that the +# corresponding rows in `model.{body,geom,site}_sameframe` are set to zero, +# otherwise MuJoCo will use the body or inertial frame instead of our modified +# pos/quat values. We must do the same for `body_{ipos,iquat}` and +# `body_simple`. +MJMODEL_DISABLE_ON_WRITE = { + # Field name in MjModel: (attribute names of Binding instance to be zeroed) + 'body_pos': ('sameframe',), + 'body_quat': ('sameframe',), + 'geom_pos': ('sameframe',), + 'geom_quat': ('sameframe',), + 'site_pos': ('sameframe',), + 'site_quat': ('sameframe',), + 'body_ipos': ('simple', 'sameframe'), + 'body_iquat': ('simple', 'sameframe'), +} + +MAX_VFS_FILENAME_LENGTH = 998 + +# The prefix used in the schema to denote reference_namespace that are defined +# via another attribute. +INDIRECT_REFERENCE_NAMESPACE_PREFIX = 'attrib:' + +INDIRECT_REFERENCE_ATTRIB = { + 'xbody': 'body', +} + +# 17 decimal digits is sufficient to represent a double float without loss +# of precision. +# https://en.wikipedia.org/wiki/IEEE_754#Character_representation +XML_DEFAULT_PRECISION = 17 diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/copier.py b/rofunc/utils/robolab/formatter/mjcf_parser/copier.py new file mode 100644 index 000000000..b8ca2501a --- /dev/null +++ b/rofunc/utils/robolab/formatter/mjcf_parser/copier.py @@ -0,0 +1,71 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Helper object for keeping track of new elements created when copying MJCF.""" + +from rofunc.utils.robolab.formatter.mjcf_parser import constants + + +class Copier: + """Helper for keeping track of new elements created when copying MJCF.""" + + def __init__(self, source): + if source._attachments: # pylint: disable=protected-access + raise NotImplementedError('Cannot copy from elements with attachments') + self._source = source + + def copy_into(self, destination, override_attributes=False): + """Copies this copier's element into a destination MJCF element.""" + newly_created_elements = {} + destination._check_valid_attachment(self._source) # pylint: disable=protected-access + if override_attributes: + destination.set_attributes(**self._source.get_attributes()) + else: + destination._sync_attributes(self._source, copying=True) # pylint: disable=protected-access + for source_child in self._source.all_children(): + dest_child = None + # First, if source_child has an identifier, we look for an existing child + # element of self with the same identifier to override. + if source_child.spec.identifier and override_attributes: + identifier_attr = source_child.spec.identifier + if identifier_attr == constants.CLASS: + identifier_attr = constants.DCLASS + identifier = getattr(source_child, identifier_attr) + if identifier: + dest_child = destination.find(source_child.spec.namespace, identifier) + if dest_child is not None and dest_child.parent is not destination: + raise ValueError( + '<{}> with identifier {!r} is already a child of another element' + .format(source_child.spec.namespace, identifier)) + # Next, we cover the case where either the child is not a repeated element + # or if source_child has an identifier attribute but it isn't set. + if not source_child.spec.repeated and dest_child is None: + dest_child = destination.get_children(source_child.tag) + + # Add a new element if dest_child doesn't exist, either because it is + # supposed to be a repeated child, or because it's an uncreated on-demand. + if dest_child is None: + dest_child = destination.add( + source_child.tag, **source_child.get_attributes()) + newly_created_elements[source_child] = dest_child + override_child_attributes = True + else: + override_child_attributes = override_attributes + + # Finally, copy attributes into dest_child. + child_copier = Copier(source_child) + newly_created_elements.update( + child_copier.copy_into(dest_child, override_child_attributes)) + return newly_created_elements diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/copier_test.py b/rofunc/utils/robolab/formatter/mjcf_parser/copier_test.py new file mode 100644 index 000000000..15c03fc4f --- /dev/null +++ b/rofunc/utils/robolab/formatter/mjcf_parser/copier_test.py @@ -0,0 +1,85 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for `dm_control.mjcf.copier`.""" + +import os + +import numpy as np +from absl.testing import absltest +from rofunc.utils.robolab.formatter import mjcf_parser as mjcf + +from rofunc.utils.robolab.formatter.mjcf_parser import parser + +_ASSETS_DIR = os.path.join(os.path.dirname(__file__), 'test_assets') +_TEST_MODEL_XML = os.path.join(_ASSETS_DIR, 'test_model.xml') +_MODEL_WITH_ASSETS_XML = os.path.join(_ASSETS_DIR, 'model_with_assets.xml') + + +class CopierTest(absltest.TestCase): + + def testSimpleCopy(self): + mjcf_model = parser.from_path(_TEST_MODEL_XML) + mixin = mjcf.RootElement(model='test_mixin') + mixin.compiler.boundmass = 1 + mjcf_model.include_copy(mixin) + self.assertEqual(mjcf_model.model, 'test') # Model name should not change + self.assertEqual(mjcf_model.compiler.boundmass, mixin.compiler.boundmass) + mixin.compiler.boundinertia = 2 + mjcf_model.include_copy(mixin) + self.assertEqual(mjcf_model.compiler.boundinertia, + mixin.compiler.boundinertia) + mixin.compiler.boundinertia = 1 + with self.assertRaisesRegex(ValueError, 'Conflicting values'): + mjcf_model.include_copy(mixin) + mixin.worldbody.add('body', name='b_0', pos=[0, 1, 2]) + mjcf_model.include_copy(mixin, override_attributes=True) + self.assertEqual(mjcf_model.compiler.boundmass, mixin.compiler.boundmass) + self.assertEqual(mjcf_model.compiler.boundinertia, + mixin.compiler.boundinertia) + np.testing.assert_array_equal(mjcf_model.worldbody.body['b_0'].pos, + [0, 1, 2]) + + def testCopyingWithReference(self): + sensor_mixin = mjcf.RootElement('sensor_mixin') + touch_site = sensor_mixin.worldbody.add('site', name='touch_site') + sensor_mixin.sensor.add('touch', name='touch_sensor', site=touch_site) + + mjcf_model = mjcf.RootElement('model') + mjcf_model.include_copy(sensor_mixin) + + # Copied reference should be updated to the copied site. + self.assertIs(mjcf_model.find('sensor', 'touch_sensor').site, + mjcf_model.find('site', 'touch_site')) + + def testCopyingWithAssets(self): + mjcf_model = parser.from_path(_MODEL_WITH_ASSETS_XML) + copied = mjcf.RootElement() + copied.include_copy(mjcf_model) + + original_assets = (mjcf_model.find_all('mesh') + + mjcf_model.find_all('texture') + + mjcf_model.find_all('hfield')) + copied_assets = (copied.find_all('mesh') + + copied.find_all('texture') + + copied.find_all('hfield')) + + self.assertLen(copied_assets, len(original_assets)) + for original_asset, copied_asset in zip(original_assets, copied_assets): + self.assertIs(copied_asset.file, original_asset.file) + + +if __name__ == '__main__': + absltest.main() diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/debugging.py b/rofunc/utils/robolab/formatter/mjcf_parser/debugging.py new file mode 100644 index 000000000..a63f42a70 --- /dev/null +++ b/rofunc/utils/robolab/formatter/mjcf_parser/debugging.py @@ -0,0 +1,368 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Implements PyMJCF debug mode. + +PyMJCF debug mode stores a stack trace each time the MJCF object is modified. +If Mujoco raises a compile error on the generated XML model, we would then be +able to find the original source line that created the offending element. +""" + +import collections +import contextlib +import copy +import os +import re +import sys +import traceback + +from absl import flags +from lxml import etree + +FLAGS = flags.FLAGS +flags.DEFINE_boolean( + 'mjcf_parser_debug', False, + 'Enables PyMJCF debug mode (SLOW!). In this mode, a stack trace is logged ' + 'each the MJCF object is modified. This may be helpful in locating the ' + 'Python source line corresponding to a problematic element in the ' + 'generated XML.') +flags.DEFINE_string( + 'mjcf_parser_debug_full_dump_dir', '', + 'Path to dump full debug info when Mujoco error is encountered.') + +StackTraceEntry = collections.namedtuple( + 'StackTraceEntry', ('filename', 'line_number', 'function_name', 'text')) + +ElementDebugInfo = collections.namedtuple( + 'ElementDebugInfo', ('element', 'init_stack', 'attribute_stacks')) + +MODULE_PATH = os.path.dirname(sys.modules[__name__].__file__) +DEBUG_METADATA_PREFIX = 'pymjcfdebug' + +_DEBUG_METADATA_TAG_PREFIX = ''.format(DEBUG_METADATA_PREFIX)) + +# Modified by `freeze_current_stack_trace`. +_CURRENT_FROZEN_STACK = None + +# These globals will take their default values from the `--mjcf_parser_debug` and +# `--pymjcf_debug_full_dump_dir` flags respectively. We cannot use `FLAGS` as +# global variables because flag parsing might not have taken place (e.g. when +# running `nosetests`). +_DEBUG_MODE_ENABLED = None +_DEBUG_FULL_DUMP_DIR = None + + +def debug_mode(): + """Returns a boolean that indicates whether PyMJCF debug mode is enabled.""" + global _DEBUG_MODE_ENABLED + if _DEBUG_MODE_ENABLED is None: + if FLAGS.is_parsed(): + _DEBUG_MODE_ENABLED = FLAGS.mjcf_parser_debug + else: + _DEBUG_MODE_ENABLED = FLAGS['mjcf_parser_debug'].default + return _DEBUG_MODE_ENABLED + + +def enable_debug_mode(): + """Enables PyMJCF debug mode.""" + global _DEBUG_MODE_ENABLED + _DEBUG_MODE_ENABLED = True + + +def disable_debug_mode(): + """Disables PyMJCF debug mode.""" + global _DEBUG_MODE_ENABLED + _DEBUG_MODE_ENABLED = False + + +def get_full_dump_dir(): + """Gets the directory to dump full debug info files.""" + global _DEBUG_FULL_DUMP_DIR + if _DEBUG_FULL_DUMP_DIR is None: + if FLAGS.is_parsed(): + _DEBUG_FULL_DUMP_DIR = FLAGS.pymjcf_debug_full_dump_dir + else: + _DEBUG_FULL_DUMP_DIR = FLAGS['pymjcf_debug_full_dump_dir'].default + return _DEBUG_FULL_DUMP_DIR + + +def set_full_dump_dir(dump_path): + """Sets the directory to dump full debug info files.""" + global _DEBUG_FULL_DUMP_DIR + _DEBUG_FULL_DUMP_DIR = dump_path + + +def get_current_stack_trace(): + """Returns the stack trace of the current execution frame. + + Returns: + A list of `StackTraceEntry` named tuples corresponding to the current stack + trace of the process, truncated to immediately before entry into + PyMJCF internal code. + """ + if _CURRENT_FROZEN_STACK: + return copy.deepcopy(_CURRENT_FROZEN_STACK) + else: + return _get_actual_current_stack_trace() + + +def _get_actual_current_stack_trace(): + """Returns the stack trace of the current execution frame. + + Returns: + A list of `StackTraceEntry` named tuples corresponding to the current stack + trace of the process, truncated to immediately before entry into + PyMJCF internal code. + """ + raw_stack = traceback.extract_stack() + processed_stack = [] + for raw_stack_item in raw_stack: + stack_item = StackTraceEntry(*raw_stack_item) + if (stack_item.filename.startswith(MODULE_PATH) + and not stack_item.filename.endswith('_test.py')): + break + else: + processed_stack.append(stack_item) + return processed_stack + + +@contextlib.contextmanager +def freeze_current_stack_trace(): + """A context manager that freezes the stack trace. + + AVOID USING THIS CONTEXT MANAGER OUTSIDE OF INTERNAL PYMJCF IMPLEMENTATION, + AS IT REDUCES THE USEFULNESS OF DEBUG MODE. + + If PyMJCF debug mode is enabled, calls to `debugging.get_current_stack_trace` + within this context will always return the stack trace from when this context + was entered. + + The frozen stack is global to this debugging module. That is, if the context + is entered while another one is still active, then the stack trace of the + outermost one is returned. + + This context significantly speeds up bulk operations in debug mode, e.g. + parsing an existing XML string or creating a deeply-nested element, as it + prevents the same stack trace from being repeatedly constructed. + + Yields: + `None` + """ + global _CURRENT_FROZEN_STACK + if debug_mode() and _CURRENT_FROZEN_STACK is None: + _CURRENT_FROZEN_STACK = _get_actual_current_stack_trace() + yield + _CURRENT_FROZEN_STACK = None + else: + yield + + +class DebugContext: + """A helper object to store debug information for a generated XML string. + + This class is intended for internal use within the PyMJCF implementation. + """ + + def __init__(self): + self._xml_string = None + self._debug_info_for_element_ids = {} + + def register_element_for_debugging(self, elem): + """Registers an `Element` and returns debugging metadata for the XML. + + Args: + elem: An `mjcf.Element`. + + Returns: + An `lxml.etree.Comment` that represents debugging metadata in the + generated XML. + """ + if not debug_mode(): + return None + else: + self._debug_info_for_element_ids[id(elem)] = ElementDebugInfo( + elem, + copy.deepcopy(elem.get_init_stack()), + copy.deepcopy(elem.get_last_modified_stacks_for_all_attributes())) + return etree.Comment('{}:{}'.format(DEBUG_METADATA_PREFIX, id(elem))) + + def commit_xml_string(self, xml_string): + """Commits the XML string associated with this debug context. + + This function also formats the XML string to make sure that the debugging + metadata appears on the same line as the corresponding XML element. + + Args: + xml_string: A pretty-printed XML string. + + Returns: + A reformatted XML string where all debugging metadata appears on the same + line as the corresponding XML element. + """ + formatted = re.sub(r'\n\s*' + _DEBUG_METADATA_TAG_PREFIX, + _DEBUG_METADATA_TAG_PREFIX, xml_string) + self._xml_string = formatted + return formatted + + def process_and_raise_last_exception(self): + """Processes and re-raises the last ValueError caught. + + This function will insert the relevant line from the source XML to the error + message. If debug mode is enabled, additional debugging information is + appended to the error message. If debug mode is not enabled, the error + message instructs the user to enable it by rerunning the executable with an + appropriate flag. + """ + err_type, err, tb = sys.exc_info() + line_number_match = re.search(r'[Ll][Ii][Nn][Ee]\s*[:=]?\s*(\d+)', str(err)) + if line_number_match: + xml_line_number = int(line_number_match.group(1)) + xml_line = self._xml_string.split('\n')[xml_line_number - 1] + stripped_xml_line = xml_line.strip() + comment_match = re.search(_DEBUG_METADATA_TAG_PREFIX, stripped_xml_line) + if comment_match: + stripped_xml_line = stripped_xml_line[:comment_match.start()] + else: + xml_line = '' + stripped_xml_line = '' + + message_lines = [] + if debug_mode(): + if get_full_dump_dir(): + self.dump_full_debug_info_to_disk() + message_lines.extend([ + 'Compile error raised by Mujoco.', + str(err)]) + if xml_line: + message_lines.extend([ + stripped_xml_line, + self._generate_debug_message_from_xml_line(xml_line)]) + else: + message_lines.extend([ + 'Compile error raised by Mujoco; ' + + 'run again with --mjcf_parser_debug for additional debug information.', + str(err) + ]) + if xml_line: + message_lines.append(stripped_xml_line) + + raise err_type('\n'.join(message_lines)).with_traceback(tb) + + @property + def default_dump_dir(self): + return get_full_dump_dir() + + @property + def debug_mode(self): + return debug_mode() + + def dump_full_debug_info_to_disk(self, dump_dir=None): + """Dumps full debug information to disk. + + Full debug information consists of an XML file whose elements are tagged + with a unique ID, and a stack trace file for each element ID. Each stack + trace file consists of a stack trace for when the element was created, and + when each attribute was last modified. + + Args: + dump_dir: Full path to the directory in which dump files are created. + + Raises: + ValueError: If neither `dump_dir` nor the global dump path is given. The + global dump path can be specified either via the + --pymjcf_debug_full_dump_dir flag or via `debugging.set_full_dump_dir`. + """ + dump_dir = dump_dir or self.default_dump_dir + if not dump_dir: + raise ValueError('`dump_dir` is not specified') + section_separator = '\n' + ('=' * 80) + '\n' + + def dump_stack(header, stack, f): + indent = ' ' + f.write(header + '\n') + for stack_entry in stack: + f.write(indent + '`{}` at {}:{}\n' + .format(stack_entry.function_name, + stack_entry.filename, stack_entry.line_number)) + f.write((indent * 2) + str(stack_entry.text) + '\n') + f.write(section_separator) + + with open(os.path.join(dump_dir, 'model.xml'), 'w') as f: + f.write(self._xml_string) + for elem_id, debug_info in self._debug_info_for_element_ids.items(): + with open(os.path.join(dump_dir, str(elem_id) + '.dump'), 'w') as f: + f.write('{}:{}\n'.format(DEBUG_METADATA_PREFIX, elem_id)) + f.write(str(debug_info.element) + '\n') + dump_stack('Element creation', debug_info.init_stack, f) + for attrib_name, stack in debug_info.attribute_stacks.items(): + attrib_value = ( + debug_info.element.get_attribute_xml_string(attrib_name)) + if stack[-1] == debug_info.init_stack[-1]: + if attrib_value is not None: + f.write('Attribute {}="{}"\n'.format(attrib_name, attrib_value)) + f.write(' was set when the element was created\n') + f.write(section_separator) + else: + if attrib_value is not None: + dump_stack('Attribute {}="{}"'.format(attrib_name, attrib_value), + stack, f) + else: + dump_stack( + 'Attribute {} was CLEARED'.format(attrib_name), stack, f) + + def _generate_debug_message_from_xml_line(self, xml_line): + """Generates a debug message by parsing the metadata on an XML line.""" + metadata_match = _DEBUG_METADATA_SEARCH_PATTERN.search(xml_line) + if metadata_match: + elem_id = int(metadata_match.group(1)) + return self._generate_debug_message_from_element_id(elem_id) + else: + return '' + + def _generate_debug_message_from_element_id(self, elem_id): + """Generates a debug message for the specified Element.""" + out = [] + debug_info = self._debug_info_for_element_ids[elem_id] + + out.append('Debug summary for element:') + if not get_full_dump_dir(): + out.append( + ' * Full debug info can be dumped to disk by setting the ' + 'flag --pymjcf_debug_full_dump_dir=path/to/dump>') + out.append(' * Element object was created by `{}` at {}:{}' + .format(debug_info.init_stack[-1].function_name, + debug_info.init_stack[-1].filename, + debug_info.init_stack[-1].line_number)) + + for attrib_name, stack in debug_info.attribute_stacks.items(): + attrib_value = debug_info.element.get_attribute_xml_string(attrib_name) + if stack[-1] == debug_info.init_stack[-1]: + if attrib_value is not None: + out.append(' * {}="{}" was set when the element was created' + .format(attrib_name, attrib_value)) + else: + if attrib_value is not None: + out.append(' * {}="{}" was set by `{}` at `{}:{}`' + .format(attrib_name, attrib_value, + stack[-1].function_name, stack[-1].filename, + stack[-1].line_number)) + else: + out.append(' * {} was CLEARED by `{}` at {}:{}' + .format(attrib_name, stack[-1].function_name, + stack[-1].filename, stack[-1].line_number)) + + return '\n'.join(out) diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/debugging_test.py b/rofunc/utils/robolab/formatter/mjcf_parser/debugging_test.py new file mode 100644 index 000000000..4f3f9731b --- /dev/null +++ b/rofunc/utils/robolab/formatter/mjcf_parser/debugging_test.py @@ -0,0 +1,177 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for `mjcf.debugging`.""" + +import contextlib +import os +import re +import shutil +import sys + +from absl.testing import absltest +from rofunc.utils.robolab.formatter import mjcf_parser as mjcf + +from rofunc.utils.robolab.formatter.mjcf_parser import code_for_debugging_test as test_code +from rofunc.utils.robolab.formatter.mjcf_parser import debugging + +ORIGINAL_DEBUG_MODE = debugging.debug_mode() + + +class DebuggingTest(absltest.TestCase): + + def tearDown(self): + super().tearDown() + if ORIGINAL_DEBUG_MODE: + debugging.enable_debug_mode() + else: + debugging.disable_debug_mode() + + def setup_debug_mode(self, debug_mode_enabled, full_dump_enabled=False): + if debug_mode_enabled: + debugging.enable_debug_mode() + else: + debugging.disable_debug_mode() + if full_dump_enabled: + base_dir = absltest.get_default_test_tmpdir() + self.dump_dir = os.path.join(base_dir, 'mjcf_debugging_test') + shutil.rmtree(self.dump_dir, ignore_errors=True) + os.mkdir(self.dump_dir) + else: + self.dump_dir = '' + debugging.set_full_dump_dir(self.dump_dir) + + def assertStackFromTestCode(self, stack, function_name, line_ref): + self.assertEqual(stack[-1].function_name, function_name) + self.assertStartsWith(test_code.__file__, stack[-1].filename) + line_info = test_code.LINE_REF['.'.join([function_name, line_ref])] + self.assertEqual(stack[-1].line_number, line_info.line_number) + self.assertEqual(stack[-1].text, line_info.text) + + @contextlib.contextmanager + def assertRaisesTestCodeRef(self, line_ref): + filename, _ = os.path.splitext(test_code.__file__) + expected_message = ( + filename + '.py:' + str(test_code.LINE_REF[line_ref].line_number)) + print(expected_message) + with self.assertRaisesRegex(ValueError, expected_message): + yield + + def test_get_current_stack_trace(self): + self.setup_debug_mode(debug_mode_enabled=True) + stack_trace = debugging.get_current_stack_trace() + self.assertStartsWith( + sys.modules[__name__].__file__, stack_trace[-1].filename) + self.assertEqual(stack_trace[-1].function_name, + 'test_get_current_stack_trace') + self.assertEqual(stack_trace[-1].text, + 'stack_trace = debugging.get_current_stack_trace()') + + def test_disable_debug_mode(self): + self.setup_debug_mode(debug_mode_enabled=False) + mjcf_model = test_code.make_valid_model() + test_code.break_valid_model(mjcf_model) + self.assertFalse(mjcf_model.get_init_stack()) + + my_actuator = mjcf_model.find('actuator', 'my_actuator') + my_actuator_attrib_stacks = ( + my_actuator.get_last_modified_stacks_for_all_attributes()) + for stack in my_actuator_attrib_stacks.values(): + self.assertFalse(stack) + + def test_element_and_attribute_stacks(self): + self.setup_debug_mode(debug_mode_enabled=True) + mjcf_model = test_code.make_valid_model() + test_code.break_valid_model(mjcf_model) + + self.assertStackFromTestCode(mjcf_model.get_init_stack(), + 'make_valid_model', 'mjcf_model') + + my_actuator = mjcf_model.find('actuator', 'my_actuator') + self.assertStackFromTestCode(my_actuator.get_init_stack(), + 'make_valid_model', 'my_actuator') + + my_actuator_attrib_stacks = ( + my_actuator.get_last_modified_stacks_for_all_attributes()) + # `name` attribute was assigned at the same time as the element was created. + self.assertEqual(my_actuator_attrib_stacks['name'], + my_actuator.get_init_stack()) + # `joint` attribute was modified later on. + self.assertStackFromTestCode(my_actuator_attrib_stacks['joint'], + 'break_valid_model', 'my_actuator.joint') + + def test_valid_physics(self): + self.setup_debug_mode(debug_mode_enabled=True) + mjcf_model = test_code.make_valid_model() + mjcf.Physics.from_mjcf_model(mjcf_model) # Should not raise + + def test_physics_error_message_outside_of_debug_mode(self): + self.setup_debug_mode(debug_mode_enabled=False) + mjcf_model = test_code.make_broken_model() + # Make sure that we advertise debug mode if it's currently disabled. + with self.assertRaisesRegex(ValueError, '--mjcf_parser_debug'): + mjcf.Physics.from_mjcf_model(mjcf_model) + + def test_physics_error_message_in_debug_mode(self): + self.setup_debug_mode(debug_mode_enabled=True) + mjcf_model_1 = test_code.make_broken_model() + with self.assertRaisesTestCodeRef('make_broken_model.my_actuator'): + mjcf.Physics.from_mjcf_model(mjcf_model_1) + mjcf_model_2 = test_code.make_valid_model() + physics = mjcf.Physics.from_mjcf_model(mjcf_model_2) # Should not raise. + test_code.break_valid_model(mjcf_model_2) + with self.assertRaisesTestCodeRef('break_valid_model.my_actuator.joint'): + physics.reload_from_mjcf_model(mjcf_model_2) + + def test_full_debug_dump(self): + self.setup_debug_mode(debug_mode_enabled=True, full_dump_enabled=False) + mjcf_model = test_code.make_valid_model() + test_code.break_valid_model(mjcf_model) + # Make sure that we advertise full dump mode if it's currently disabled. + with self.assertRaisesRegex(ValueError, '--pymjcf_debug_full_dump_dir'): + mjcf.Physics.from_mjcf_model(mjcf_model) + self.setup_debug_mode(debug_mode_enabled=True, full_dump_enabled=True) + with self.assertRaises(ValueError): + mjcf.Physics.from_mjcf_model(mjcf_model) + + with open(os.path.join(self.dump_dir, 'model.xml')) as f: + dumped_xml = f.read() + dumped_xml = [line.strip() for line in dumped_xml.strip().split('\n')] + + xml_line_pattern = re.compile(r'^(.*)$') + uninstrumented_pattern = re.compile(r'({})'.format( + '|'.join([ + r'', + r'', + r'', + r'' + ]))) + + for xml_line in dumped_xml: + print(xml_line) + xml_line_match = xml_line_pattern.match(xml_line) + if not xml_line_match: + # Only uninstrumented lines are allowed to have no metadata. + self.assertIsNotNone(uninstrumented_pattern.match(xml_line)) + else: + xml_element = xml_line_match.group(1) + debug_id = int(xml_line_match.group(2)) + with open(os.path.join(self.dump_dir, str(debug_id) + '.dump')) as f: + element_dump = f.read() + self.assertIn(xml_element, element_dump) + + +if __name__ == '__main__': + absltest.main() diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/element.py b/rofunc/utils/robolab/formatter/mjcf_parser/element.py new file mode 100644 index 000000000..477d914b9 --- /dev/null +++ b/rofunc/utils/robolab/formatter/mjcf_parser/element.py @@ -0,0 +1,1414 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Classes to represent MJCF elements in the object model.""" + +import collections +import copy +import os +import sys + +import numpy as np +from lxml import etree + +from rofunc.utils.robolab.formatter.mjcf_parser import attribute as attribute_types +from rofunc.utils.robolab.formatter.mjcf_parser import base +from rofunc.utils.robolab.formatter.mjcf_parser import constants +from rofunc.utils.robolab.formatter.mjcf_parser import copier +from rofunc.utils.robolab.formatter.mjcf_parser import debugging +from rofunc.utils.robolab.formatter.mjcf_parser import namescope +from rofunc.utils.robolab.formatter.mjcf_parser import schema +from rofunc.utils.robolab.formatter.mjcf_parser import util + +_raw_property = property # pylint: disable=invalid-name + +_UNITS = ('K', 'M', 'G', 'T', 'P', 'E') + + +def _to_bytes(value_str): + """Converts a `str` value representing a size in bytes to `int`. + + Args: + value_str: `str` value to be converted. + + Returns: + `int` corresponding size in bytes. + + Raises: + ValueError: if the `str` value passed has an unsupported unit. + """ + if value_str.isdigit(): + value_int = int(value_str) + else: + value_int = int(value_str[:-1]) + unit = value_str[-1].upper() + if unit not in _UNITS: + raise ValueError( + f'unit of `size.memory` should be one of [{", ".join(_UNITS)}], got' + f' {unit}') + power = 10 * (_UNITS.index(unit) + 1) + value_int *= 2 ** power + return value_int + + +def _max_bytes(first, second): + return str(max(_to_bytes(first), _to_bytes(second))) + + +_CONFLICT_BEHAVIOR_FUNC = {'min': min, 'max': max, 'max_bytes': _max_bytes} + + +def property(method): # pylint: disable=redefined-builtin + """Modifies `@property` to keep track of any `AttributeError` raised. + + Our `Element` implementations overrides the `__getattr__` method. This does + not interact well with `@property`: if a `property`'s code is buggy so as to + raise an `AttributeError`, then Python would silently discard it and redirect + to our `__getattr__` instead, leading to an uninformative stack trace. This + makes it very difficult to debug issues that involve properties. + + To remedy this, we modify `@property` within this module to store any + `AttributeError` raised within the respective `Element` object. Then, in our + `__getattr__` logic, we could re-raise it to preserve the original stack + trace. + + The reason that this is not implemented as a different decorator is that we + could accidentally use @property on a new method. This would work fine until + someone triggers a subtle bug. This is when a proper trace would be most + useful, but we would still end up with a strange undebuggable stack trace + anyway. + + Note that at the end of this module, we have a `del property` to prevent this + override from being broadcasted externally. + + Args: + method: The method that is being decorated. + + Returns: + A `property` corresponding to the decorated method. + """ + + def _mjcf_property(self): + try: + return method(self) + except: + _, err, tb = sys.exc_info() + err_with_next_tb = err.with_traceback(tb.tb_next) + if isinstance(err, AttributeError): + self._last_attribute_error = err_with_next_tb # pylint: disable=protected-access + raise err_with_next_tb # pylint: disable=raise-missing-from + + return _raw_property(_mjcf_property) + + +def _make_element(spec, parent, attributes=None): + """Helper function to generate the right kind of Element given a spec.""" + if (spec.name == constants.WORLDBODY + or (spec.name == constants.SITE + and (parent.tag == constants.BODY + or parent.tag == constants.WORLDBODY))): + return _AttachableElement(spec, parent, attributes) + elif isinstance(parent, _AttachmentFrame): + return _AttachmentFrameChild(spec, parent, attributes) + elif spec.name == constants.DEFAULT: + return _DefaultElement(spec, parent, attributes) + elif spec.name == constants.ACTUATOR: + return _ActuatorElement(spec, parent, attributes) + else: + return _ElementImpl(spec, parent, attributes) + + +_DEFAULT_NAME_FROM_FILENAME = frozenset(['mesh', 'hfield', 'texture']) + + +class _ElementImpl(base.Element): + """Actual implementation of a generic MJCF element object.""" + __slots__ = ['__weakref__', '_spec', '_parent', '_attributes', '_children', + '_own_attributes', '_attachments', '_is_removed', '_init_stack', + '_is_worldbody', '_cached_namescope', '_cached_root', + '_cached_full_identifier', '_cached_revision', + '_last_attribute_error'] + + def __init__(self, spec, parent, attributes=None): + attributes = attributes or {} + + # For certain `asset` elements the `name` attribute can be omitted, in which + # case the name will be the filename without the leading path and extension. + # See http://www.mujoco.org/book/XMLreference.html#asset. + if ('name' not in attributes + and 'file' in attributes + and spec.name in _DEFAULT_NAME_FROM_FILENAME): + _, filename = os.path.split(attributes['file']) + basename, _ = os.path.splitext(filename) + attributes['name'] = basename + + self._spec = spec + self._parent = parent + self._attributes = collections.OrderedDict() + self._own_attributes = None + self._children = [] + self._attachments = collections.OrderedDict() + self._is_removed = False + self._is_worldbody = (self.tag == 'worldbody') + + if self._parent: + self._cached_namescope = self._parent.namescope + self._cached_root = self._parent.root + self._cached_full_identifier = '' + self._cached_revision = -1 + + self._last_attribute_error = None + + if debugging.debug_mode(): + self._init_stack = debugging.get_current_stack_trace() + + with debugging.freeze_current_stack_trace(): + for child_spec in self._spec.children.values(): + if not (child_spec.repeated or child_spec.on_demand): + self._children.append(_make_element(spec=child_spec, parent=self)) + + if constants.DCLASS in attributes: + attributes[constants.CLASS] = attributes[constants.DCLASS] + del attributes[constants.DCLASS] + + for attribute_name in attributes.keys(): + self._check_valid_attribute(attribute_name) + + for attribute_spec in self._spec.attributes.values(): + value = None + # Some Reference attributes refer to a namespace that is specified + # via another attribute. We therefore have to set things up for + # the additional indirection. + if attribute_spec.type is attribute_types.Reference: + reference_namespace = ( + attribute_spec.other_kwargs['reference_namespace']) + if reference_namespace.startswith( + constants.INDIRECT_REFERENCE_NAMESPACE_PREFIX): + attribute_spec = copy.deepcopy(attribute_spec) + namespace_attrib_name = reference_namespace[ + len(constants.INDIRECT_REFERENCE_NAMESPACE_PREFIX):] + attribute_spec.other_kwargs['reference_namespace'] = ( + self._attributes[namespace_attrib_name]) + if attribute_spec.name in attributes: + value = attributes[attribute_spec.name] + try: + self._attributes[attribute_spec.name] = attribute_spec.type( + name=attribute_spec.name, + required=attribute_spec.required, + conflict_allowed=attribute_spec.conflict_allowed, + conflict_behavior=attribute_spec.conflict_behavior, + parent=self, value=value, **attribute_spec.other_kwargs) + except: + # On failure, clear attributes already created + for attribute_obj in self._attributes.values(): + attribute_obj._force_clear() # pylint: disable=protected-access + # Then raise a meaningful error + err_type, err, tb = sys.exc_info() + raise err_type( # pylint: disable=raise-missing-from + f'during initialization of attribute {attribute_spec.name!r} of ' + f'element <{self._spec.name}>: {err}').with_traceback(tb) + + def get_init_stack(self): + """Gets the stack trace where this element was first initialized.""" + if debugging.debug_mode(): + return self._init_stack + + def get_last_modified_stacks_for_all_attributes(self): + """Gets a dict of stack traces where each attribute was last modified.""" + return collections.OrderedDict( + [(name, self._attributes[name].last_modified_stack) + for name in self._spec.attributes]) + + def is_same_as(self, other): + """Checks whether another element is semantically equivalent to this one. + + Two elements are considered equivalent if they have the same + specification (i.e. same tag appearing in the same context), the same + attribute values, and all of their children are equivalent. The ordering + of non-repeated children is not important for this comparison, while + the ordering of repeated children are important only amongst the same + type* of children. In other words, for two bodies to be considered + equivalent, their child sites must appear in the same order, and their + child geoms must appear in the same order, but permutations between sites + and geoms are disregarded. (The only exception is in tendon definition, + where strict ordering of all children is necessary for equivalence.) + + *Note that the notion of "same type" in this function is very loose: + for example different actuator element subtypes are treated as separate + types when children ordering is considered. Therefore, two + elements might be considered equivalent even though they result in different + orderings of `mjData.ctrl` when compiled. As it stands, this function + is designed primarily as a testing aid and should not be used to guarantee + that models are actually identical. + + Args: + other: An `mjcf.Element` + + Returns: + `True` if `other` element is semantically equivalent to this one. + """ + if other is None or other.spec != self._spec: + return False + + for attribute_name in self._spec.attributes.keys(): + attribute = self._attributes[attribute_name] + other_attribute = getattr(other, attribute_name) + if isinstance(attribute.value, base.Element): + if attribute.value.full_identifier != other_attribute.full_identifier: + return False + elif not np.all(attribute.value == other_attribute): + return False + + if (self._parent and + self._parent.tag == constants.TENDON and + self._parent.parent == self.root): + return self._tendon_has_same_children_as(other) + else: + return self._has_same_children_as(other) + + def _has_same_children_as(self, other): + """Helper function to check whether another element has the same children. + + See docstring for `is_same_as` for explanation about the treatment of + children ordering. + + Args: + other: An `mjcf.Element` + + Returns: + A boolean + """ + for child_name, child_spec in self._spec.children.items(): + child = self.get_children(child_name) + other_child = getattr(other, child_name) + if not child_spec.repeated: + if ((child is None and other_child is not None) or + (child is not None and not child.is_same_as(other_child))): + return False + else: + if len(child) != len(other_child): + return False + else: + for grandchild, other_grandchild in zip(child, other_child): + if not grandchild.is_same_as(other_grandchild): + return False + return True + + def _tendon_has_same_children_as(self, other): + return all(child.is_same_as(other_child) + for child, other_child + in zip(self.all_children(), other.all_children())) + + def _alias_attributes_dict(self, other): + if self._own_attributes is None: + self._own_attributes = self._attributes + self._attributes = other + + def _restore_attributes_dict(self): + if self._own_attributes is not None: + for attribute_name, attribute in self._attributes.items(): + self._own_attributes[attribute_name].value = attribute.value + self._attributes = self._own_attributes + self._own_attributes = None + + @property + def tag(self): + return self._spec.name + + @property + def spec(self): + return self._spec + + @property + def parent(self): + return self._parent + + @property + def namescope(self): + return self._cached_namescope + + @property + def root(self): + return self._cached_root + + def prefixed_identifier(self, prefix_root): + if not self._spec.identifier and not self._is_worldbody: + return None + elif self._is_worldbody: + prefix = self.namescope.full_prefix(prefix_root=prefix_root) + return prefix or 'world' + else: + full_identifier = ( + self._attributes[self._spec.identifier].to_xml_string( + prefix_root=prefix_root)) + if full_identifier: + return full_identifier + else: + prefix = self.namescope.full_prefix(prefix_root=prefix_root) + prefix = prefix or constants.PREFIX_SEPARATOR + return prefix + self._default_identifier + + @property + def full_identifier(self): + """Fully-qualified identifier used for this element in the generated XML.""" + if self.namescope.revision > self._cached_revision: + self._cached_full_identifier = self.prefixed_identifier( + prefix_root=self.namescope.root) + self._cached_revision = self.namescope.revision + return self._cached_full_identifier + + @property + def _default_identifier(self): + """The default identifier used if this element is not named by the user.""" + if not self._spec.identifier: + return None + else: + siblings = self.root.find_all(self._spec.namespace, + exclude_attachments=True) + return '{separator}unnamed_{namespace}_{index}'.format( + separator=constants.PREFIX_SEPARATOR, + namespace=self._spec.namespace, + index=siblings.index(self)) + + def __dir__(self): + out_dir = set() + classes = (type(self),) + while classes: + super_classes = set() + for klass in classes: + out_dir.update(klass.__dict__) + super_classes.update(klass.__bases__) + classes = super_classes + out_dir.update(self._spec.children) + out_dir.update(self._spec.attributes) + if constants.CLASS in out_dir: + out_dir.remove(constants.CLASS) + out_dir.add(constants.DCLASS) + return sorted(out_dir) + + def find(self, namespace, identifier): + """Finds an element with a particular identifier. + + This function allows the direct access to an arbitrarily deeply nested + child element by name, without the need to manually traverse through the + object tree. The `namespace` argument specifies the kind of element to + find. In most cases, this corresponds to the element's XML tag name. + However, if an element has multiple specialized tags, then the namespace + corresponds to the tag name of the most general element of that kind. + For example, `namespace='joint'` would search for `` and + ``, while `namespace='actuator'` would search for ``, + ``, ``, ``, and ``. + + Args: + namespace: A string specifying the namespace being searched. See the + docstring above for explanation. + identifier: The identifier string of the desired element. + + Returns: + An `mjcf.Element` object, or `None` if an element with the specified + identifier is not found. + + Raises: + ValueError: if either `namespace` or `identifier` is not a string, or if + `namespace` is not a valid namespace. + """ + if not isinstance(namespace, str): + raise ValueError( + '`namespace` should be a string: got {!r}'.format(namespace)) + if not isinstance(identifier, str): + raise ValueError( + '`identifier` should be a string: got {!r}'.format(identifier)) + if namespace not in schema.FINDABLE_NAMESPACES: + raise ValueError('{!r} is not a valid namespace. Available: {}.'.format( + namespace, schema.FINDABLE_NAMESPACES)) + if constants.PREFIX_SEPARATOR in identifier: + scope_name = identifier.split(constants.PREFIX_SEPARATOR)[0] + try: + attachment = self.namescope.get('attached_model', scope_name) + found_element = attachment.find( + namespace, identifier[(len(scope_name) + 1):]) + except (KeyError, ValueError): + found_element = None + else: + try: + found_element = self.namescope.get(namespace, identifier) + except KeyError: + found_element = None + if found_element and self._parent: + next_parent = found_element.parent + while next_parent and next_parent != self: + next_parent = next_parent.parent + if not next_parent: + found_element = None + return found_element + + def find_all(self, namespace, + immediate_children_only=False, exclude_attachments=False): + """Finds all elements of a particular kind. + + The `namespace` argument specifies the kind of element to + find. In most cases, this corresponds to the element's XML tag name. + However, if an element has multiple specialized tags, then the namespace + corresponds to the tag name of the most general element of that kind. + For example, `namespace='joint'` would search for `` and + ``, while `namespace='actuator'` would search for ``, + ``, ``, ``, and ``. + + Args: + namespace: A string specifying the namespace being searched. See the + docstring above for explanation. + immediate_children_only: (optional) A boolean, if `True` then only + the immediate children of this element are returned. + exclude_attachments: (optional) A boolean, if `True` then elements + belonging to attached models are excluded. + + Returns: + A list of `mjcf.Element`. + + Raises: + ValueError: if `namespace` is not a valid namespace. + """ + if namespace not in schema.FINDABLE_NAMESPACES: + raise ValueError('{!r} is not a valid namespace. Available: {}'.format( + namespace, schema.FINDABLE_NAMESPACES)) + out = [] + children = self._children if exclude_attachments else self.all_children() + for child in children: + if (namespace == child.spec.namespace or + # Direct children of attachment frames have custom namespaces of the + # form "joint@attachment_frame_". + child.spec.namespace and child.spec.namespace.startswith( + namespace + constants.NAMESPACE_SEPARATOR) or + # Attachment frames are considered part of the "body" namespace. + namespace == constants.BODY and isinstance(child, _AttachmentFrame)): + out.append(child) + if not immediate_children_only: + out.extend(child.find_all(namespace, + exclude_attachments=exclude_attachments)) + return out + + def enter_scope(self, scope_identifier): + """Finds the root element of the given scope and returns it. + + This function allows the access to a nested scope that is a child of this + element. The `scope_identifier` argument specifies the path to the child + scope element. + + Args: + scope_identifier: The path of the desired scope element. + + Returns: + An `mjcf.Element` object, or `None` if a scope element with the + specified path is not found. + """ + if constants.PREFIX_SEPARATOR in scope_identifier: + scope_name = scope_identifier.split(constants.PREFIX_SEPARATOR)[0] + try: + attachment = self.namescope.get('attached_model', scope_name) + except KeyError: + return None + + scope_suffix = scope_identifier[(len(scope_name) + 1):] + if scope_suffix: + return attachment.enter_scope(scope_suffix) + else: + return attachment + else: + try: + return self.namescope.get('attached_model', scope_identifier) + except KeyError: + return None + + def _check_valid_attribute(self, attribute_name): + if attribute_name not in self._spec.attributes: + raise AttributeError( + '{!r} is not a valid attribute for <{}>'.format( + attribute_name, self._spec.name)) + + def _get_attribute(self, attribute_name): + self._check_valid_attribute(attribute_name) + return self._attributes[attribute_name].value + + def get_attribute_xml_string(self, + attribute_name, + prefix_root=None, + *, + precision=constants.XML_DEFAULT_PRECISION, + zero_threshold=0): + self._check_valid_attribute(attribute_name) + return self._attributes[attribute_name].to_xml_string( + prefix_root, precision=precision, zero_threshold=zero_threshold) + + def get_attributes(self): + fix_attribute_name = ( + lambda name: constants.DCLASS if name == constants.CLASS else name) + return collections.OrderedDict( + [(fix_attribute_name(name), self._get_attribute(name)) + for name in self._spec.attributes.keys() + if self._get_attribute(name) is not None]) + + def _set_attribute(self, attribute_name, value): + self._check_valid_attribute(attribute_name) + self._attributes[attribute_name].value = value + self.namescope.increment_revision() + + def set_attributes(self, **kwargs): + if constants.DCLASS in kwargs: + kwargs[constants.CLASS] = kwargs[constants.DCLASS] + del kwargs[constants.DCLASS] + old_values = [] + with debugging.freeze_current_stack_trace(): + for attribute_name, new_value in kwargs.items(): + old_value = self._get_attribute(attribute_name) + try: + self._set_attribute(attribute_name, new_value) + old_values.append((attribute_name, old_value)) + except: + # On failure, restore old attribute values for those already set. + for name, old_value in old_values: + self._set_attribute(name, old_value) + # Then raise a meaningful error. + err_type, err, tb = sys.exc_info() + raise err_type( # pylint: disable=raise-missing-from + f'during assignment to attribute {attribute_name!r} of ' + f'element <{self._spec.name}>: {err}').with_traceback(tb) + + def _remove_attribute(self, attribute_name): + self._check_valid_attribute(attribute_name) + self._attributes[attribute_name].clear() + self.namescope.increment_revision() + + def _check_valid_child(self, element_name): + try: + return self._spec.children[element_name] + except KeyError: + raise AttributeError( # pylint: disable=raise-missing-from + '<{}> is not a valid child of <{}>' + .format(element_name, self._spec.name)) + + def get_children(self, element_name): + child_spec = self._check_valid_child(element_name) + if child_spec.repeated: + return _ElementListView(spec=child_spec, parent=self) + else: + for child in self._children: + if child.tag == element_name: + return child + if child_spec.on_demand: + return None + else: + raise RuntimeError( + 'Cannot find the non-repeated child <{}> of <{}>. ' + 'This should never happen, as we pre-create these in __init__. ' + 'Please file an bug report. Thank you.' + .format(element_name, self._spec.name)) + + def add(self, element_name, **kwargs): + """Add a new child element to this element. + + Args: + element_name: The tag of the element to add. + **kwargs: Attributes of the new element being created. + + Raises: + ValueError: If the 'element_name' is not a valid child, or if an invalid + attribute is specified in `kwargs`. + + Returns: + An `mjcf.Element` corresponding to the newly created child element. + """ + return self.insert(element_name, position=None, **kwargs) + + def insert(self, element_name, position, **kwargs): + """Add a new child element to this element. + + Args: + element_name: The tag of the element to add. + position: Where to insert the new element. + **kwargs: Attributes of the new element being created. + + Raises: + ValueError: If the 'element_name' is not a valid child, or if an invalid + attribute is specified in `kwargs`. + + Returns: + An `mjcf.Element` corresponding to the newly created child element. + """ + child_spec = self._check_valid_child(element_name) + if child_spec.on_demand: + need_new_on_demand = self.get_children(element_name) is None + else: + need_new_on_demand = False + if not (child_spec.repeated or need_new_on_demand): + raise ValueError('A <{}> child already exists, please access it directly.' + .format(element_name)) + new_element = _make_element(child_spec, self, attributes=kwargs) + if position is not None: + self._children.insert(position, new_element) + else: + self._children.append(new_element) + self.namescope.increment_revision() + return new_element + + def __getattr__(self, name): + if self._last_attribute_error: + # This means that we got here through a @property raising AttributeError. + # We therefore just re-raise the last AttributeError back to the user. + # Note that self._last_attribute_error was set by our specially + # instrumented @property decorator. + exc = self._last_attribute_error + self._last_attribute_error = None + raise exc # pylint: disable=raising-bad-type + elif name in self._spec.children: + return self.get_children(name) + elif name in self._spec.attributes: + return self._get_attribute(name) + elif name == constants.DCLASS and constants.CLASS in self._spec.attributes: + return self._get_attribute(constants.CLASS) + else: + raise AttributeError('object has no attribute: {}'.format(name)) + + def __setattr__(self, name, value): + # If this name corresponds to a descriptor for a slotted attribute or + # settable property then try to invoke the descriptor to set the attribute + # and return if successful. + klass_attr = getattr(type(self), name, None) + if klass_attr is not None: + try: + return klass_attr.__set__(self, value) + except AttributeError: + pass + # If we did not find a settable descriptor then we look in the attribute + # spec to see if there is a MuJoCo attribute matching this name. + attribute_name = name if name != constants.DCLASS else constants.CLASS + if attribute_name in self._spec.attributes: + self._set_attribute(attribute_name, value) + else: + raise AttributeError('can\'t set attribute: {}'.format(name)) + + def __delattr__(self, name): + if name in self._spec.children: + if self._spec.children[name].repeated: + raise AttributeError( + '`{0}` is a collection of child elements, ' + 'which cannot be deleted. Did you mean to call `{0}.clear()`?' + .format(name)) + else: + return self.get_children(name).remove() + elif name in self._spec.attributes: + return self._remove_attribute(name) + else: + raise AttributeError('object has no attribute: {}'.format(name)) + + def _check_attachments_on_remove(self, affect_attachments): + if not affect_attachments and self._attachments: + raise ValueError( + 'please use remove(affect_attachments=True) as this will affect some ' + 'attributes and/or children belonging to an attached model') + for child in self._children: + child._check_attachments_on_remove(affect_attachments) # pylint: disable=protected-access + + def remove(self, affect_attachments=False): + """Removes this element from the model.""" + self._check_attachments_on_remove(affect_attachments) + if affect_attachments: + for attachment in self._attachments.values(): + attachment.remove(affect_attachments=True) + for child in list(self._children): + child.remove(affect_attachments) + if self._spec.repeated or self._spec.on_demand: + self._parent._children.remove(self) # pylint: disable=protected-access + for attribute in self._attributes.values(): + attribute._force_clear() # pylint: disable=protected-access + self._parent = None + self._is_removed = True + else: + for attribute in self._attributes.values(): + attribute._force_clear() # pylint: disable=protected-access + self.namescope.increment_revision() + + @property + def is_removed(self): + return self._is_removed + + def all_children(self): + all_children = [child for child in self._children] + for attachment in self._attachments.values(): + all_children += [child for child in attachment.all_children() + if child.spec.repeated] + return all_children + + def to_xml(self, prefix_root=None, debug_context=None, + *, + precision=constants.XML_DEFAULT_PRECISION, + zero_threshold=0): + """Generates an etree._Element corresponding to this MJCF element. + + Args: + prefix_root: (optional) A `NameScope` object to be treated as root + for the purpose of calculating the prefix. + If `None` then no prefix is included. + debug_context: (optional) A `debugging.DebugContext` object to which + the debugging information associated with the generated XML is written. + This is intended for internal use within PyMJCF; users should never need + manually pass this argument. + precision: (optional) Number of digits to output for floating point + quantities. + zero_threshold: (optional) When outputting XML, floating point quantities + whose absolute value falls below this threshold will be treated as zero. + + Returns: + An etree._Element object. + """ + prefix_root = prefix_root or self.namescope + xml_element = etree.Element(self._spec.name) + self._attributes_to_xml(xml_element, prefix_root, debug_context, + precision=precision, zero_threshold=zero_threshold) + self._children_to_xml(xml_element, prefix_root, debug_context, + precision=precision, zero_threshold=zero_threshold) + return xml_element + + def _attributes_to_xml(self, xml_element, prefix_root, debug_context=None, + *, precision, zero_threshold): + del debug_context # Unused. + for attribute_name, attribute in self._attributes.items(): + attribute_value = attribute.to_xml_string(prefix_root, + precision=precision, + zero_threshold=zero_threshold) + if attribute_name == self._spec.identifier and attribute_value is None: + xml_element.set(attribute_name, self.full_identifier) + elif attribute_value is None: + continue + else: + xml_element.set(attribute_name, attribute_value) + + def _children_to_xml(self, xml_element, prefix_root, debug_context=None, + *, precision, zero_threshold): + for child in self.all_children(): + child_xml = child.to_xml(prefix_root, debug_context, + precision=precision, + zero_threshold=zero_threshold) + if (child_xml.attrib or len(child_xml) # pylint: disable=g-explicit-length-test + or child.spec.repeated or child.spec.on_demand): + xml_element.append(child_xml) + if debugging.debug_mode() and debug_context: + debug_comment = debug_context.register_element_for_debugging(child) + xml_element.append(debug_comment) + if len(child_xml) > 0: # pylint: disable=g-explicit-length-test + child_xml.insert(0, copy.deepcopy(debug_comment)) + + def to_xml_string(self, prefix_root=None, + self_only=False, pretty_print=True, debug_context=None, + *, + precision=constants.XML_DEFAULT_PRECISION, + zero_threshold=0): + """Generates an XML string corresponding to this MJCF element. + + Args: + prefix_root: (optional) A `NameScope` object to be treated as root + for the purpose of calculating the prefix. + If `None` then no prefix is included. + self_only: (optional) A boolean, whether to generate an XML corresponding + only to this element without any children. + pretty_print: (optional) A boolean, whether to the XML string should be + properly indented. + debug_context: (optional) A `debugging.DebugContext` object to which + the debugging information associated with the generated XML is written. + This is intended for internal use within PyMJCF; users should never need + manually pass this argument. + precision: (optional) Number of digits to output for floating point + quantities. + zero_threshold: (optional) When outputting XML, floating point quantities + whose absolute value falls below this threshold will be treated as zero. + + Returns: + A string. + """ + xml_element = self.to_xml(prefix_root, debug_context, + precision=precision, + zero_threshold=zero_threshold) + if self_only and len(xml_element) > 0: # pylint: disable=g-explicit-length-test + etree.strip_elements(xml_element, '*') + xml_element.text = '...' + if (self_only and self._spec.identifier and + not self._attributes[self._spec.identifier].to_xml_string( + prefix_root, precision=precision, zero_threshold=zero_threshold)): + del xml_element.attrib[self._spec.identifier] + xml_string = util.to_native_string( + etree.tostring(xml_element, pretty_print=pretty_print)) + if pretty_print and debug_context: + return debug_context.commit_xml_string(xml_string) + else: + return xml_string + + def __str__(self): + return self.to_xml_string(self_only=True, pretty_print=False) + + def __repr__(self): + return 'MJCF Element: ' + str(self) + + def _check_valid_attachment(self, other): + self_spec = self._spec + if self_spec.name == constants.WORLDBODY: + self_spec = self._spec.children[constants.BODY] + + other_spec = other.spec + if other_spec.name == constants.WORLDBODY: + other_spec = other_spec.children[constants.BODY] + + if other_spec != self_spec: + raise ValueError( + 'The attachment must have the same spec as this element.') + + def _attach(self, other, exclude_worldbody=False, dry_run=False): + """Attaches another element of the same spec to this element. + + All children of `other` will be treated as children of this element. + All XML attributes which are defined in `other` but not defined in this + element will be copied over, and any conflicting XML attribute value causes + an error. After the attachment, any XML attribute modified in this element + will also affect `other` and vice versa. + + Children of this element which are not a repeated elements will also be + attached by the corresponding children of `other`. + + Args: + other: Another Element with the same spec. + exclude_worldbody: (optional) A boolean. If `True`, then don't do anything + if `other` is a worldbody. + dry_run: (optional) A boolean, if `True` only verify that the operation + is valid without actually committing any change. + + Raises: + ValueError: If `other` has a different spec, or if there are conflicting + XML attribute values. + """ + self._check_valid_attachment(other) + if exclude_worldbody and other.tag == constants.WORLDBODY: + return + if dry_run: + self._check_conflicting_attributes(other, copying=False) + else: + self._attachments[other.namescope] = other + self._sync_attributes(other, copying=False) + self._attach_children(other, exclude_worldbody, dry_run) + if other.tag != constants.WORLDBODY and not dry_run: + other._alias_attributes_dict(self._attributes) # pylint: disable=protected-access + + def _detach(self, other_namescope): + """Detaches a model with the specified namescope.""" + attached_element = self._attachments.get(other_namescope) + if attached_element: + attached_element._restore_attributes_dict() # pylint: disable=protected-access + del self._attachments[other_namescope] + for child in self._children: + child._detach(other_namescope) # pylint: disable=protected-access + + def _check_conflicting_attributes(self, other, copying): + for attribute_name, other_attribute in other.get_attributes().items(): + if attribute_name == constants.DCLASS: + attribute_name = constants.CLASS + if ((not self._attributes[attribute_name].conflict_allowed) + and self._attributes[attribute_name].value is not None + and other_attribute is not None + and np.asarray( + self._attributes[attribute_name].value != other_attribute).any()): + raise ValueError( + 'Conflicting values for attribute `{}`: {} vs {}' + .format(attribute_name, + self._attributes[attribute_name].value, + other_attribute)) + + def _sync_attributes(self, other, copying): + self._check_conflicting_attributes(other, copying) + for attribute_name, other_attribute in other.get_attributes().items(): + if attribute_name == constants.DCLASS: + attribute_name = constants.CLASS + + self_attribute = self._attributes[attribute_name] + if other_attribute is not None: + if self_attribute.conflict_behavior in _CONFLICT_BEHAVIOR_FUNC: + if self_attribute.value is not None: + self_attribute.value = ( + _CONFLICT_BEHAVIOR_FUNC[self_attribute.conflict_behavior]( + self_attribute.value, other_attribute)) + else: + self_attribute.value = other_attribute + elif copying or not self_attribute.conflict_allowed: + self_attribute.value = other_attribute + + def _attach_children(self, other, exclude_worldbody, dry_run=False): + for other_child in other.all_children(): + if not other_child.spec.repeated: + self_child = self.get_children(other_child.spec.name) + self_child._attach(other_child, exclude_worldbody, dry_run) # pylint: disable=protected-access + + def resolve_references(self): + for attribute in self._attributes.values(): + if isinstance(attribute, attribute_types.Reference): + if attribute.value and isinstance(attribute.value, str): + referred = self.root.find( + attribute.reference_namespace, attribute.value) + if referred: + attribute.value = referred + for child in self.all_children(): + child.resolve_references() + + def _update_references(self, reference_dict): + for attribute in self._attributes.values(): + if isinstance(attribute, attribute_types.Reference): + if attribute.value in reference_dict: + attribute.value = reference_dict[attribute.value] + for child in self.all_children(): + child._update_references(reference_dict) # pylint: disable=protected-access + + +class _AttachableElement(_ElementImpl): + """Specialized object representing a or element. + + This element defines a frame to which another MJCF model can be attached. + """ + __slots__ = [] + + def attach(self, attachment): + """Attaches another MJCF model at this site. + + An empty will be created as an attachment frame. All children of + `attachment`'s will be treated as children of this frame. + Furthermore, all other elements in `attachment` are merged into the root + of the MJCF model to which this element belongs. + + Args: + attachment: An MJCF `RootElement` + + Returns: + An `mjcf.Element` corresponding to the attachment frame. A joint can be + added directly to this frame to give degrees of freedom to the attachment. + + Raises: + ValueError: If `other` is not a valid attachment to this element. + """ + if not isinstance(attachment, RootElement): + raise ValueError('Expected a mjcf.RootElement: got {}' + .format(attachment)) + if attachment.namescope.parent is not None: + raise ValueError('The model specified is already attached elsewhere') + if attachment.namescope == self.namescope: + raise ValueError('Cannot merge a model to itself') + self.root._attach(attachment, exclude_worldbody=True, dry_run=True) # pylint: disable=protected-access + + if self.namescope.has_identifier('namescope', attachment.model): + id_number = 1 + while self.namescope.has_identifier( + 'namescope', '{}_{}'.format(attachment.model, id_number)): + id_number += 1 + attachment.model = '{}_{}'.format(attachment.model, id_number) + attachment.namescope.parent = self.namescope + + if self.tag == constants.WORLDBODY: + frame_parent = self + frame_siblings = self._children + index = len(frame_siblings) + else: + frame_parent = self._parent + frame_siblings = self._parent._children # pylint: disable=protected-access + index = frame_siblings.index(self) + 1 + while (index < len(frame_siblings) + and isinstance(frame_siblings[index], _AttachmentFrame)): + index += 1 + frame = _AttachmentFrame(frame_parent, self, attachment) + frame_siblings.insert(index, frame) + self.root._attach(attachment, exclude_worldbody=True) # pylint: disable=protected-access + return frame + + +class _AttachmentFrame(_ElementImpl): + """An specialized representing a frame holding an external attachment. + """ + __slots__ = ['_site', '_attachment'] + + def __init__(self, parent, site, attachment): + if parent.tag == constants.WORLDBODY: + spec = schema.WORLD_ATTACHMENT_FRAME + else: + spec = schema.ATTACHMENT_FRAME + + spec_is_copied = False + for child_name, child_spec in spec.children.items(): + if child_spec.namespace: + if not spec_is_copied: + spec = copy.deepcopy(spec) + spec_is_copied = True + spec_as_dict = child_spec._asdict() + spec_as_dict['namespace'] = '{}{}attachment_frame_{}'.format( + child_spec.namespace, constants.NAMESPACE_SEPARATOR, id(self)) + spec.children[child_name] = type(child_spec)(**spec_as_dict) + + attributes = {} + with debugging.freeze_current_stack_trace(): + for attribute_name in spec.attributes.keys(): + if hasattr(site, attribute_name): + attributes[attribute_name] = getattr(site, attribute_name) + super().__init__(spec, parent, attributes) + self._site = site + self._attachment = attachment + self._attachments[attachment.namescope] = attachment.worldbody + self.namescope.add('attachment_frame', attachment.namescope.name, self) + self.namescope.add('attached_model', attachment.namescope.name, attachment) + + def prefixed_identifier(self, prefix_root=None): + prefix = self.namescope.full_prefix(prefix_root) + return prefix + self._attachment.namescope.name + constants.PREFIX_SEPARATOR + + def to_xml(self, prefix_root=None, debug_context=None, + *, + precision=constants.XML_DEFAULT_PRECISION, + zero_threshold=0): + xml_element = (super().to_xml(prefix_root, debug_context, + precision=precision, + zero_threshold=zero_threshold)) + xml_element.set('name', self.prefixed_identifier(prefix_root)) + return xml_element + + @property + def full_identifier(self): + return self.prefixed_identifier(self.namescope.root) + + def _detach(self, other_namescope): + super()._detach(other_namescope) + if other_namescope is self._attachment.namescope: + self.namescope.remove('attachment_frame', self._attachment.namescope.name) + self.namescope.remove('attached_model', self._attachment.namescope.name) + self.remove() + + +class _AttachmentFrameChild(_ElementImpl): + """A child element of an attachment frame. + + Right now, this is always a or a . The name of the joint + is not freely specifiable, but instead just inherits from the parent frame. + This ensures uniqueness, as attachment frame identifiers always end in '/'. + """ + __slots__ = [] + + def to_xml(self, prefix_root=None, debug_context=None, + *, + precision=constants.XML_DEFAULT_PRECISION, + zero_threshold=0): + xml_element = (super().to_xml(prefix_root, debug_context, + precision=precision, + zero_threshold=zero_threshold)) + if self.spec.namespace is not None: + if self.name: + name = (self._parent.prefixed_identifier(prefix_root) + + self.name + constants.PREFIX_SEPARATOR) + else: + name = self._parent.prefixed_identifier(prefix_root) + xml_element.set('name', name) + return xml_element + + def prefixed_identifier(self, prefix_root=None): + if self.name: + return (self._parent.prefixed_identifier(prefix_root) + + self.name + constants.PREFIX_SEPARATOR) + else: + return self._parent.prefixed_identifier(prefix_root) + + +class _DefaultElement(_ElementImpl): + """Specialized object representing a element. + + This is necessary for the proper handling of global defaults. + """ + __slots__ = [] + + def _attach(self, other, exclude_worldbody=False, dry_run=False): + self._check_valid_attachment(other) + if ((not isinstance(self._parent, RootElement)) + or (not isinstance(other.parent, RootElement))): + raise ValueError('Only global <{}> can be attached' + .format(constants.DEFAULT)) + if not dry_run: + self._attachments[other.namescope] = other + + def all_children(self): + return [child for child in self._children] + + def to_xml(self, prefix_root=None, debug_context=None, + *, + precision=constants.XML_DEFAULT_PRECISION, + zero_threshold=0): + prefix_root = prefix_root or self.namescope + xml_element = (super().to_xml(prefix_root, debug_context, + precision=precision, + zero_threshold=zero_threshold)) + if isinstance(self._parent, RootElement): + root_default = etree.Element(self._spec.name) + root_default.append(xml_element) + for attachment in self._attachments.values(): + attachment_xml = attachment.to_xml(prefix_root, debug_context, + precision=precision, + zero_threshold=zero_threshold) + for attachment_child_xml in attachment_xml: + root_default.append(attachment_child_xml) + xml_element = root_default + return xml_element + + +class _ActuatorElement(_ElementImpl): + """Specialized object representing an element.""" + + __slots__ = () + + def _children_to_xml(self, xml_element, prefix_root, debug_context=None, + *, + precision=constants.XML_DEFAULT_PRECISION, + zero_threshold=0): + debug_comments = {} + for child in self.all_children(): + child_xml = child.to_xml(prefix_root, debug_context, + precision=precision, + zero_threshold=zero_threshold) + if debugging.debug_mode() and debug_context: + debug_comment = debug_context.register_element_for_debugging(child) + debug_comments[child_xml] = debug_comment + if len(child_xml) > 0: # pylint: disable=g-explicit-length-test + child_xml.insert(0, copy.deepcopy(debug_comment)) + xml_element.append(child_xml) + if debugging.debug_mode() and debug_context: + xml_element.append(debug_comments[child_xml]) + + +class RootElement(_ElementImpl): + """The root `` element of an MJCF model.""" + __slots__ = ['_namescope'] + + def __init__(self, model=None, model_dir='', assets=None): + model = model or 'unnamed_model' + self._namescope = namescope.NameScope( + model, self, model_dir=model_dir, assets=assets) + super().__init__( + spec=schema.MUJOCO, parent=None, attributes={'model': model}) + + def _attach(self, other, exclude_worldbody=False, dry_run=False): + self._check_valid_attachment(other) + if not dry_run: + self._attachments[other.namescope] = other + self._attach_children(other, exclude_worldbody, dry_run) + self.namescope.increment_revision() + + @property + def namescope(self): + return self._namescope + + @property + def root(self): + return self + + @property + def model(self): + return self._namescope.name + + @model.setter + def model(self, new_name): + old_name = self._namescope.name + self._namescope.name = new_name + self._attributes['model'].value = new_name + if self.parent_model: + self.parent_model.namescope.rename('attachment_frame', old_name, new_name) + self.parent_model.namescope.rename('attached_model', old_name, new_name) + + def attach(self, other): + return self.worldbody.attach(other) + + def detach(self): + parent_model = self.parent_model + if not parent_model: + raise RuntimeError( + 'Cannot `detach` a model that is not attached to some other model.') + else: + parent_model._detach(self.namescope) # pylint: disable=protected-access + self.namescope.parent = None + + def include_copy(self, other, override_attributes=False): + other_copier = copier.Copier(other) + new_elements = other_copier.copy_into(self, override_attributes) + self._update_references(new_elements) + self.namescope.increment_revision() + + @property + def parent_model(self): + """The RootElement of the MJCF model to which this one is attached.""" + namescope_parent = self._namescope.parent + return namescope_parent.mjcf_model if namescope_parent else None + + @property + def root_model(self): + return self.parent_model.root_model if self.parent_model else self + + def get_assets(self): + """Returns a dict containing the binary assets referenced in this model. + + This will contain `{vfs_filename: contents}` pairs. `vfs_filename` will be + the name of the asset in MuJoCo's Virtual File System, which corresponds to + the filename given in the XML returned by `to_xml_string()`. `contents` is a + bytestring. + + This dict can be used together with the result of `to_xml_string()` to + construct a `mujoco.Physics` instance: + + ```python + physics = mujoco.Physics.from_xml_string( + xml_string=mjcf_model.to_xml_string(), + assets=mjcf_model.get_assets()) + ``` + """ + # Get the assets referenced within this `RootElement`'s namescope. + assets = {file_obj.to_xml_string(): file_obj.get_contents() + for file_obj in self.namescope.files + if file_obj.value} + + # Recursively add assets belonging to attachments. + for attached_model in self._attachments.values(): + assets.update(attached_model.get_assets()) + + return assets + + def get_assets_map(self): + + # Get the assets referenced within this `RootElement`'s namescope. + assets = {file_obj._parent.name: file_obj.value.prefix + file_obj.value.extension + for file_obj in self.namescope.files + if file_obj.value} + return assets + + @property + def full_identifier(self): + return self._namescope.full_prefix(self._namescope.root) + + def __copy__(self): + new_model = RootElement(model=self._namescope.name, + model_dir=self.namescope.model_dir) + new_model.include_copy(self) + return new_model + + def __deepcopy__(self, _): + return self.__copy__() + + def is_same_as(self, other): + if other is None or other.spec != self._spec: + return False + return self._has_same_children_as(other) + + +class _ElementListView: + """A hybrid list/dict-like view to a group of repeated MJCF elements.""" + + def __init__(self, spec, parent): + self._spec = spec + self._parent = parent + self._elements = self._parent._children # pylint: disable=protected-access + self._scoped_elements = collections.OrderedDict( + [(scope_namescope.name, getattr(scoped_parent, self._spec.name)) + for scope_namescope, scoped_parent + in self._parent._attachments.items()]) + + @property + def spec(self): + return self._spec + + @property + def tag(self): + return self._spec.name + + @property + def namescope(self): + return self._parent.namescope + + @property + def parent(self): + return self._parent + + def __len__(self): + return len(self._full_list()) + + def __iter__(self): + return iter(self._full_list()) + + def _identifier_not_found_error(self, index): + return KeyError('An element <{}> with {}={!r} does not exist' + .format(self._spec.name, self._spec.identifier, index)) + + def _find_index(self, index): + """Locates an element given the index among siblings with the same tag.""" + if isinstance(index, str) and self._spec.identifier: + for i, element in enumerate(self._elements): + if (element.tag == self._spec.name + and getattr(element, self._spec.identifier) == index): + return i + raise self._identifier_not_found_error(index) + else: + count = 0 + for i, element in enumerate(self._elements): + if element.tag == self._spec.name: + if index == count: + return i + else: + count += 1 + raise IndexError('list index out of range') + + def _full_list(self): + out_list = [element for element in self._elements + if element.tag == self._spec.name] + for scoped_elements in self._scoped_elements.values(): + out_list += scoped_elements[:] + return out_list + + def clear(self): + for child in self._full_list(): + child.remove() + + def __getitem__(self, index): + if (isinstance(index, str) and self._spec.identifier + and constants.PREFIX_SEPARATOR in index): + scope_name = index.split(constants.PREFIX_SEPARATOR)[0] + scoped_elements = self._scoped_elements[scope_name] + try: + return scoped_elements[index[(len(scope_name) + 1):]] + except KeyError: + # Re-raise so that the error shows the full, un-stripped index string + raise self._identifier_not_found_error(index) # pylint: disable=raise-missing-from + elif isinstance(index, slice) or (isinstance(index, int) and index < 0): + return self._full_list()[index] + else: + return self._elements[self._find_index(index)] + + def __delitem__(self, index): + found_index = self._find_index(index) + self._elements[found_index].remove() + + def __str__(self): + return str( + [element.to_xml_string( + prefix_root=self.namescope, self_only=True, pretty_print=False) + for element in self._full_list()]) + + def __repr__(self): + return 'MJCF Elements List: ' + str(self) + + +# This restores @property back to Python's built-in one. +del property +del _raw_property diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/element_test.py b/rofunc/utils/robolab/formatter/mjcf_parser/element_test.py new file mode 100644 index 000000000..bad50c719 --- /dev/null +++ b/rofunc/utils/robolab/formatter/mjcf_parser/element_test.py @@ -0,0 +1,1073 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for `dm_control.mjcf.element`.""" + +import copy +import hashlib +import itertools +import os +import sys +import traceback + +import lxml +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized +from rofunc.utils.robolab.formatter import mjcf_parser as mjcf +from rofunc.utils.robolab.formatter.mjcf_parser import util + +from rofunc.utils.robolab.formatter.mjcf_parser import element +from rofunc.utils.robolab.formatter.mjcf_parser import namescope +from rofunc.utils.robolab.formatter.mjcf_parser import parser + +etree = lxml.etree + +_ASSETS_DIR = os.path.join(os.path.dirname(__file__), 'test_assets') +_TEST_MODEL_XML = os.path.join(_ASSETS_DIR, 'test_model.xml') +_TEXTURE_PATH = os.path.join(_ASSETS_DIR, 'textures/deepmind.png') +_MESH_PATH = os.path.join(_ASSETS_DIR, 'meshes/cube.stl') +_MODEL_WITH_INCLUDE_PATH = os.path.join(_ASSETS_DIR, 'model_with_include.xml') + +_MODEL_WITH_INVALID_FILENAMES = os.path.join( + _ASSETS_DIR, 'model_with_invalid_filenames.xml') +_INCLUDED_WITH_INVALID_FILENAMES = os.path.join( + _ASSETS_DIR, 'included_with_invalid_filenames.xml') +_MODEL_WITH_NAMELESS_ASSETS = os.path.join( + _ASSETS_DIR, 'model_with_nameless_assets.xml') + + +class ElementTest(parameterized.TestCase): + + def assertIsSame(self, mjcf_model, other): + self.assertTrue(mjcf_model.is_same_as(other)) + self.assertTrue(other.is_same_as(mjcf_model)) + + def assertIsNotSame(self, mjcf_model, other): + self.assertFalse(mjcf_model.is_same_as(other)) + self.assertFalse(other.is_same_as(mjcf_model)) + + def assertHasAttr(self, obj, attrib): + self.assertTrue(hasattr(obj, attrib)) + + def assertNotHasAttr(self, obj, attrib): + self.assertFalse(hasattr(obj, attrib)) + + def _test_properties(self, mjcf_element, parent, root, recursive=False): + self.assertEqual(mjcf_element.tag, mjcf_element.spec.name) + self.assertEqual(mjcf_element.parent, parent) + self.assertEqual(mjcf_element.root, root) + self.assertEqual(mjcf_element.namescope, root.namescope) + for child_name, child_spec in mjcf_element.spec.children.items(): + if not (child_spec.repeated or child_spec.on_demand): + child = getattr(mjcf_element, child_name) + self.assertEqual(child.tag, child_name) + self.assertEqual(child.spec, child_spec) + if recursive: + self._test_properties(child, parent=mjcf_element, + root=root, recursive=True) + + def testAttributeError(self): + mjcf_model = element.RootElement(model='test') + mjcf_model.worldbody._spec = None + try: + _ = mjcf_model.worldbody.tag + except AttributeError: + _, err, tb = sys.exc_info() + else: + self.fail('AttributeError was not raised.') + # Test that the error comes from the fact that we've set `_spec = None`. + self.assertEqual(str(err), + '\'NoneType\' object has no attribute \'name\'') + _, _, func_name, _ = traceback.extract_tb(tb)[-1] + # Test that the error comes from the `root` property, not `__getattr__`. + self.assertEqual(func_name, 'tag') + + def testProperties(self): + mujoco = element.RootElement(model='test') + self.assertIsInstance(mujoco.namescope, namescope.NameScope) + self._test_properties(mujoco, parent=None, root=mujoco, recursive=True) + + def _test_attributes(self, mjcf_element, + expected_values=None, recursive=False): + attributes = mjcf_element.get_attributes() + self.assertNotIn('class', attributes) + for attribute_name in mjcf_element.spec.attributes.keys(): + if attribute_name == 'class': + attribute_name = 'dclass' + self.assertHasAttr(mjcf_element, attribute_name) + self.assertIn(attribute_name, dir(mjcf_element)) + attribute_value = getattr(mjcf_element, attribute_name) + if attribute_value is not None: + self.assertIn(attribute_name, attributes) + else: + self.assertNotIn(attribute_name, attributes) + if expected_values: + if attribute_name in expected_values: + expected_value = expected_values[attribute_name] + np.testing.assert_array_equal(attribute_value, expected_value) + else: + self.assertIsNone(attribute_value) + if recursive: + for child in mjcf_element.all_children(): + self._test_attributes(child, recursive=True) + + def testAttributes(self): + mujoco = element.RootElement(model='test') + mujoco.default.dclass = 'main' + self._test_attributes(mujoco, recursive=True) + + def _test_children(self, mjcf_element, recursive=False): + children = mjcf_element.all_children() + for child_name, child_spec in mjcf_element.spec.children.items(): + if not (child_spec.repeated or child_spec.on_demand): + self.assertHasAttr(mjcf_element, child_name) + self.assertIn(child_name, dir(mjcf_element)) + child = getattr(mjcf_element, child_name) + self.assertIn(child, children) + with self.assertRaisesRegex(AttributeError, 'can\'t set attribute'): + setattr(mjcf_element, child_name, 'value') + if recursive: + self._test_children(child, recursive=True) + + def testChildren(self): + mujoco = element.RootElement(model='test') + self._test_children(mujoco, recursive=True) + + def testInvalidAttr(self): + mujoco = element.RootElement(model='test') + invalid_attrib_name = 'foobar' + + def test_invalid_attr_recursively(mjcf_element): + self.assertNotHasAttr(mjcf_element, invalid_attrib_name) + self.assertNotIn(invalid_attrib_name, dir(mjcf_element)) + with self.assertRaisesRegex(AttributeError, 'object has no attribute'): + getattr(mjcf_element, invalid_attrib_name) + with self.assertRaisesRegex(AttributeError, 'can\'t set attribute'): + setattr(mjcf_element, invalid_attrib_name, 'value') + with self.assertRaisesRegex(AttributeError, 'object has no attribute'): + delattr(mjcf_element, invalid_attrib_name) + for child in mjcf_element.all_children(): + test_invalid_attr_recursively(child) + + test_invalid_attr_recursively(mujoco) + + def testAdd(self): + mujoco = element.RootElement(model='test') + + # repeated elements + body_foo_attributes = dict(name='foo', pos=[0, 1, 0], quat=[0, 1, 0, 0]) + body_foo = mujoco.worldbody.add('body', **body_foo_attributes) + self.assertEqual(body_foo.tag, 'body') + joint_foo_attributes = dict(name='foo', type='free') + joint_foo = body_foo.add('joint', **joint_foo_attributes) + self.assertEqual(joint_foo.tag, 'joint') + self._test_properties(body_foo, parent=mujoco.worldbody, root=mujoco) + self._test_attributes(body_foo, expected_values=body_foo_attributes) + self._test_children(body_foo) + self._test_properties(joint_foo, parent=body_foo, root=mujoco) + self._test_attributes(joint_foo, expected_values=joint_foo_attributes) + self._test_children(joint_foo) + + # non-repeated, on-demand elements + self.assertIsNone(body_foo.inertial) + body_foo_inertial_attributes = dict(mass=1.0, pos=[0, 0, 0]) + body_foo_inertial = body_foo.add('inertial', **body_foo_inertial_attributes) + self._test_properties(body_foo_inertial, parent=body_foo, root=mujoco) + self._test_attributes(body_foo_inertial, + expected_values=body_foo_inertial_attributes) + self._test_children(body_foo_inertial) + + with self.assertRaisesRegex(ValueError, ' child already exists'): + body_foo.add('inertial', **body_foo_inertial_attributes) + + # non-repeated, non-on-demand elements + with self.assertRaisesRegex(ValueError, ' child already exists'): + mujoco.add('compiler') + self.assertIsNotNone(mujoco.compiler) + with self.assertRaisesRegex(ValueError, ' child already exists'): + mujoco.add('default') + self.assertIsNotNone(mujoco.default) + + def testInsert(self): + mujoco = element.RootElement(model='test') + + # add in order + mujoco.worldbody.add('body', name='0') + mujoco.worldbody.add('body', name='1') + mujoco.worldbody.add('body', name='2') + + # insert into position 0, check order + mujoco.worldbody.insert('body', name='foo', position=0) + expected_order = ['foo', '0', '1', '2'] + for i, child in enumerate(mujoco.worldbody._children): + self.assertEqual(child.name, expected_order[i]) + + # insert into position -1, check order + mujoco.worldbody.insert('body', name='bar', position=-1) + expected_order = ['foo', '0', '1', 'bar', '2'] + for i, child in enumerate(mujoco.worldbody._children): + self.assertEqual(child.name, expected_order[i]) + + def testAddWithInvalidAttribute(self): + mujoco = element.RootElement(model='test') + with self.assertRaisesRegex(AttributeError, 'not a valid attribute'): + mujoco.worldbody.add('body', name='foo', invalid_attribute='some_value') + self.assertFalse(mujoco.worldbody.body) + self.assertIsNone(mujoco.worldbody.find('body', 'foo')) + + def testSameness(self): + mujoco = element.RootElement(model='test') + + body_1 = mujoco.worldbody.add('body', pos=[0, 1, 2], quat=[0, 1, 0, 1]) + site_1 = body_1.add('site', pos=[0, 1, 2], quat=[0, 1, 0, 1]) + geom_1 = body_1.add('geom', pos=[0, 1, 2], quat=[0, 1, 0, 1]) + + for elem in (body_1, site_1, geom_1): + self.assertIsSame(elem, elem) + + # strict ordering NOT required: adding geom and site is different order + body_2 = mujoco.worldbody.add('body', pos=[0, 1, 2], quat=[0, 1, 0, 1]) + geom_2 = body_2.add('geom', pos=[0, 1, 2], quat=[0, 1, 0, 1]) + site_2 = body_2.add('site', pos=[0, 1, 2], quat=[0, 1, 0, 1]) + + elems_1 = (body_1, site_1, geom_1) + elems_2 = (body_2, site_2, geom_2) + for i, j in itertools.product(range(len(elems_1)), range(len(elems_2))): + if i == j: + self.assertIsSame(elems_1[i], elems_2[j]) + else: + self.assertIsNotSame(elems_1[i], elems_2[j]) + + # on-demand child + body_1.add('inertial', pos=[0, 0, 0], mass=1) + self.assertIsNotSame(body_1, body_2) + + body_2.add('inertial', pos=[0, 0, 0], mass=1) + self.assertIsSame(body_1, body_2) + + # different number of children + subbody_1 = body_1.add('body', pos=[0, 0, 1]) + self.assertIsNotSame(body_1, body_2) + + # attribute mismatch + subbody_2 = body_2.add('body') + self.assertIsNotSame(subbody_1, subbody_2) + self.assertIsNotSame(body_1, body_2) + + subbody_2.pos = [0, 0, 1] + self.assertIsSame(subbody_1, subbody_2) + self.assertIsSame(body_1, body_2) + + # grandchild attribute mismatch + subbody_1.add('joint', type='hinge') + subbody_2.add('joint', type='ball') + self.assertIsNotSame(body_1, body_2) + + def testTendonSameness(self): + mujoco = element.RootElement(model='test') + + spatial_1 = mujoco.tendon.add('spatial') + spatial_1.add('site', site='foo') + spatial_1.add('geom', geom='bar') + + spatial_2 = mujoco.tendon.add('spatial') + spatial_2.add('site', site='foo') + spatial_2.add('geom', geom='bar') + + self.assertIsSame(spatial_1, spatial_2) + + # strict ordering is required + spatial_3 = mujoco.tendon.add('spatial') + spatial_3.add('site', site='foo') + spatial_3.add('geom', geom='bar') + + spatial_4 = mujoco.tendon.add('spatial') + spatial_4.add('geom', geom='bar') + spatial_4.add('site', site='foo') + + self.assertIsNotSame(spatial_3, spatial_4) + + def testCopy(self): + mujoco = parser.from_path(_TEST_MODEL_XML) + self.assertIsSame(mujoco, mujoco) + + copy_mujoco = copy.copy(mujoco) + copy_mujoco.model = 'copied_model' + self.assertIsSame(copy_mujoco, mujoco) + self.assertNotEqual(copy_mujoco, mujoco) + + deepcopy_mujoco = copy.deepcopy(mujoco) + deepcopy_mujoco.model = 'deepcopied_model' + self.assertIsSame(deepcopy_mujoco, mujoco) + self.assertNotEqual(deepcopy_mujoco, mujoco) + + self.assertIsSame(deepcopy_mujoco, copy_mujoco) + self.assertNotEqual(deepcopy_mujoco, copy_mujoco) + + def testWorldBodyFullIdentifier(self): + mujoco = parser.from_path(_TEST_MODEL_XML) + mujoco.model = 'model' + self.assertEqual(mujoco.worldbody.full_identifier, 'world') + + submujoco = copy.copy(mujoco) + submujoco.model = 'submodel' + self.assertEqual(submujoco.worldbody.full_identifier, 'world') + + mujoco.attach(submujoco) + self.assertEqual(mujoco.worldbody.full_identifier, 'world') + self.assertEqual(submujoco.worldbody.full_identifier, 'submodel/') + + self.assertNotIn('name', mujoco.worldbody.to_xml_string(self_only=True)) + self.assertNotIn('name', submujoco.worldbody.to_xml_string(self_only=True)) + + def testAttach(self): + mujoco = parser.from_path(_TEST_MODEL_XML) + mujoco.model = 'model' + + submujoco = copy.copy(mujoco) + submujoco.model = 'submodel' + + subsubmujoco = copy.copy(mujoco) + subsubmujoco.model = 'subsubmodel' + + with self.assertRaisesRegex(ValueError, 'Cannot merge a model to itself'): + mujoco.attach(mujoco) + + attachment_site = submujoco.find('site', 'attachment') + attachment_site.attach(subsubmujoco) + subsubmodel_frame = submujoco.find('attachment_frame', 'subsubmodel') + for attribute_name in ('pos', 'axisangle', 'xyaxes', 'zaxis', 'euler'): + np.testing.assert_array_equal( + getattr(subsubmodel_frame, attribute_name), + getattr(attachment_site, attribute_name)) + self._test_properties(subsubmodel_frame, + parent=attachment_site.parent, root=submujoco) + self.assertEqual( + subsubmodel_frame.to_xml_string().split('\n')[0], + '') + self.assertEqual( + subsubmodel_frame.to_xml_string(precision=5).split('\n')[0], + '') + self.assertEqual(subsubmodel_frame.all_children(), + subsubmujoco.worldbody.all_children()) + + with self.assertRaisesRegex(ValueError, 'already attached elsewhere'): + mujoco.attach(subsubmujoco) + + with self.assertRaisesRegex(ValueError, 'Expected a mjcf.RootElement'): + mujoco.attach(submujoco.contact) + + submujoco.option.flag.gravity = 'enable' + with self.assertRaisesRegex( + ValueError, 'Conflicting values for attribute `gravity`'): + mujoco.attach(submujoco) + submujoco.option.flag.gravity = 'disable' + + mujoco.attach(submujoco) + self.assertEqual(subsubmujoco.parent_model, submujoco) + self.assertEqual(submujoco.parent_model, mujoco) + self.assertEqual(subsubmujoco.root_model, mujoco) + self.assertEqual(submujoco.root_model, mujoco) + + self.assertEqual(submujoco.full_identifier, 'submodel/') + self.assertEqual(subsubmujoco.full_identifier, 'submodel/subsubmodel/') + + merged_children = ('contact', 'actuator') + for child_name in merged_children: + for grandchild in getattr(submujoco, child_name).all_children(): + self.assertIn(grandchild, getattr(mujoco, child_name).all_children()) + for grandchild in getattr(subsubmujoco, child_name).all_children(): + self.assertIn(grandchild, getattr(mujoco, child_name).all_children()) + self.assertIn(grandchild, getattr(submujoco, child_name).all_children()) + + base_contact_content = ( + '') + self.assertEqual( + mujoco.contact.to_xml_string(pretty_print=False), + '' + + base_contact_content.format('') + + base_contact_content.format('submodel/') + + base_contact_content.format('submodel/subsubmodel/') + + '') + + actuators_template = ( + '' + '') + self.assertEqual( + mujoco.actuator.to_xml_string(pretty_print=False), + '' + + actuators_template.format('/', '') + + actuators_template.format('submodel/', 'submodel/') + + actuators_template.format('submodel/subsubmodel/', + 'submodel/subsubmodel/') + + '') + + self.assertEqual(mujoco.default.full_identifier, '/') + self.assertEqual(mujoco.default.default[0].full_identifier, 'big_and_green') + self.assertEqual(submujoco.default.full_identifier, 'submodel/') + self.assertEqual(submujoco.default.default[0].full_identifier, + 'submodel/big_and_green') + self.assertEqual(subsubmujoco.default.full_identifier, + 'submodel/subsubmodel/') + self.assertEqual(subsubmujoco.default.default[0].full_identifier, + 'submodel/subsubmodel/big_and_green') + default_xml_lines = (mujoco.default.to_xml_string(pretty_print=False) + .replace('><', '>><<').split('><')) + self.assertEqual(default_xml_lines[0], '') + self.assertEqual(default_xml_lines[1], '') + self.assertEqual(default_xml_lines[4], '') + self.assertEqual(default_xml_lines[6], '') + self.assertEqual(default_xml_lines[7], '') + self.assertEqual(default_xml_lines[8], '') + self.assertEqual(default_xml_lines[11], + '') + self.assertEqual(default_xml_lines[13], '') + self.assertEqual(default_xml_lines[14], '') + self.assertEqual(default_xml_lines[15], + '') + self.assertEqual(default_xml_lines[18], + '') + self.assertEqual(default_xml_lines[-3], '') + self.assertEqual(default_xml_lines[-2], '') + self.assertEqual(default_xml_lines[-1], '') + + def testDetach(self): + root = parser.from_path(_TEST_MODEL_XML) + root.model = 'model' + + submodel = copy.copy(root) + submodel.model = 'submodel' + + unattached_xml_1 = root.to_xml_string() + root.attach(submodel) + attached_xml_1 = root.to_xml_string() + + submodel.detach() + unattached_xml_2 = root.to_xml_string() + root.attach(submodel) + attached_xml_2 = root.to_xml_string() + + self.assertEqual(unattached_xml_1, unattached_xml_2) + self.assertEqual(attached_xml_1, attached_xml_2) + + def testRenameAttachedModel(self): + root = parser.from_path(_TEST_MODEL_XML) + root.model = 'model' + + submodel = copy.copy(root) + submodel.model = 'submodel' + geom = submodel.worldbody.add( + 'geom', name='geom', type='sphere', size=[0.1]) + + frame = root.attach(submodel) + submodel.model = 'renamed' + self.assertEqual(frame.full_identifier, 'renamed/') + self.assertIsSame(root.find('geom', 'renamed/geom'), geom) + + def testAttachmentFrames(self): + mujoco = parser.from_path(_TEST_MODEL_XML) + mujoco.model = 'model' + + submujoco = copy.copy(mujoco) + submujoco.model = 'submodel' + + subsubmujoco = copy.copy(mujoco) + subsubmujoco.model = 'subsubmodel' + + attachment_site = submujoco.find('site', 'attachment') + attachment_site.attach(subsubmujoco) + mujoco.attach(submujoco) + + # attachments directly on worldbody can have a + submujoco_frame = mujoco.find('attachment_frame', 'submodel') + self.assertStartsWith(submujoco_frame.to_xml_string(pretty_print=False), + '') + self.assertEqual(submujoco_frame.full_identifier, 'submodel/') + free_joint = submujoco_frame.add('freejoint') + self.assertEqual(free_joint.to_xml_string(pretty_print=False), + '') + self.assertEqual(free_joint.full_identifier, 'submodel/') + + # attachments elsewhere cannot have a + subsubmujoco_frame = submujoco.find('attachment_frame', 'subsubmodel') + subsubmujoco_frame_xml = subsubmujoco_frame.to_xml_string( + pretty_print=False, prefix_root=mujoco.namescope) + self.assertStartsWith( + subsubmujoco_frame_xml, + '') + self.assertEqual(subsubmujoco_frame.full_identifier, + 'submodel/subsubmodel/') + with self.assertRaisesRegex(AttributeError, 'not a valid child'): + subsubmujoco_frame.add('freejoint') + hinge_joint = subsubmujoco_frame.add('joint', type='hinge', axis=[1, 2, 3]) + hinge_joint_xml = hinge_joint.to_xml_string( + pretty_print=False, prefix_root=mujoco.namescope) + self.assertEqual( + hinge_joint_xml, + '') + self.assertEqual(hinge_joint.full_identifier, 'submodel/subsubmodel/') + + def testDuplicateAttachmentFrameJointIdentifiers(self): + mujoco = parser.from_path(_TEST_MODEL_XML) + mujoco.model = 'model' + + submujoco_1 = copy.copy(mujoco) + submujoco_1.model = 'submodel_1' + + submujoco_2 = copy.copy(mujoco) + submujoco_2.model = 'submodel_2' + + frame_1 = mujoco.attach(submujoco_1) + frame_2 = mujoco.attach(submujoco_2) + + joint_1 = frame_1.add('joint', type='slide', name='root_x', axis=[1, 0, 0]) + joint_2 = frame_2.add('joint', type='slide', name='root_x', axis=[1, 0, 0]) + + self.assertEqual(joint_1.full_identifier, 'submodel_1/root_x/') + self.assertEqual(joint_2.full_identifier, 'submodel_2/root_x/') + + def testAttachmentFrameReference(self): + root_1 = mjcf.RootElement('model_1') + root_2 = mjcf.RootElement('model_2') + root_2_frame = root_1.attach(root_2) + sensor = root_1.sensor.add( + 'framelinacc', name='root_2', objtype='body', objname=root_2_frame) + self.assertEqual( + sensor.to_xml_string(pretty_print=False), + '') + + def testAttachmentFrameChildReference(self): + root_1 = mjcf.RootElement('model_1') + root_2 = mjcf.RootElement('model_2') + root_2_frame = root_1.attach(root_2) + root_2_joint = root_2_frame.add( + 'joint', name='root_x', type='slide', axis=[1, 0, 0]) + actuator = root_1.actuator.add( + 'position', name='root_x', joint=root_2_joint) + self.assertEqual( + actuator.to_xml_string(pretty_print=False), + '') + + def testDeletion(self): + mujoco = parser.from_path(_TEST_MODEL_XML) + mujoco.model = 'model' + + submujoco = copy.copy(mujoco) + submujoco.model = 'submodel' + + subsubmujoco = copy.copy(mujoco) + subsubmujoco.model = 'subsubmodel' + + submujoco.find('site', 'attachment').attach(subsubmujoco) + mujoco.attach(submujoco) + + with self.assertRaisesRegex( + ValueError, r'use remove\(affect_attachments=True\)'): + del mujoco.option + + mujoco.option.remove(affect_attachments=True) + for root in (mujoco, submujoco, subsubmujoco): + self.assertIsNotNone(root.option.flag) + self.assertEqual( + root.option.to_xml_string(pretty_print=False), ' diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/attribute_test_schema.xml b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/attribute_test_schema.xml new file mode 100644 index 000000000..aa62f3bab --- /dev/null +++ b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/attribute_test_schema.xml @@ -0,0 +1,100 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/included.xml b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/included.xml new file mode 100644 index 000000000..40f88427a --- /dev/null +++ b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/included.xml @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/included_with_invalid_filenames.xml b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/included_with_invalid_filenames.xml new file mode 100644 index 000000000..a58f6e25d --- /dev/null +++ b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/included_with_invalid_filenames.xml @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/lego_brick.xml b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/lego_brick.xml new file mode 100644 index 000000000..ba8561e3d --- /dev/null +++ b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/lego_brick.xml @@ -0,0 +1,34 @@ + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/meshes/cube.msh b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/meshes/cube.msh new file mode 100644 index 000000000..0a2f278f0 Binary files /dev/null and b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/meshes/cube.msh differ diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/meshes/cube.stl b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/meshes/cube.stl new file mode 100644 index 000000000..a5bc8256d Binary files /dev/null and b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/meshes/cube.stl differ diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/meshes/more_meshes/cube.stl b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/meshes/more_meshes/cube.stl new file mode 100644 index 000000000..a5bc8256d Binary files /dev/null and b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/meshes/more_meshes/cube.stl differ diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/model_with_assetdir.xml b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/model_with_assetdir.xml new file mode 100644 index 000000000..68bc57346 --- /dev/null +++ b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/model_with_assetdir.xml @@ -0,0 +1,19 @@ + + + + + + + + + + + + + + + + + + + diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/model_with_assets.xml b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/model_with_assets.xml new file mode 100644 index 000000000..2521511fc --- /dev/null +++ b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/model_with_assets.xml @@ -0,0 +1,19 @@ + + + + + + + + + + + + + + + + + + + diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/model_with_include.xml b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/model_with_include.xml new file mode 100644 index 000000000..fbe2d4190 --- /dev/null +++ b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/model_with_include.xml @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/model_with_invalid_filenames.xml b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/model_with_invalid_filenames.xml new file mode 100644 index 000000000..ff972e1b0 --- /dev/null +++ b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/model_with_invalid_filenames.xml @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/model_with_nameless_assets.xml b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/model_with_nameless_assets.xml new file mode 100644 index 000000000..ad615c6c1 --- /dev/null +++ b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/model_with_nameless_assets.xml @@ -0,0 +1,14 @@ + + + + + + + + + + + + + + diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/model_with_separators.xml b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/model_with_separators.xml new file mode 100644 index 000000000..27cd4aa4b --- /dev/null +++ b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/model_with_separators.xml @@ -0,0 +1,19 @@ + + + + + + + + + + + + + + + + + + + diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/robot_arm.xml b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/robot_arm.xml new file mode 100644 index 000000000..5b6d6a076 --- /dev/null +++ b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/robot_arm.xml @@ -0,0 +1,200 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/skins/test_skin.skn b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/skins/test_skin.skn new file mode 100644 index 000000000..651fdc18d Binary files /dev/null and b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/skins/test_skin.skn differ diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/test_model.xml b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/test_model.xml new file mode 100644 index 000000000..117488e88 --- /dev/null +++ b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/test_model.xml @@ -0,0 +1,49 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/textures/deepmind.png b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/textures/deepmind.png new file mode 100644 index 000000000..1586759c0 Binary files /dev/null and b/rofunc/utils/robolab/formatter/mjcf_parser/test_assets/textures/deepmind.png differ diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/traversal_utils.py b/rofunc/utils/robolab/formatter/mjcf_parser/traversal_utils.py new file mode 100644 index 000000000..ff1bae230 --- /dev/null +++ b/rofunc/utils/robolab/formatter/mjcf_parser/traversal_utils.py @@ -0,0 +1,81 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Utility functions that operate on MJCF elements.""" + +_ACTUATOR_TAGS = ('general', 'motor', 'position', + 'velocity', 'cylinder', 'muscle') + + +def get_freejoint(element): + """Retrieves the free joint of a body. Returns `None` if there isn't one.""" + if element.tag != 'body': + return None + elif hasattr(element, 'freejoint') and element.freejoint is not None: + return element.freejoint + else: + joints = element.find_all('joint', immediate_children_only=True) + for joint in joints: + if joint.type == 'free': + return joint + return None + + +def get_attachment_frame(mjcf_model): + return mjcf_model.parent_model.find('attachment_frame', mjcf_model.model) + + +def get_frame_freejoint(mjcf_model): + frame = get_attachment_frame(mjcf_model) + return get_freejoint(frame) + + +def get_frame_joints(mjcf_model): + """Retrieves all joints belonging to the attachment frame of an MJCF model.""" + frame = get_attachment_frame(mjcf_model) + if frame: + return frame.find_all('joint', immediate_children_only=True) + else: + return None + + +def commit_defaults(element, attributes=None): + """Commits default values into attributes of the specified element. + + Args: + element: A PyMJCF element. + attributes: (optional) A list of strings specifying the attributes to be + copied from defaults, or `None` if all attributes should be copied. + """ + dclass = element.dclass + parent = element.parent + while dclass is None and parent != element.root: + dclass = getattr(parent, 'childclass', None) + parent = parent.parent + if dclass is None: + dclass = element.root.default + + while dclass != element.root: + if element.tag in _ACTUATOR_TAGS: + tags = _ACTUATOR_TAGS + else: + tags = (element.tag,) + for tag in tags: + default_element = getattr(dclass, tag) + for name, value in default_element.get_attributes().items(): + if attributes is None or name in attributes: + if hasattr(element, name) and getattr(element, name) is None: + setattr(element, name, value) + dclass = dclass.parent diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/traversal_utils_test.py b/rofunc/utils/robolab/formatter/mjcf_parser/traversal_utils_test.py new file mode 100644 index 000000000..521ef0766 --- /dev/null +++ b/rofunc/utils/robolab/formatter/mjcf_parser/traversal_utils_test.py @@ -0,0 +1,218 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for `dm_control.mjcf.traversal_utils`.""" + +import numpy as np +from absl.testing import absltest +from rofunc.utils.robolab.formatter import mjcf_parser as mjcf + + +class TraversalUtilsTest(absltest.TestCase): + + def assert_same_attributes(self, element, expected_attributes): + actual_attributes = element.get_attributes() + self.assertEqual(set(actual_attributes.keys()), + set(expected_attributes.keys())) + for name in actual_attributes: + actual_value = actual_attributes[name] + expected_value = expected_attributes[name] + np.testing.assert_array_equal(actual_value, expected_value, name) + + def test_resolve_root_defaults(self): + root = mjcf.RootElement() + root.default.geom.type = 'box' + root.default.geom.pos = [0, 1, 0] + root.default.joint.ref = 2 + root.default.joint.pos = [0, 0, 1] + + # Own attribute overriding default. + body = root.worldbody.add('body') + geom1 = body.add('geom', type='sphere') + mjcf.commit_defaults(geom1) + self.assert_same_attributes(geom1, dict( + type='sphere', + pos=[0, 1, 0])) + + # No explicit attributes. + geom2 = body.add('geom') + mjcf.commit_defaults(geom2) + self.assert_same_attributes(geom2, dict( + type='box', + pos=[0, 1, 0])) + + # Attributes mutually exclusive with those defined in default. + joint1 = body.add('joint', margin=3) + mjcf.commit_defaults(joint1) + self.assert_same_attributes(joint1, dict( + pos=[0, 0, 1], + ref=2, + margin=3)) + + def test_resolve_defaults_for_some_attributes(self): + root = mjcf.RootElement() + root.default.geom.type = 'box' + root.default.geom.pos = [0, 1, 0] + geom1 = root.worldbody.add('geom') + mjcf.commit_defaults(geom1, attributes=['pos']) + self.assert_same_attributes(geom1, dict( + pos=[0, 1, 0])) + + def test_resolve_hierarchies_of_defaults(self): + root = mjcf.RootElement() + root.default.geom.type = 'box' + root.default.joint.pos = [0, 1, 0] + + top1 = root.default.add('default', dclass='top1') + top1.geom.pos = [0.1, 0, 0] + top1.joint.pos = [1, 0, 0] + top1.joint.axis = [0, 0, 1] + sub1 = top1.add('default', dclass='sub1') + sub1.geom.size = [0.5] + + top2 = root.default.add('default', dclass='top2') + top2.joint.pos = [0, 0, 1] + top2.joint.axis = [0, 1, 0] + top2.geom.type = 'sphere' + + body = root.worldbody.add('body') + geom1 = body.add('geom', dclass=sub1) + mjcf.commit_defaults(geom1) + self.assert_same_attributes(geom1, dict( + dclass=sub1, + type='box', + size=[0.5], + pos=[0.1, 0, 0])) + + geom2 = body.add('geom', dclass=top1) + mjcf.commit_defaults(geom2) + self.assert_same_attributes(geom2, dict( + dclass=top1, + type='box', + pos=[0.1, 0, 0])) + + geom3 = body.add('geom', dclass=top2) + mjcf.commit_defaults(geom3) + self.assert_same_attributes(geom3, dict( + dclass=top2, + type='sphere')) + + geom4 = body.add('geom') + mjcf.commit_defaults(geom4) + self.assert_same_attributes(geom4, dict( + type='box')) + + joint1 = body.add('joint', dclass=sub1) + mjcf.commit_defaults(joint1) + self.assert_same_attributes(joint1, dict( + dclass=sub1, + pos=[1, 0, 0], + axis=[0, 0, 1])) + + joint2 = body.add('joint', dclass=top2) + mjcf.commit_defaults(joint2) + self.assert_same_attributes(joint2, dict( + dclass=top2, + pos=[0, 0, 1], + axis=[0, 1, 0])) + + joint3 = body.add('joint') + mjcf.commit_defaults(joint3) + self.assert_same_attributes(joint3, dict( + pos=[0, 1, 0])) + + def test_resolve_actuator_defaults(self): + root = mjcf.RootElement() + root.default.general.forcelimited = 'true' + root.default.motor.forcerange = [-2, 3] + root.default.position.kp = 0.1 + root.default.velocity.kv = 0.2 + + body = root.worldbody.add('body') + joint = body.add('joint') + + motor = root.actuator.add('motor', joint=joint) + mjcf.commit_defaults(motor) + self.assert_same_attributes(motor, dict( + joint=joint, + forcelimited='true', + forcerange=[-2, 3])) + + position = root.actuator.add('position', joint=joint) + mjcf.commit_defaults(position) + self.assert_same_attributes(position, dict( + joint=joint, + kp=0.1, + kv=0.2, + forcelimited='true', + forcerange=[-2, 3])) + + velocity = root.actuator.add('velocity', joint=joint) + mjcf.commit_defaults(velocity) + self.assert_same_attributes(velocity, dict( + joint=joint, + kv=0.2, + forcelimited='true', + forcerange=[-2, 3])) + + def test_resolve_childclass(self): + root = mjcf.RootElement() + root.default.geom.type = 'capsule' + top = root.default.add('default', dclass='top') + top.geom.pos = [0, 1, 0] + sub = top.add('default', dclass='sub') + sub.geom.pos = [0, 0, 1] + + # Element only affected by the childclass of immediate parent. + body = root.worldbody.add('body', childclass=sub) + geom1 = body.add('geom') + mjcf.commit_defaults(geom1) + self.assert_same_attributes(geom1, dict( + type='capsule', + pos=[0, 0, 1])) + + # Element overrides parent's childclass. + geom2 = body.add('geom', dclass=top) + mjcf.commit_defaults(geom2) + self.assert_same_attributes(geom2, dict( + dclass=top, + type='capsule', + pos=[0, 1, 0])) + + # Element's parent overrides grandparent's childclass. + subbody1 = body.add('body', childclass=top) + geom3 = subbody1.add('geom') + mjcf.commit_defaults(geom3) + self.assert_same_attributes(geom3, dict( + type='capsule', + pos=[0, 1, 0])) + + # Element's grandparent does not specify a childclass, but grandparent does. + subbody2 = body.add('body') + geom4 = subbody2.add('geom') + mjcf.commit_defaults(geom4) + self.assert_same_attributes(geom4, dict( + type='capsule', + pos=[0, 0, 1])) + + # Direct child of worldbody, not affected by any childclass. + geom5 = root.worldbody.add('geom') + mjcf.commit_defaults(geom5) + self.assert_same_attributes(geom5, dict( + type='capsule')) + + +if __name__ == '__main__': + absltest.main() diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/util.py b/rofunc/utils/robolab/formatter/mjcf_parser/util.py new file mode 100644 index 000000000..c03e69cc9 --- /dev/null +++ b/rofunc/utils/robolab/formatter/mjcf_parser/util.py @@ -0,0 +1,60 @@ +# Copyright 2017-2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Various helper functions and classes.""" + +import functools +import sys +import mujoco +import numpy as np + +# Environment variable that can be used to override the default path to the +# MuJoCo shared library. +ENV_MJLIB_PATH = "MJLIB_PATH" + +DEFAULT_ENCODING = sys.getdefaultencoding() + + +def to_binary_string(s): + """Convert text string to binary.""" + if isinstance(s, bytes): + return s + return s.encode(DEFAULT_ENCODING) + + +def to_native_string(s): + """Convert a text or binary string to the native string format.""" + if isinstance(s, bytes): + return s.decode(DEFAULT_ENCODING) + else: + return s + + +def get_mjlib(): + return mujoco + + +@functools.wraps(np.ctypeslib.ndpointer) +def ndptr(*args, **kwargs): + """Wraps `np.ctypeslib.ndpointer` to allow passing None for NULL pointers.""" + base = np.ctypeslib.ndpointer(*args, **kwargs) + + def from_param(_, obj): + if obj is None: + return obj + else: + return base.from_param(obj) + + return type(base.__name__, (base,), {"from_param": classmethod(from_param)}) diff --git a/rofunc/utils/robolab/formatter/mjcf_parser/xml_validation_test.py b/rofunc/utils/robolab/formatter/mjcf_parser/xml_validation_test.py new file mode 100644 index 000000000..f2cc072c7 --- /dev/null +++ b/rofunc/utils/robolab/formatter/mjcf_parser/xml_validation_test.py @@ -0,0 +1,67 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests that generated XML string is valid.""" + +import os + +from absl.testing import absltest +from dm_control.mujoco import wrapper + +from rofunc.utils.robolab.formatter.mjcf_parser import parser + +ASSETS_DIR = os.path.join(os.path.dirname(__file__), 'test_assets') +_ARENA_XML = os.path.join(ASSETS_DIR, 'arena.xml') +_LEGO_BRICK_XML = os.path.join(ASSETS_DIR, 'lego_brick.xml') +_ROBOT_XML = os.path.join(ASSETS_DIR, 'robot_arm.xml') + + +def validate(xml_string): + """Validates that an XML string is a valid MJCF. + + Validation is performed by constructing Mujoco model from the string. + The construction process contains compilation and validation phases by Mujoco + engine, the best validation tool we have access to. + + Args: + xml_string: XML string to validate + """ + + mjmodel = wrapper.MjModel.from_xml_string(xml_string) + wrapper.MjData(mjmodel) + + +class XMLValidationTest(absltest.TestCase): + + def testXmlAttach(self): + robot_arm = parser.from_file(_ROBOT_XML) + arena = parser.from_file(_ARENA_XML) + lego = parser.from_file(_LEGO_BRICK_XML) + + # validate MJCF strings before changing them + validate(robot_arm.to_xml_string()) + validate(arena.to_xml_string()) + validate(lego.to_xml_string()) + + # combine objects in complex scene + robot_arm.find('site', 'fingertip1').attach(lego) + arena.worldbody.attach(robot_arm) + + # validate + validate(arena.to_xml_string()) + + +if __name__ == '__main__': + absltest.main() diff --git a/rofunc/utils/robolab/formatter/sdf.py b/rofunc/utils/robolab/formatter/sdf.py new file mode 100644 index 000000000..12e888fae --- /dev/null +++ b/rofunc/utils/robolab/formatter/sdf.py @@ -0,0 +1,108 @@ +import torch +import math +from .urdf_parser.sdf import SDF, Mesh, Cylinder, Box, Sphere +from . import frame +from . import chain +import pytorch_kinematics.transforms as tf + +JOINT_TYPE_MAP = {'revolute': 'revolute', + 'prismatic': 'prismatic', + 'fixed': 'fixed'} + + +def _convert_transform(pose): + if pose is None: + return tf.Transform3d() + else: + return tf.Transform3d(rot=tf.euler_angles_to_matrix(torch.tensor(pose[3:]), "ZYX"), pos=pose[:3]) + + +def _convert_visuals(visuals): + vlist = [] + for v in visuals: + v_tf = _convert_transform(v.pose) + if isinstance(v.geometry, Mesh): + g_type = "mesh" + g_param = v.geometry.filename + elif isinstance(v.geometry, Cylinder): + g_type = "cylinder" + v_tf = v_tf.compose( + tf.Transform3d(rot=tf.euler_angles_to_matrix(torch.tensor([0.5 * math.pi, 0, 0]), "ZYX"))) + g_param = (v.geometry.radius, v.geometry.length) + elif isinstance(v.geometry, Box): + g_type = "box" + g_param = v.geometry.size + elif isinstance(v.geometry, Sphere): + g_type = "sphere" + g_param = v.geometry.radius + else: + g_type = None + g_param = None + vlist.append(frame.Visual(v_tf, g_type, g_param)) + return vlist + + +def _build_chain_recurse(root_frame, lmap, joints): + children = [] + for j in joints: + if j.parent == root_frame.link.name: + child_frame = frame.Frame(j.child) + link_p = lmap[j.parent] + link_c = lmap[j.child] + t_p = _convert_transform(link_p.pose) + t_c = _convert_transform(link_c.pose) + try: + limits = (j.axis.limit.lower, j.axis.limit.upper) + except AttributeError: + limits = None + child_frame.joint = frame.Joint(j.name, offset=t_p.inverse().compose(t_c), + joint_type=JOINT_TYPE_MAP[j.type], axis=j.axis.xyz, limits=limits) + child_frame.link = frame.Link(link_c.name, offset=tf.Transform3d(), + visuals=_convert_visuals(link_c.visuals)) + child_frame.children = _build_chain_recurse(child_frame, lmap, joints) + children.append(child_frame) + return children + + +def build_chain_from_sdf(data): + """ + Build a Chain object from SDF data. + + Parameters + ---------- + data : str + SDF string data. + + Returns + ------- + chain.Chain + Chain object created from SDF. + """ + sdf = SDF.from_xml_string(data) + robot = sdf.model + lmap = robot.link_map + joints = robot.joints + n_joints = len(joints) + has_root = [True for _ in range(len(joints))] + for i in range(n_joints): + for j in range(i + 1, n_joints): + if joints[i].parent == joints[j].child: + has_root[i] = False + elif joints[j].parent == joints[i].child: + has_root[j] = False + for i in range(n_joints): + if has_root[i]: + root_link = lmap[joints[i].parent] + break + root_frame = frame.Frame(root_link.name) + root_frame.joint = frame.Joint(offset=_convert_transform(root_link.pose)) + root_frame.link = frame.Link(root_link.name, tf.Transform3d(), + _convert_visuals(root_link.visuals)) + root_frame.children = _build_chain_recurse(root_frame, lmap, joints) + return chain.Chain(root_frame) + + +def build_serial_chain_from_sdf(data, end_link_name, root_link_name=""): + mjcf_chain = build_chain_from_sdf(data) + serial_chain = chain.SerialChain(mjcf_chain, end_link_name, "" if root_link_name == "" else root_link_name) + return serial_chain diff --git a/rofunc/utils/robolab/formatter/urdf.py b/rofunc/utils/robolab/formatter/urdf.py new file mode 100644 index 000000000..47b08dd99 --- /dev/null +++ b/rofunc/utils/robolab/formatter/urdf.py @@ -0,0 +1,137 @@ +from .urdf_parser.urdf import URDF, Mesh, Cylinder, Box, Sphere +from pytorch_kinematics import chain +from pytorch_kinematics import frame +import torch +import pytorch_kinematics.transforms as tf + +JOINT_TYPE_MAP = {'revolute': 'revolute', + 'continuous': 'revolute', + 'prismatic': 'prismatic', + 'fixed': 'fixed'} + + +def _convert_transform(origin): + if origin is None: + return tf.Transform3d() + else: + rpy = torch.tensor(origin.rpy, dtype=torch.float32) + return tf.Transform3d(rot=tf.quaternion_from_euler(rpy, "sxyz"), pos=origin.xyz) + + +def _convert_visual(visual): + if visual is None or visual.geometry is None: + return frame.Visual() + else: + v_tf = _convert_transform(visual.origin) + if isinstance(visual.geometry, Mesh): + g_type = "mesh" + g_param = (visual.geometry.filename, visual.geometry.scale) + elif isinstance(visual.geometry, Cylinder): + g_type = "cylinder" + g_param = (visual.geometry.radius, visual.geometry.length) + elif isinstance(visual.geometry, Box): + g_type = "box" + g_param = visual.geometry.size + elif isinstance(visual.geometry, Sphere): + g_type = "sphere" + g_param = visual.geometry.radius + else: + g_type = None + g_param = None + return frame.Visual(v_tf, g_type, g_param) + + +def _build_chain_recurse(root_frame, lmap, joints): + children = [] + for j in joints: + if j.parent == root_frame.link.name: + try: + limits = (j.limit.lower, j.limit.upper) + except AttributeError: + limits = None + child_frame = frame.Frame(j.child) + child_frame.joint = frame.Joint(j.name, offset=_convert_transform(j.origin), + joint_type=JOINT_TYPE_MAP[j.type], axis=j.axis, limits=limits) + link = lmap[j.child] + child_frame.link = frame.Link(link.name, offset=_convert_transform(link.origin), + visuals=[_convert_visual(link.visual)]) + child_frame.children = _build_chain_recurse(child_frame, lmap, joints) + children.append(child_frame) + return children + + +def build_chain_from_urdf(data): + """ + Build a Chain object from URDF data. + + Parameters + ---------- + data : str + URDF string data. + + Returns + ------- + chain.Chain + Chain object created from URDF. + + Example + ------- + >>> import pytorch_kinematics as pk + >>> data = ''' + ... + ... + ... + ... + ... + ... + ... ''' + >>> chain = pk.build_chain_from_urdf(data) + >>> print(chain) + link1_frame + link2_frame + + """ + robot = URDF.from_xml_string(data.encode("utf-8")) + lmap = robot.link_map + joints = robot.joints + n_joints = len(joints) + has_root = [True for _ in range(len(joints))] + for i in range(n_joints): + for j in range(i + 1, n_joints): + if joints[i].parent == joints[j].child: + has_root[i] = False + elif joints[j].parent == joints[i].child: + has_root[j] = False + for i in range(n_joints): + if has_root[i]: + root_link = lmap[joints[i].parent] + break + root_frame = frame.Frame(root_link.name) + root_frame.joint = frame.Joint() + root_frame.link = frame.Link(root_link.name, _convert_transform(root_link.origin), + [_convert_visual(root_link.visual)]) + root_frame.children = _build_chain_recurse(root_frame, lmap, joints) + return chain.Chain(root_frame) + + +def build_serial_chain_from_urdf(data, end_link_name, root_link_name=""): + """ + Build a SerialChain object from urdf data. + + Parameters + ---------- + data : str + URDF string data. + end_link_name : str + The name of the link that is the end effector. + root_link_name : str, optional + The name of the root link. + + Returns + ------- + chain.SerialChain + SerialChain object created from URDF. + """ + urdf_chain = build_chain_from_urdf(data) + return chain.SerialChain(urdf_chain, end_link_name, + "" if root_link_name == "" else root_link_name) diff --git a/rofunc/utils/robolab/formatter/urdf_parser/__init__.py b/rofunc/utils/robolab/formatter/urdf_parser/__init__.py new file mode 100644 index 000000000..c3ed66a75 --- /dev/null +++ b/rofunc/utils/robolab/formatter/urdf_parser/__init__.py @@ -0,0 +1 @@ +from . import sdf, urdf diff --git a/rofunc/utils/robolab/formatter/urdf_parser/sdf.py b/rofunc/utils/robolab/formatter/urdf_parser/sdf.py new file mode 100644 index 000000000..c44d9cd39 --- /dev/null +++ b/rofunc/utils/robolab/formatter/urdf_parser/sdf.py @@ -0,0 +1,310 @@ +from . import xml_reflection as xmlr +from .xml_reflection.basics import * + +# What is the scope of plugins? Model, World, Sensor? + +xmlr.start_namespace("sdf") + + +name_attribute = xmlr.Attribute("name", str, False) +pose_element = xmlr.Element("pose", "vector6", False) + + +class Inertia(xmlr.Object): + KEYS = ["ixx", "ixy", "ixz", "iyy", "iyz", "izz"] + + def __init__(self, ixx=0.0, ixy=0.0, ixz=0.0, iyy=0.0, iyz=0.0, izz=0.0): + self.ixx = ixx + self.ixy = ixy + self.ixz = ixz + self.iyy = iyy + self.iyz = iyz + self.izz = izz + + def to_matrix(self): + return [[self.ixx, self.ixy, self.ixz], [self.ixy, self.iyy, self.iyz], [self.ixz, self.iyz, self.izz]] + + +xmlr.reflect(Inertia, params=[xmlr.Element(key, float) for key in Inertia.KEYS]) + +# Pretty much copy-paste... Better method? +# Use multiple inheritance to separate the objects out so they are unique? + + +class Inertial(xmlr.Object): + def __init__(self, mass=0.0, inertia=None, pose=None): + self.mass = mass + self.inertia = inertia + self.pose = pose + + +xmlr.reflect(Inertial, params=[xmlr.Element("mass", float), xmlr.Element("inertia", Inertia), pose_element]) + + +class Box(xmlr.Object): + def __init__(self, size=None): + self.size = size + + +xmlr.reflect(Box, tag="box", params=[xmlr.Element("size", "vector3")]) + + +class Cylinder(xmlr.Object): + def __init__(self, radius=0.0, length=0.0): + self.radius = radius + self.length = length + + +xmlr.reflect(Cylinder, tag="cylinder", params=[xmlr.Element("radius", float), xmlr.Element("length", float)]) + + +class Sphere(xmlr.Object): + def __init__(self, radius=0.0): + self.radius = radius + + +xmlr.reflect(Sphere, tag="sphere", params=[xmlr.Element("radius", float)]) + + +class Mesh(xmlr.Object): + def __init__(self, filename=None, scale=None): + self.filename = filename + self.scale = scale + + +xmlr.reflect( + Mesh, tag="mesh", params=[xmlr.Element("filename", str), xmlr.Element("scale", "vector3", required=False)] +) + + +class GeometricType(xmlr.ValueType): + def __init__(self): + self.factory = xmlr.FactoryType( + "geometric", {"box": Box, "cylinder": Cylinder, "sphere": Sphere, "mesh": Mesh} + ) + + def from_xml(self, node, path): + children = xml_children(node) + assert len(children) == 1, "One element only for geometric" + return self.factory.from_xml(children[0], path=path) + + def write_xml(self, node, obj): + name = self.factory.get_name(obj) + child = node_add(node, name) + obj.write_xml(child) + + +xmlr.add_type("geometric", GeometricType()) + + +class Script(xmlr.Object): + def __init__(self, uri=None, name=None): + self.uri = uri + self.name = name + + +xmlr.reflect(Script, tag="script", params=[xmlr.Element("name", str, False), xmlr.Element("uri", str, False)]) + + +class Material(xmlr.Object): + def __init__(self, name=None, script=None): + self.name = name + self.script = script + + +xmlr.reflect(Material, tag="material", params=[name_attribute, xmlr.Element("script", Script, False)]) + + +class Visual(xmlr.Object): + def __init__(self, name=None, geometry=None, pose=None): + self.name = name + self.geometry = geometry + self.pose = pose + + +xmlr.reflect( + Visual, + tag="visual", + params=[ + name_attribute, + xmlr.Element("geometry", "geometric"), + xmlr.Element("material", Material, False), + pose_element, + ], +) + + +class Collision(xmlr.Object): + def __init__(self, name=None, geometry=None, pose=None): + self.name = name + self.geometry = geometry + self.pose = pose + + +xmlr.reflect(Collision, tag="collision", params=[name_attribute, xmlr.Element("geometry", "geometric"), pose_element]) + + +class Dynamics(xmlr.Object): + def __init__(self, damping=None, friction=None): + self.damping = damping + self.friction = friction + + +xmlr.reflect( + Dynamics, tag="dynamics", params=[xmlr.Element("damping", float, False), xmlr.Element("friction", float, False)] +) + + +class Limit(xmlr.Object): + def __init__(self, lower=None, upper=None): + self.lower = lower + self.upper = upper + + +xmlr.reflect(Limit, tag="limit", params=[xmlr.Element("lower", float, False), xmlr.Element("upper", float, False)]) + + +class Axis(xmlr.Object): + def __init__(self, xyz=None, limit=None, dynamics=None, use_parent_model_frame=None): + self.xyz = xyz + self.limit = limit + self.dynamics = dynamics + self.use_parent_model_frame = use_parent_model_frame + + +xmlr.reflect( + Axis, + tag="axis", + params=[ + xmlr.Element("xyz", "vector3"), + xmlr.Element("limit", Limit, False), + xmlr.Element("dynamics", Dynamics, False), + xmlr.Element("use_parent_model_frame", bool, False), + ], +) + + +class Joint(xmlr.Object): + TYPES = ["unknown", "revolute", "gearbox", "revolute2", "prismatic", "ball", "screw", "universal", "fixed"] + + def __init__(self, name=None, parent=None, child=None, joint_type=None, axis=None, pose=None): + self.aggregate_init() + self.name = name + self.parent = parent + self.child = child + self.type = joint_type + self.axis = axis + self.pose = pose + + # Aliases + @property + def joint_type(self): + return self.type + + @joint_type.setter + def joint_type(self, value): + self.type = value + + +xmlr.reflect( + Joint, + tag="joint", + params=[ + name_attribute, + xmlr.Attribute("type", str, False), + xmlr.Element("axis", Axis), + xmlr.Element("parent", str), + xmlr.Element("child", str), + pose_element, + ], +) + + +class Link(xmlr.Object): + def __init__(self, name=None, pose=None, inertial=None, kinematic=False): + self.aggregate_init() + self.name = name + self.pose = pose + self.inertial = inertial + self.kinematic = kinematic + self.visuals = [] + self.collisions = [] + + +xmlr.reflect( + Link, + tag="link", + params=[ + name_attribute, + xmlr.Element("inertial", Inertial), + xmlr.Attribute("kinematic", bool, False), + xmlr.AggregateElement("visual", Visual, var="visuals"), + xmlr.AggregateElement("collision", Collision, var="collisions"), + pose_element, + ], +) + + +class Model(xmlr.Object): + def __init__(self, name=None, pose=None): + self.aggregate_init() + self.name = name + self.pose = pose + self.links = [] + self.joints = [] + self.joint_map = {} + self.link_map = {} + + self.parent_map = {} + self.child_map = {} + + def add_aggregate(self, typeName, elem): + xmlr.Object.add_aggregate(self, typeName, elem) + + if typeName == "joint": + joint = elem + self.joint_map[joint.name] = joint + self.parent_map[joint.child] = (joint.name, joint.parent) + if joint.parent in self.child_map: + self.child_map[joint.parent].append((joint.name, joint.child)) + else: + self.child_map[joint.parent] = [(joint.name, joint.child)] + elif typeName == "link": + link = elem + self.link_map[link.name] = link + + def add_link(self, link): + self.add_aggregate("link", link) + + def add_joint(self, joint): + self.add_aggregate("joint", joint) + + +xmlr.reflect( + Model, + tag="model", + params=[ + name_attribute, + xmlr.AggregateElement("link", Link, var="links"), + xmlr.AggregateElement("joint", Joint, var="joints"), + pose_element, + ], +) + + +class SDF(xmlr.Object): + def __init__(self, version=None): + self.version = version + + +xmlr.reflect( + SDF, + tag="sdf", + params=[ + xmlr.Attribute("version", str, False), + xmlr.Element("model", Model, False), + ], +) + + +xmlr.end_namespace() diff --git a/rofunc/utils/robolab/formatter/urdf_parser/urdf.py b/rofunc/utils/robolab/formatter/urdf_parser/urdf.py new file mode 100644 index 000000000..a3af5d6cf --- /dev/null +++ b/rofunc/utils/robolab/formatter/urdf_parser/urdf.py @@ -0,0 +1,592 @@ +from . import xml_reflection as xmlr +from .xml_reflection.basics import * + +# Add a 'namespace' for names to avoid a conflict between URDF and SDF? +# A type registry? How to scope that? Just make a 'global' type pointer? +# Or just qualify names? urdf.geometric, sdf.geometric + +xmlr.start_namespace("urdf") + +xmlr.add_type("element_link", xmlr.SimpleElementType("link", str)) +xmlr.add_type("element_xyz", xmlr.SimpleElementType("xyz", "vector3")) + +verbose = True + + +class Pose(xmlr.Object): + def __init__(self, xyz=None, rpy=None): + self.xyz = xyz + self.rpy = rpy + + def check_valid(self): + assert (self.xyz is None or len(self.xyz) == 3) and (self.rpy is None or len(self.rpy) == 3) + + # Aliases for backwards compatibility + @property + def rotation(self): + return self.rpy + + @rotation.setter + def rotation(self, value): + self.rpy = value + + @property + def position(self): + return self.xyz + + @position.setter + def position(self, value): + self.xyz = value + + +xmlr.reflect( + Pose, + tag="origin", + params=[ + xmlr.Attribute("xyz", "vector3", False, default=[0, 0, 0]), + xmlr.Attribute("rpy", "vector3", False, default=[0, 0, 0]), + ], +) + + +# Common stuff +name_attribute = xmlr.Attribute("name", str) +origin_element = xmlr.Element("origin", Pose, False) + + +class Color(xmlr.Object): + def __init__(self, *args): + # What about named colors? + count = len(args) + if count == 4 or count == 3: + self.rgba = args + elif count == 1: + self.rgba = args[0] + elif count == 0: + self.rgba = None + if self.rgba is not None: + if len(self.rgba) == 3: + self.rgba += [1.0] + if len(self.rgba) != 4: + raise Exception("Invalid color argument count") + + +xmlr.reflect(Color, tag="color", params=[xmlr.Attribute("rgba", "vector4")]) + + +class JointDynamics(xmlr.Object): + def __init__(self, damping=None, friction=None): + self.damping = damping + self.friction = friction + + +xmlr.reflect( + JointDynamics, + tag="dynamics", + params=[xmlr.Attribute("damping", float, False), xmlr.Attribute("friction", float, False)], +) + + +class Box(xmlr.Object): + def __init__(self, size=None): + self.size = size + + +xmlr.reflect(Box, tag="box", params=[xmlr.Attribute("size", "vector3")]) + + +class Cylinder(xmlr.Object): + def __init__(self, radius=0.0, length=0.0): + self.radius = radius + self.length = length + + +xmlr.reflect(Cylinder, tag="cylinder", params=[xmlr.Attribute("radius", float), xmlr.Attribute("length", float)]) + + +class Sphere(xmlr.Object): + def __init__(self, radius=0.0): + self.radius = radius + + +xmlr.reflect(Sphere, tag="sphere", params=[xmlr.Attribute("radius", float)]) + + +class Mesh(xmlr.Object): + def __init__(self, filename=None, scale=None): + self.filename = filename + self.scale = scale + + +xmlr.reflect( + Mesh, tag="mesh", params=[xmlr.Attribute("filename", str), xmlr.Attribute("scale", "vector3", required=False)] +) + + +class GeometricType(xmlr.ValueType): + def __init__(self): + self.factory = xmlr.FactoryType( + "geometric", {"box": Box, "cylinder": Cylinder, "sphere": Sphere, "mesh": Mesh} + ) + + def from_xml(self, node, path): + children = xml_children(node) + assert len(children) == 1, "One element only for geometric" + return self.factory.from_xml(children[0], path=path) + + def write_xml(self, node, obj): + name = self.factory.get_name(obj) + child = node_add(node, name) + obj.write_xml(child) + + +xmlr.add_type("geometric", GeometricType()) + + +class Collision(xmlr.Object): + def __init__(self, geometry=None, origin=None): + self.geometry = geometry + self.origin = origin + + +xmlr.reflect(Collision, tag="collision", params=[origin_element, xmlr.Element("geometry", "geometric")]) + + +class Texture(xmlr.Object): + def __init__(self, filename=None): + self.filename = filename + + +xmlr.reflect(Texture, tag="texture", params=[xmlr.Attribute("filename", str)]) + + +class Material(xmlr.Object): + def __init__(self, name=None, color=None, texture=None): + self.name = name + self.color = color + self.texture = texture + + def check_valid(self): + if self.color is None and self.texture is None: + xmlr.on_error("Material has neither a color nor texture.") + + +xmlr.reflect( + Material, + tag="material", + params=[name_attribute, xmlr.Element("color", Color, False), xmlr.Element("texture", Texture, False)], +) + + +class LinkMaterial(Material): + def check_valid(self): + pass + + +class Visual(xmlr.Object): + def __init__(self, geometry=None, material=None, origin=None): + self.geometry = geometry + self.material = material + self.origin = origin + + +xmlr.reflect( + Visual, + tag="visual", + params=[origin_element, xmlr.Element("geometry", "geometric"), xmlr.Element("material", LinkMaterial, False)], +) + + +class Inertia(xmlr.Object): + KEYS = ["ixx", "ixy", "ixz", "iyy", "iyz", "izz"] + + def __init__(self, ixx=0.0, ixy=0.0, ixz=0.0, iyy=0.0, iyz=0.0, izz=0.0): + self.ixx = ixx + self.ixy = ixy + self.ixz = ixz + self.iyy = iyy + self.iyz = iyz + self.izz = izz + + def to_matrix(self): + return [[self.ixx, self.ixy, self.ixz], [self.ixy, self.iyy, self.iyz], [self.ixz, self.iyz, self.izz]] + + +xmlr.reflect(Inertia, tag="inertia", params=[xmlr.Attribute(key, float) for key in Inertia.KEYS]) + + +class Inertial(xmlr.Object): + def __init__(self, mass=0.0, inertia=None, origin=None): + self.mass = mass + self.inertia = inertia + self.origin = origin + + +xmlr.reflect( + Inertial, + tag="inertial", + params=[origin_element, xmlr.Element("mass", "element_value"), xmlr.Element("inertia", Inertia, False)], +) + + +# FIXME: we are missing the reference position here. +class JointCalibration(xmlr.Object): + def __init__(self, rising=None, falling=None): + self.rising = rising + self.falling = falling + + +xmlr.reflect( + JointCalibration, + tag="calibration", + params=[xmlr.Attribute("rising", float, False, 0), xmlr.Attribute("falling", float, False, 0)], +) + + +class JointLimit(xmlr.Object): + def __init__(self, effort=None, velocity=None, lower=None, upper=None): + self.effort = effort + self.velocity = velocity + self.lower = lower + self.upper = upper + + +xmlr.reflect( + JointLimit, + tag="limit", + params=[ + xmlr.Attribute("effort", float), + xmlr.Attribute("lower", float, False, 0), + xmlr.Attribute("upper", float, False, 0), + xmlr.Attribute("velocity", float), + ], +) + +# FIXME: we are missing __str__ here. + + +class JointMimic(xmlr.Object): + def __init__(self, joint_name=None, multiplier=None, offset=None): + self.joint = joint_name + self.multiplier = multiplier + self.offset = offset + + +xmlr.reflect( + JointMimic, + tag="mimic", + params=[ + xmlr.Attribute("joint", str), + xmlr.Attribute("multiplier", float, False), + xmlr.Attribute("offset", float, False), + ], +) + + +class SafetyController(xmlr.Object): + def __init__(self, velocity=None, position=None, lower=None, upper=None): + self.k_velocity = velocity + self.k_position = position + self.soft_lower_limit = lower + self.soft_upper_limit = upper + + +xmlr.reflect( + SafetyController, + tag="safety_controller", + params=[ + xmlr.Attribute("k_velocity", float), + xmlr.Attribute("k_position", float, False, 0), + xmlr.Attribute("soft_lower_limit", float, False, 0), + xmlr.Attribute("soft_upper_limit", float, False, 0), + ], +) + + +class Joint(xmlr.Object): + TYPES = ["unknown", "revolute", "continuous", "prismatic", "floating", "planar", "fixed"] + + def __init__( + self, + name=None, + parent=None, + child=None, + joint_type=None, + axis=None, + origin=None, + limit=None, + dynamics=None, + safety_controller=None, + calibration=None, + mimic=None, + ): + self.name = name + self.parent = parent + self.child = child + self.type = joint_type + self.axis = axis + self.origin = origin + self.limit = limit + self.dynamics = dynamics + self.safety_controller = safety_controller + self.calibration = calibration + self.mimic = mimic + + def check_valid(self): + assert self.type in self.TYPES, "Invalid joint type: {}".format(self.type) # noqa + + # Aliases + @property + def joint_type(self): + return self.type + + @joint_type.setter + def joint_type(self, value): + self.type = value + + +xmlr.reflect( + Joint, + tag="joint", + params=[ + name_attribute, + xmlr.Attribute("type", str), + origin_element, + xmlr.Element("axis", "element_xyz", False), + xmlr.Element("parent", "element_link"), + xmlr.Element("child", "element_link"), + xmlr.Element("limit", JointLimit, False), + xmlr.Element("dynamics", JointDynamics, False), + xmlr.Element("safety_controller", SafetyController, False), + xmlr.Element("calibration", JointCalibration, False), + xmlr.Element("mimic", JointMimic, False), + ], +) + + +class Link(xmlr.Object): + def __init__(self, name=None, visual=None, inertial=None, collision=None, origin=None): + self.aggregate_init() + self.name = name + self.visuals = [] + self.inertial = inertial + self.collisions = [] + self.origin = origin + + def __get_visual(self): + """Return the first visual or None.""" + if self.visuals: + return self.visuals[0] + + def __set_visual(self, visual): + """Set the first visual.""" + if self.visuals: + self.visuals[0] = visual + else: + self.visuals.append(visual) + + def __get_collision(self): + """Return the first collision or None.""" + if self.collisions: + return self.collisions[0] + + def __set_collision(self, collision): + """Set the first collision.""" + if self.collisions: + self.collisions[0] = collision + else: + self.collisions.append(collision) + + # Properties for backwards compatibility + visual = property(__get_visual, __set_visual) + collision = property(__get_collision, __set_collision) + + +xmlr.reflect( + Link, + tag="link", + params=[ + name_attribute, + origin_element, + xmlr.AggregateElement("visual", Visual), + xmlr.AggregateElement("collision", Collision), + xmlr.Element("inertial", Inertial, False), + ], +) + + +class PR2Transmission(xmlr.Object): + def __init__(self, name=None, joint=None, actuator=None, type=None, mechanicalReduction=1): + self.name = name + self.type = type + self.joint = joint + self.actuator = actuator + self.mechanicalReduction = mechanicalReduction + + +xmlr.reflect( + PR2Transmission, + tag="pr2_transmission", + params=[ + name_attribute, + xmlr.Attribute("type", str), + xmlr.Element("joint", "element_name"), + xmlr.Element("actuator", "element_name"), + xmlr.Element("mechanicalReduction", float), + ], +) + + +class Actuator(xmlr.Object): + def __init__(self, name=None, mechanicalReduction=1): + self.name = name + self.mechanicalReduction = None + + +xmlr.reflect( + Actuator, tag="actuator", params=[name_attribute, xmlr.Element("mechanicalReduction", float, required=False)] +) + + +class TransmissionJoint(xmlr.Object): + def __init__(self, name=None): + self.aggregate_init() + self.name = name + self.hardwareInterfaces = [] + + def check_valid(self): + assert len(self.hardwareInterfaces) > 0, "no hardwareInterface defined" + + +xmlr.reflect( + TransmissionJoint, + tag="joint", + params=[ + name_attribute, + xmlr.AggregateElement("hardwareInterface", str), + ], +) + + +class Transmission(xmlr.Object): + """New format: http://wiki.ros.org/urdf/XML/Transmission""" + + def __init__(self, name=None): + self.aggregate_init() + self.name = name + self.joints = [] + self.actuators = [] + + def check_valid(self): + assert len(self.joints) > 0, "no joint defined" + assert len(self.actuators) > 0, "no actuator defined" + + +xmlr.reflect( + Transmission, + tag="new_transmission", + params=[ + name_attribute, + xmlr.Element("type", str), + xmlr.AggregateElement("joint", TransmissionJoint), + xmlr.AggregateElement("actuator", Actuator), + ], +) + +xmlr.add_type("transmission", xmlr.DuckTypedFactory("transmission", [Transmission, PR2Transmission])) + + +class Robot(xmlr.Object): + def __init__(self, name=None): + self.aggregate_init() + + self.name = name + self.joints = [] + self.links = [] + self.materials = [] + self.gazebos = [] + self.transmissions = [] + + self.joint_map = {} + self.link_map = {} + + self.parent_map = {} + self.child_map = {} + + def add_aggregate(self, typeName, elem): + xmlr.Object.add_aggregate(self, typeName, elem) + + if typeName == "joint": + joint = elem + self.joint_map[joint.name] = joint + self.parent_map[joint.child] = (joint.name, joint.parent) + if joint.parent in self.child_map: + self.child_map[joint.parent].append((joint.name, joint.child)) + else: + self.child_map[joint.parent] = [(joint.name, joint.child)] + elif typeName == "link": + link = elem + self.link_map[link.name] = link + + def add_link(self, link): + self.add_aggregate("link", link) + + def add_joint(self, joint): + self.add_aggregate("joint", joint) + + def get_chain(self, root, tip, joints=True, links=True, fixed=True): + chain = [] + if links: + chain.append(tip) + link = tip + while link != root: + (joint, parent) = self.parent_map[link] + if joints: + if fixed or self.joint_map[joint].joint_type != "fixed": + chain.append(joint) + if links: + chain.append(parent) + link = parent + chain.reverse() + return chain + + def get_root(self): + root = None + for link in self.link_map: + if link not in self.parent_map: + assert root is None, "Multiple roots detected, invalid URDF." + root = link + assert root is not None, "No roots detected, invalid URDF." + return root + + @classmethod + def from_parameter_server(cls, key="robot_description"): + """ + Retrieve the robot model on the parameter server + and parse it to create a URDF robot structure. + + Warning: this requires roscore to be running. + """ + # Could move this into xml_reflection + import rospy + + return cls.from_xml_string(rospy.get_param(key)) + + +xmlr.reflect( + Robot, + tag="robot", + params=[ + xmlr.Attribute("name", str, False), # Is 'name' a required attribute? + xmlr.AggregateElement("link", Link), + xmlr.AggregateElement("joint", Joint), + xmlr.AggregateElement("gazebo", xmlr.RawType()), + xmlr.AggregateElement("transmission", "transmission"), + xmlr.AggregateElement("material", Material), + ], +) + +# Make an alias +URDF = Robot + +xmlr.end_namespace() diff --git a/rofunc/utils/robolab/formatter/urdf_parser/xml_reflection/__init__.py b/rofunc/utils/robolab/formatter/urdf_parser/xml_reflection/__init__.py new file mode 100644 index 000000000..bb67a43fa --- /dev/null +++ b/rofunc/utils/robolab/formatter/urdf_parser/xml_reflection/__init__.py @@ -0,0 +1 @@ +from .core import * diff --git a/rofunc/utils/robolab/formatter/urdf_parser/xml_reflection/basics.py b/rofunc/utils/robolab/formatter/urdf_parser/xml_reflection/basics.py new file mode 100644 index 000000000..638ad72b0 --- /dev/null +++ b/rofunc/utils/robolab/formatter/urdf_parser/xml_reflection/basics.py @@ -0,0 +1,91 @@ +import collections +import string + +import yaml +from lxml import etree + + +def xml_string(rootXml, addHeader=True): + # Meh + xmlString = etree.tostring(rootXml, pretty_print=True, encoding="unicode") + if addHeader: + xmlString = '\n' + xmlString + return xmlString + + +def dict_sub(obj, keys): + return dict((key, obj[key]) for key in keys) + + +def node_add(doc, sub): + if sub is None: + return None + if type(sub) == str: + return etree.SubElement(doc, sub) + elif isinstance(sub, etree._Element): + doc.append(sub) # This screws up the rest of the tree for prettyprint + return sub + else: + raise Exception("Invalid sub value") + + +def pfloat(x): + return str(x).rstrip(".") + + +def xml_children(node): + children = node.getchildren() + + def predicate(node): + return not isinstance(node, etree._Comment) + + return list(filter(predicate, children)) + + +def isstring(obj): + try: + return isinstance(obj, basestring) + except NameError: + return isinstance(obj, str) + + +def to_yaml(obj): + """Simplify yaml representation for pretty printing""" + # Is there a better way to do this by adding a representation with + # yaml.Dumper? + # Ordered dict: http://pyyaml.org/ticket/29#comment:11 + if obj is None or isstring(obj): + out = str(obj) + elif type(obj) in [int, float, bool]: + return obj + elif hasattr(obj, "to_yaml"): + out = obj.to_yaml() + elif isinstance(obj, etree._Element): + out = etree.tostring(obj, pretty_print=True) + elif type(obj) == dict: + out = {} + for (var, value) in obj.items(): + out[str(var)] = to_yaml(value) + elif hasattr(obj, "tolist"): + # For numpy objects + out = to_yaml(obj.tolist()) + elif isinstance(obj, collections.Iterable): + out = [to_yaml(item) for item in obj] + else: + out = str(obj) + return out + + +class SelectiveReflection(object): + def get_refl_vars(self): + return list(vars(self).keys()) + + +class YamlReflection(SelectiveReflection): + def to_yaml(self): + raw = dict((var, getattr(self, var)) for var in self.get_refl_vars()) + return to_yaml(raw) + + def __str__(self): + # Good idea? Will it remove other important things? + return yaml.dump(self.to_yaml()).rstrip() diff --git a/rofunc/utils/robolab/formatter/urdf_parser/xml_reflection/core.py b/rofunc/utils/robolab/formatter/urdf_parser/xml_reflection/core.py new file mode 100644 index 000000000..2d9c3fe1a --- /dev/null +++ b/rofunc/utils/robolab/formatter/urdf_parser/xml_reflection/core.py @@ -0,0 +1,680 @@ +import copy +import sys + +from .basics import * + +# @todo Get rid of "import *" +# @todo Make this work with decorators + +# Is this reflection or serialization? I think it's serialization... +# Rename? + +# Do parent operations after, to allow child to 'override' parameters? +# Need to make sure that duplicate entires do not get into the 'unset*' lists + + +def reflect(cls, *args, **kwargs): + """ + Simple wrapper to add XML reflection to an xml_reflection.Object class + """ + cls.XML_REFL = Reflection(*args, **kwargs) + + +# Rename 'write_xml' to 'write_xml' to have paired 'load/dump', and make +# 'pre_dump' and 'post_load'? +# When dumping to yaml, include tag name? + +# How to incorporate line number and all that jazz? +def on_error_stderr(message): + """What to do on an error. This can be changed to raise an exception.""" + sys.stderr.write(message + "\n") + + +on_error = on_error_stderr + + +skip_default = False +# defaultIfMatching = True # Not implemeneted yet + +# Registering Types +value_types = {} +value_type_prefix = "" + + +def start_namespace(namespace): + """ + Basic mechanism to prevent conflicts for string types for URDF and SDF + @note Does not handle nesting! + """ + global value_type_prefix + value_type_prefix = namespace + "." + + +def end_namespace(): + global value_type_prefix + value_type_prefix = "" + + +def add_type(key, value): + if isinstance(key, str): + key = value_type_prefix + key + assert key not in value_types + value_types[key] = value + + +def get_type(cur_type): + """Can wrap value types if needed""" + if value_type_prefix and isinstance(cur_type, str): + # See if it exists in current 'namespace' + curKey = value_type_prefix + cur_type + value_type = value_types.get(curKey) + else: + value_type = None + if value_type is None: + # Try again, in 'global' scope + value_type = value_types.get(cur_type) + if value_type is None: + value_type = make_type(cur_type) + add_type(cur_type, value_type) + return value_type + + +def make_type(cur_type): + if isinstance(cur_type, ValueType): + return cur_type + elif isinstance(cur_type, str): + if cur_type.startswith("vector"): + extra = cur_type[6:] + if extra: + count = float(extra) + else: + count = None + return VectorType(count) + else: + raise Exception("Invalid value type: {}".format(cur_type)) + elif cur_type == list: + return ListType() + elif issubclass(cur_type, Object): + return ObjectType(cur_type) + elif cur_type in [str, float, bool]: + return BasicType(cur_type) + else: + raise Exception("Invalid type: {}".format(cur_type)) + + +class Path(object): + def __init__(self, tag, parent=None, suffix="", tree=None): + self.parent = parent + self.tag = tag + self.suffix = suffix + self.tree = tree # For validating general path (getting true XML path) + + def __str__(self): + if self.parent is not None: + return "{}/{}{}".format(self.parent, self.tag, self.suffix) + else: + if self.tag is not None and len(self.tag) > 0: + return "/{}{}".format(self.tag, self.suffix) + else: + return self.suffix + + +class ParseError(Exception): + def __init__(self, e, path): + self.e = e + self.path = path + message = "ParseError in {}:\n{}".format(self.path, self.e) + super(ParseError, self).__init__(message) + + +class ValueType(object): + """Primitive value type""" + + def from_xml(self, node, path): + return self.from_string(node.text) + + def write_xml(self, node, value): + """ + If type has 'write_xml', this function should expect to have it's own + XML already created i.e., In Axis.to_sdf(self, node), 'node' would be + the 'axis' element. + @todo Add function that makes an XML node completely independently? + """ + node.text = self.to_string(value) + + def equals(self, a, b): + return a == b + + +class BasicType(ValueType): + def __init__(self, cur_type): + self.type = cur_type + + def to_string(self, value): + return str(value) + + def from_string(self, value): + return self.type(value) + + +class ListType(ValueType): + def to_string(self, values): + return " ".join(values) + + def from_string(self, text): + return text.split() + + def equals(self, aValues, bValues): + return len(aValues) == len(bValues) and all(a == b for (a, b) in zip(aValues, bValues)) # noqa + + +class VectorType(ListType): + def __init__(self, count=None): + self.count = count + + def check(self, values): + if self.count is not None: + assert len(values) == self.count, "Invalid vector length" + + def to_string(self, values): + self.check(values) + raw = list(map(str, values)) + return ListType.to_string(self, raw) + + def from_string(self, text): + raw = ListType.from_string(self, text) + self.check(raw) + return list(map(float, raw)) + + +class RawType(ValueType): + """ + Simple, raw XML value. Need to bugfix putting this back into a document + """ + + def from_xml(self, node, path): + return node + + def write_xml(self, node, value): + # @todo rying to insert an element at root level seems to screw up + # pretty printing + children = xml_children(value) + list(map(node.append, children)) + # Copy attributes + for (attrib_key, attrib_value) in value.attrib.items(): + node.set(attrib_key, attrib_value) + + +class SimpleElementType(ValueType): + """ + Extractor that retrieves data from an element, given a + specified attribute, casted to value_type. + """ + + def __init__(self, attribute, value_type): + self.attribute = attribute + self.value_type = get_type(value_type) + + def from_xml(self, node, path): + text = node.get(self.attribute) + return self.value_type.from_string(text) + + def write_xml(self, node, value): + text = self.value_type.to_string(value) + node.set(self.attribute, text) + + +class ObjectType(ValueType): + def __init__(self, cur_type): + self.type = cur_type + + def from_xml(self, node, path): + obj = self.type() + obj.read_xml(node, path) + return obj + + def write_xml(self, node, obj): + obj.write_xml(node) + + +class FactoryType(ValueType): + def __init__(self, name, typeMap): + self.name = name + self.typeMap = typeMap + self.nameMap = {} + for (key, value) in typeMap.items(): + # Reverse lookup + self.nameMap[value] = key + + def from_xml(self, node, path): + cur_type = self.typeMap.get(node.tag) + if cur_type is None: + raise Exception("Invalid {} tag: {}".format(self.name, node.tag)) + value_type = get_type(cur_type) + return value_type.from_xml(node, path) + + def get_name(self, obj): + cur_type = type(obj) + name = self.nameMap.get(cur_type) + if name is None: + raise Exception("Invalid {} type: {}".format(self.name, cur_type)) + return name + + def write_xml(self, node, obj): + obj.write_xml(node) + + +class DuckTypedFactory(ValueType): + def __init__(self, name, typeOrder): + self.name = name + assert len(typeOrder) > 0 + self.type_order = typeOrder + + def from_xml(self, node, path): + error_set = [] + for value_type in self.type_order: + try: + return value_type.from_xml(node, path) + except Exception as e: + error_set.append((value_type, e)) + # Should have returned, we encountered errors + out = "Could not perform duck-typed parsing." + for (value_type, e) in error_set: + out += "\nValue Type: {}\nException: {}\n".format(value_type, e) + raise ParseError(Exception(out), path) + + def write_xml(self, node, obj): + obj.write_xml(node) + + +class Param(object): + """Mirroring Gazebo's SDF api + + @param xml_var: Xml name + @todo If the value_type is an object with a tag defined in it's + reflection, allow it to act as the default tag name? + @param var: Python class variable name. By default it's the same as the + XML name + """ + + def __init__(self, xml_var, value_type, required=True, default=None, var=None): + self.xml_var = xml_var + if var is None: + self.var = xml_var + else: + self.var = var + self.type = None + self.value_type = get_type(value_type) + self.default = default + if required: + assert default is None, "Default does not make sense for a required field" # noqa + self.required = required + self.is_aggregate = False + + def set_default(self, obj): + if self.required: + raise Exception("Required {} not set in XML: {}".format(self.type, self.xml_var)) # noqa + elif not skip_default: + setattr(obj, self.var, self.default) + + +class Attribute(Param): + def __init__(self, xml_var, value_type, required=True, default=None, var=None): + Param.__init__(self, xml_var, value_type, required, default, var) + self.type = "attribute" + + def set_from_string(self, obj, value): + """Node is the parent node in this case""" + # Duplicate attributes cannot occur at this point + setattr(obj, self.var, self.value_type.from_string(value)) + + def get_value(self, obj): + return getattr(obj, self.var) + + def add_to_xml(self, obj, node): + value = getattr(obj, self.var) + # Do not set with default value if value is None + if value is None: + if self.required: + raise Exception("Required attribute not set in object: {}".format(self.var)) # noqa + elif not skip_default: + value = self.default + # Allow value type to handle None? + if value is not None: + node.set(self.xml_var, self.value_type.to_string(value)) + + +# Add option if this requires a header? +# Like .... ??? +# Not really... This would be a specific list type, not really aggregate + + +class Element(Param): + def __init__(self, xml_var, value_type, required=True, default=None, var=None, is_raw=False): + Param.__init__(self, xml_var, value_type, required, default, var) + self.type = "element" + self.is_raw = is_raw + + def set_from_xml(self, obj, node, path): + value = self.value_type.from_xml(node, path) + setattr(obj, self.var, value) + + def add_to_xml(self, obj, parent): + value = getattr(obj, self.xml_var) + if value is None: + if self.required: + raise Exception("Required element not defined in object: {}".format(self.var)) # noqa + elif not skip_default: + value = self.default + if value is not None: + self.add_scalar_to_xml(parent, value) + + def add_scalar_to_xml(self, parent, value): + if self.is_raw: + node = parent + else: + node = node_add(parent, self.xml_var) + self.value_type.write_xml(node, value) + + +class AggregateElement(Element): + def __init__(self, xml_var, value_type, var=None, is_raw=False): + if var is None: + var = xml_var + "s" + Element.__init__(self, xml_var, value_type, required=False, var=var, is_raw=is_raw) + self.is_aggregate = True + + def add_from_xml(self, obj, node, path): + value = self.value_type.from_xml(node, path) + obj.add_aggregate(self.xml_var, value) + + def set_default(self, obj): + pass + + +class Info: + """Small container for keeping track of what's been consumed""" + + def __init__(self, node): + self.attributes = list(node.attrib.keys()) + self.children = xml_children(node) + + +class Reflection(object): + def __init__(self, params=[], parent_cls=None, tag=None): + """Construct a XML reflection thing + @param parent_cls: Parent class, to use it's reflection as well. + @param tag: Only necessary if you intend to use Object.write_xml_doc() + This does not override the name supplied in the reflection + definition thing. + """ + if parent_cls is not None: + self.parent = parent_cls.XML_REFL + else: + self.parent = None + self.tag = tag + + # Laziness for now + attributes = [] + elements = [] + for param in params: + if isinstance(param, Element): + elements.append(param) + else: + attributes.append(param) + + self.vars = [] + self.paramMap = {} + + self.attributes = attributes + self.attribute_map = {} + self.required_attribute_names = [] + for attribute in attributes: + self.attribute_map[attribute.xml_var] = attribute + self.paramMap[attribute.xml_var] = attribute + self.vars.append(attribute.var) + if attribute.required: + self.required_attribute_names.append(attribute.xml_var) + + self.elements = [] + self.element_map = {} + self.required_element_names = [] + self.aggregates = [] + self.scalars = [] + self.scalarNames = [] + for element in elements: + self.element_map[element.xml_var] = element + self.paramMap[element.xml_var] = element + self.vars.append(element.var) + if element.required: + self.required_element_names.append(element.xml_var) + if element.is_aggregate: + self.aggregates.append(element) + else: + self.scalars.append(element) + self.scalarNames.append(element.xml_var) + + def set_from_xml(self, obj, node, path, info=None): + is_final = False + if info is None: + is_final = True + info = Info(node) + + if self.parent: + path = self.parent.set_from_xml(obj, node, path, info) + + # Make this a map instead? Faster access? {name: isSet} ? + unset_attributes = list(self.attribute_map.keys()) + unset_scalars = copy.copy(self.scalarNames) + + def get_attr_path(attribute): + attr_path = copy.copy(path) + attr_path.suffix += "[@{}]".format(attribute.xml_var) + return attr_path + + def get_element_path(element): + element_path = Path(element.xml_var, parent=path) + # Add an index (allow this to be overriden) + if element.is_aggregate: + values = obj.get_aggregate_list(element.xml_var) + index = 1 + len(values) # 1-based indexing for W3C XPath + element_path.suffix = "[{}]".format(index) + return element_path + + id_var = "name" + # Better method? Queues? + for xml_var in copy.copy(info.attributes): + attribute = self.attribute_map.get(xml_var) + if attribute is not None: + value = node.attrib[xml_var] + attr_path = get_attr_path(attribute) + try: + attribute.set_from_string(obj, value) + if attribute.xml_var == id_var: + # Add id_var suffix to current path (do not copy so it propagates) + path.suffix = "[@{}='{}']".format(id_var, attribute.get_value(obj)) + except ParseError: + raise + except Exception as e: + raise ParseError(e, attr_path) + unset_attributes.remove(xml_var) + info.attributes.remove(xml_var) + + # Parse unconsumed nodes + for child in copy.copy(info.children): + tag = child.tag + element = self.element_map.get(tag) + if element is not None: + # Name will have been set + element_path = get_element_path(element) + if element.is_aggregate: + element.add_from_xml(obj, child, element_path) + else: + if tag in unset_scalars: + element.set_from_xml(obj, child, element_path) + unset_scalars.remove(tag) + else: + on_error("Scalar element defined multiple times: {}".format(tag)) # noqa + info.children.remove(child) + + # For unset attributes and scalar elements, we should not pass the attribute + # or element path, as those paths will implicitly not exist. + # If we do supply it, then the user would need to manually prune the XPath to try + # and find where the problematic parent element. + for attribute in map(self.attribute_map.get, unset_attributes): + try: + attribute.set_default(obj) + except ParseError: + raise + except Exception as e: + raise ParseError(e, path) # get_attr_path(attribute.xml_var) + + for element in map(self.element_map.get, unset_scalars): + try: + element.set_default(obj) + except ParseError: + raise + except Exception as e: + raise ParseError(e, path) # get_element_path(element) + + if is_final: + for xml_var in info.attributes: + on_error('Unknown attribute "{}" in {}'.format(xml_var, path)) + for node in info.children: + on_error('Unknown tag "{}" in {}'.format(node.tag, path)) + # Allow children parsers to adopt this current path (if modified with id_var) + return path + + def add_to_xml(self, obj, node): + if self.parent: + self.parent.add_to_xml(obj, node) + for attribute in self.attributes: + attribute.add_to_xml(obj, node) + for element in self.scalars: + element.add_to_xml(obj, node) + # Now add in aggregates + if self.aggregates: + obj.add_aggregates_to_xml(node) + + +class Object(YamlReflection): + """Raw python object for yaml / xml representation""" + + XML_REFL = None + + def get_refl_vars(self): + return self.XML_REFL.vars + + def check_valid(self): + pass + + def pre_write_xml(self): + """If anything needs to be converted prior to dumping to xml + i.e., getting the names of objects and such""" + pass + + def write_xml(self, node): + """Adds contents directly to XML node""" + self.check_valid() + self.pre_write_xml() + self.XML_REFL.add_to_xml(self, node) + + def to_xml(self): + """Creates an overarching tag and adds its contents to the node""" + tag = self.XML_REFL.tag + assert tag is not None, "Must define 'tag' in reflection to use this function" # noqa + doc = etree.Element(tag) + self.write_xml(doc) + return doc + + def to_xml_string(self, addHeader=True): + return xml_string(self.to_xml(), addHeader) + + def post_read_xml(self): + pass + + def read_xml(self, node, path): + self.XML_REFL.set_from_xml(self, node, path) + self.post_read_xml() + try: + self.check_valid() + except ParseError: + raise + except Exception as e: + raise ParseError(e, path) + + @classmethod + def from_xml(cls, node, path): + cur_type = get_type(cls) + return cur_type.from_xml(node, path) + + @classmethod + def from_xml_string(cls, xml_string): + node = etree.fromstring(xml_string) + path = Path(cls.XML_REFL.tag, tree=etree.ElementTree(node)) + return cls.from_xml(node, path) + + @classmethod + def from_xml_file(cls, file_path): + xml_string = open(file_path, "r").read() + return cls.from_xml_string(xml_string.encode('utf-8')) + + # Confusing distinction between loading code in object and reflection + # registry thing... + + def get_aggregate_list(self, xml_var): + var = self.XML_REFL.paramMap[xml_var].var + values = getattr(self, var) + assert isinstance(values, list) + return values + + def aggregate_init(self): + """Must be called in constructor!""" + self.aggregate_order = [] + # Store this info in the loaded object??? Nah + self.aggregate_type = {} + + def add_aggregate(self, xml_var, obj): + """NOTE: One must keep careful track of aggregate types for this system. + Can use 'lump_aggregates()' before writing if you don't care.""" + self.get_aggregate_list(xml_var).append(obj) + self.aggregate_order.append(obj) + self.aggregate_type[obj] = xml_var + + def add_aggregates_to_xml(self, node): + for value in self.aggregate_order: + typeName = self.aggregate_type[value] + element = self.XML_REFL.element_map[typeName] + element.add_scalar_to_xml(node, value) + + def remove_aggregate(self, obj): + self.aggregate_order.remove(obj) + xml_var = self.aggregate_type[obj] + del self.aggregate_type[obj] + self.get_aggregate_list(xml_var).remove(obj) + + def lump_aggregates(self): + """Put all aggregate types together, just because""" + self.aggregate_init() + for param in self.XML_REFL.aggregates: + for obj in self.get_aggregate_list(param.xml_var): + self.add_aggregate(param.var, obj) + + """ Compatibility """ + + def parse(self, xml_string): + node = etree.fromstring(xml_string) + path = Path(self.XML_REFL.tag, tree=etree.ElementTree(node)) + self.read_xml(node, path) + return self + + +# Really common types +# Better name: element_with_name? Attributed element? +add_type("element_name", SimpleElementType("name", str)) +add_type("element_value", SimpleElementType("value", float)) + +# Add in common vector types so they aren't absorbed into the namespaces +get_type("vector3") +get_type("vector4") +get_type("vector6") diff --git a/rofunc/utils/robolab/kinematics/__init__.py b/rofunc/utils/robolab/kinematics/__init__.py index 4daaac452..98b90b69a 100644 --- a/rofunc/utils/robolab/kinematics/__init__.py +++ b/rofunc/utils/robolab/kinematics/__init__.py @@ -1,2 +1,2 @@ from .fk import get_fk_from_chain, get_fk_from_model -from .robot_class import RobotModel \ No newline at end of file +from .robot_class import RobotModel diff --git a/rofunc/utils/robolab/kinematics/fk.py b/rofunc/utils/robolab/kinematics/fk.py index 5eb48b9a5..0db47e8eb 100644 --- a/rofunc/utils/robolab/kinematics/fk.py +++ b/rofunc/utils/robolab/kinematics/fk.py @@ -1,7 +1,6 @@ -def get_fk_from_chain(chain, joint_value, export_link_name): +def get_fk_from_chain(chain, joint_value, export_link_name=None): """ Get the forward kinematics from a serial chain - :param chain: :param joint_value: :param export_link_name: @@ -10,14 +9,16 @@ def get_fk_from_chain(chain, joint_value, export_link_name): # do forward kinematics and get transform objects; end_only=False gives a dictionary of transforms for all links ret = chain.forward_kinematics(joint_value) # look up the transform for a specific link - pose = ret[export_link_name] + if export_link_name is not None: + pose = ret[export_link_name] + else: + pose = None return pose, ret -def get_fk_from_model(model_path: str, joint_value, export_link, verbose=False): +def get_fk_from_model(model_path: str, joint_value, export_link=None, verbose=False): """ Get the forward kinematics from a URDF or MuJoCo XML file - :param model_path: the path of the URDF or MuJoCo XML file :param joint_value: the value of the joints :param export_link: the name of the end effector link @@ -28,4 +29,4 @@ def get_fk_from_model(model_path: str, joint_value, export_link, verbose=False): chain = build_chain_from_model(model_path, verbose) pose, ret = get_fk_from_chain(chain, joint_value, export_link) - return pose, ret + return pose, ret \ No newline at end of file diff --git a/rofunc/utils/robolab/kinematics/ik.py b/rofunc/utils/robolab/kinematics/ik.py index ed5ea61d8..72d01b435 100644 --- a/rofunc/utils/robolab/kinematics/ik.py +++ b/rofunc/utils/robolab/kinematics/ik.py @@ -1,58 +1,59 @@ +from typing import Union, List, Tuple import torch from rofunc.utils.robolab.coord import convert_ori_format, convert_quat_order from rofunc.utils.robolab.kinematics.pytorch_kinematics_utils import build_chain_from_model -def get_ik_from_chain(chain, pos, rot, device): +def get_ik_from_chain(chain, goal_pose: Union[torch.Tensor, None, List, Tuple], device, goal_in_rob_tf: bool = True, + robot_pose: Union[torch.Tensor, None, List, Tuple] = None, cur_configs=None, + num_retries: int = 10): """ Get the inverse kinematics from a serial chain - :param chain: only the serial chain is supported - :param pos: the position of the export_link - :param rot: the rotation of the export_link + :param goal_pose: the pose of the export ee link :param device: the device to run the computation + :param goal_in_rob_tf: whether the goal pose is in the robot base frame + :param robot_pose: the pose of the robot base frame + :param cur_configs: let the ik solver retry from these configurations + :param num_retries: the number of retries :return: """ import pytorch_kinematics as pk - rob_tf = pk.Transform3d(pos=pos, rot=rot, device=device) + goal_pos = goal_pose[:3] + goal_rot = goal_pose[3:] + goal_tf = pk.Transform3d(pos=goal_pos, rot=goal_rot, device=device) + if not goal_in_rob_tf: + assert robot_pose is not None, "The robot pose must be provided if the goal pose is not in the robot base frame" + robot_pos = robot_pose[:3] + robot_rot = robot_pose[3:] + rob_tf = pk.Transform3d(pos=robot_pos, rot=robot_rot, device=device) + goal_tf = rob_tf.inverse().compose(goal_tf) # get robot joint limits lim = torch.tensor(chain.get_joint_limits(), device=device) - cur_q = torch.rand(7, device=device) * (lim[1] - lim[0]) + lim[0] - goal_q = cur_q.unsqueeze(0).repeat(1, 1) - - goal_in_rob_frame_tf = chain.forward_kinematics(goal_q) + if cur_configs is not None: + cur_configs = torch.tensor(cur_configs, device=device) # create the IK object # see the constructor for more options and their explanations, such as convergence tolerances - # ik = PseudoInverseIK(chain, max_iterations=30, num_retries=10, - # joint_limits=lim.T, - # early_stopping_any_converged=True, - # early_stopping_no_improvement="all", - # retry_configs=cur_q.reshape(1, -1), - # # line_search=pk.BacktrackingLineSearch(max_lr=0.2), - # debug=False, - # lr=0.2) - ik = pk.PseudoInverseIK(chain, max_iterations=30, num_retries=10, + ik = pk.PseudoInverseIK(chain, max_iterations=30, retry_configs=cur_configs, num_retries=num_retries, joint_limits=lim.T, early_stopping_any_converged=True, early_stopping_no_improvement="all", - # line_search=pk.BacktrackingLineSearch(max_lr=0.2), debug=False, lr=0.2) # solve IK - sol = ik.solve(goal_in_rob_frame_tf) + sol = ik.solve(goal_tf) return sol def get_ik_from_model(model_path: str, pose: torch.Tensor, device, export_link, verbose=False): """ Get the inverse kinematics from a URDF or MuJoCo XML file - :param model_path: the path of the URDF or MuJoCo XML file :param pose: the pose of the end effector, 7D vector with the first 3 elements as position and the last 4 elements as rotation :param device: the device to run the computation @@ -72,7 +73,7 @@ def get_ik_from_model(model_path: str, pose: torch.Tensor, device, export_link, if __name__ == '__main__': - model_path = "/home/ubuntu/Github/Xianova_Robotics/Rofunc-secret/rofunc/simulator/assets/urdf/franka_description/robots/franka_panda.urdf" + model_path = "./simulator/assets/urdf/franka_description/robots/franka_panda.urdf" # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cpu") @@ -125,4 +126,4 @@ def get_ik_from_model(model_path: str, pose: torch.Tensor, device, export_link, # p.removeBody(goalId) while True: - p.stepSimulation() + p.stepSimulation() \ No newline at end of file diff --git a/rofunc/utils/robolab/kinematics/kinpy_utils.py b/rofunc/utils/robolab/kinematics/kinpy_utils.py index 5a2ab5a74..c5f6d3687 100644 --- a/rofunc/utils/robolab/kinematics/kinpy_utils.py +++ b/rofunc/utils/robolab/kinematics/kinpy_utils.py @@ -9,7 +9,6 @@ def build_chain_from_model(model_path: str, verbose=False): """ Build a serial chain from a URDF or MuJoCo XML file - :param model_path: the path of the URDF or MuJoCo XML file :param verbose: whether to print the chain :return: robot kinematics chain @@ -27,4 +26,4 @@ def build_chain_from_model(model_path: str, verbose=False): print(chain) beauty_print("Robot joints:") print(chain.get_joint_parameter_names()) - return chain + return chain \ No newline at end of file diff --git a/rofunc/utils/robolab/kinematics/pytorch_kinematics_utils.py b/rofunc/utils/robolab/kinematics/pytorch_kinematics_utils.py index 598006c23..14f35b831 100644 --- a/rofunc/utils/robolab/kinematics/pytorch_kinematics_utils.py +++ b/rofunc/utils/robolab/kinematics/pytorch_kinematics_utils.py @@ -3,22 +3,20 @@ from typing import Optional, Callable from typing import Union -from rofunc.utils.oslab.path import check_package_exist from rofunc.utils.logger.beauty_logger import beauty_print +from rofunc.utils.oslab.path import check_package_exist check_package_exist("pytorch_kinematics") import mujoco -import pytorch_kinematics.transforms as tf import torch from matplotlib import pyplot as plt, cm as cm -from mujoco._structs import _MjModelBodyViews as MjModelBodyViews -from pytorch_kinematics import chain -from pytorch_kinematics import frame from pytorch_kinematics.chain import SerialChain from pytorch_kinematics.transforms import Transform3d from pytorch_kinematics.transforms import rotation_conversions +from rofunc.utils.robolab.formatter.urdf import build_chain_from_urdf +from rofunc.utils.robolab.formatter.mjcf import build_chain_from_mjcf # Converts from MuJoCo joint types to pytorch_kinematics joint types JOINT_TYPE_MAP = { mujoco.mjtJoint.mjJNT_HINGE: 'revolute', @@ -26,120 +24,29 @@ } -def body_to_geoms(m: mujoco.MjModel, body: MjModelBodyViews): - # Find all geoms which have body as parent - visuals = [] - for geom_id in range(m.ngeom): - geom = m.geom(geom_id) - if geom.bodyid == body.id: - visuals.append(frame.Visual(offset=tf.Transform3d(rot=geom.quat, pos=geom.pos), geom_type=geom.type, - geom_param=geom.size)) - return visuals - - -def _build_chain_recurse(m, parent_frame, parent_body): - parent_frame.link.visuals = body_to_geoms(m, parent_body) - # iterate through all bodies that are children of parent_body - for body_id in range(m.nbody): - body = m.body(body_id) - if body.parentid == parent_body.id and body_id != parent_body.id: - n_joints = body.jntnum - if n_joints > 1: - raise ValueError("composite joints not supported (could implement this if needed)") - if n_joints == 1: - # Find the joint for this body, again assuming there's only one joint per body. - joint = m.joint(body.jntadr[0]) - joint_offset = tf.Transform3d(pos=joint.pos) - child_joint = frame.Joint(joint.name, offset=joint_offset, axis=joint.axis, - joint_type=JOINT_TYPE_MAP[joint.type[0]], - limits=(joint.range[0], joint.range[1])) - else: - child_joint = frame.Joint(body.name + "_fixed_joint") - child_link = frame.Link(body.name, offset=tf.Transform3d(rot=body.quat, pos=body.pos)) - child_frame = frame.Frame(name=body.name, link=child_link, joint=child_joint) - parent_frame.children = parent_frame.children + [child_frame, ] - _build_chain_recurse(m, child_frame, body) - - # iterate through all sites that are children of parent_body - for site_id in range(m.nsite): - site = m.site(site_id) - if site.bodyid == parent_body.id: - site_link = frame.Link(site.name, offset=tf.Transform3d(rot=site.quat, pos=site.pos)) - site_frame = frame.Frame(name=site.name, link=site_link) - parent_frame.children = parent_frame.children + [site_frame, ] - - -def build_chain_from_mjcf(path, body: Union[None, str, int] = None): - """ - Build a Chain object from MJCF path. - - :param path: the path of the MJCF file - :param body: the name or index of the body to use as the root of the chain. If None, body idx=0 is used. - :return: Chain object created from MJCF - """ - - # import xml.etree.ElementTree as ET - # root = ET.parse(path).getroot() - # - # ASSETS = dict() - # mesh_dir = root.find("compiler").attrib["meshdir"] - # for asset in root.findall("asset"): - # for mesh in asset.findall("mesh"): - # filename = mesh.attrib["file"] - # with open(os.path.join(os.path.dirname(path), mesh_dir, filename), 'rb') as f: - # ASSETS[filename] = f.read() - - # m = mujoco.MjModel.from_xml_string(open(path).read(), assets=ASSETS) - m = mujoco.MjModel.from_xml_path(path) - if body is None: - root_body = m.body(0) - else: - root_body = m.body(body) - root_frame = frame.Frame(root_body.name, - link=frame.Link(root_body.name, - offset=tf.Transform3d(rot=root_body.quat, pos=root_body.pos)), - joint=frame.Joint()) - _build_chain_recurse(m, root_frame, root_body) - return chain.Chain(root_frame) - - -def build_serial_chain_from_mjcf(data, end_link_name, root_link_name=""): - """ - Build a SerialChain object from MJCF data. - - :param data: MJCF string data - :param end_link_name: the name of the link that is the end effector - :param root_link_name: the name of the root link - :return: SerialChain object created from MJCF - """ - mjcf_chain = build_chain_from_mjcf(data) - serial_chain = chain.SerialChain(mjcf_chain, end_link_name, "" if root_link_name == "" else root_link_name) - return serial_chain - - def build_chain_from_model(model_path: str, verbose=False): """ Build a serial chain from a URDF or MuJoCo XML file - :param model_path: the path of the URDF or MuJoCo XML file :param verbose: whether to print the chain :return: robot kinematics chain """ check_package_exist("pytorch_kinematics") - import pytorch_kinematics as pk if model_path.endswith(".urdf"): - chain = pk.build_chain_from_urdf(open(model_path).read()) + chain = build_chain_from_urdf(open(model_path).read()) elif model_path.endswith(".xml"): chain = build_chain_from_mjcf(model_path) else: raise ValueError("Invalid model path") if verbose: - beauty_print("Robot chain:") + beauty_print(f"Robot chain:") print(chain) - beauty_print("Robot joints:") + beauty_print(f"Robot joints: ({len(chain.get_joint_parameter_names())})") print(chain.get_joint_parameter_names()) + beauty_print(f"Robot joints frame name") + print(chain.get_joint_parent_frame_names()) return chain @@ -372,7 +279,6 @@ def solve(self, target_poses: Transform3d) -> IKSolution: def delta_pose(m: torch.tensor, target_pos, target_wxyz): """ Determine the error in position and rotation between the given poses and the target poses - :param m: (N x M x 4 x 4) tensor of homogenous transforms :param target_pos: :param target_wxyz: target orientation represented in unit quaternion @@ -567,4 +473,4 @@ def compute_dq(self, J, dx): # dq = J^T (JJ^T + lambda^2I)^-1 dx dq = total @ dx - return dq + return dq \ No newline at end of file diff --git a/rofunc/utils/robolab/kinematics/robot_class.py b/rofunc/utils/robolab/kinematics/robot_class.py index e1a3facdb..b9c37d233 100644 --- a/rofunc/utils/robolab/kinematics/robot_class.py +++ b/rofunc/utils/robolab/kinematics/robot_class.py @@ -1,13 +1,26 @@ -from rofunc.utils.robolab.kinematics.fk import get_fk_from_chain +import os +import numpy as np +import torch +import trimesh +from typing import Union, List, Tuple + +from rofunc.utils.logger.beauty_logger import beauty_print from rofunc.utils.robolab.coord import convert_ori_format, convert_quat_order, homo_matrix_from_quat_tensor +from rofunc.utils.robolab.formatter.mjcf_parser.mjcf import MJCF +from rofunc.utils.robolab.formatter.urdf_parser.urdf import URDF +from rofunc.utils.robolab.kinematics.fk import get_fk_from_chain +from rofunc.utils.robolab.kinematics.ik import get_ik_from_chain class RobotModel: - def __init__(self, model_path: str, solve_engine: str = "pytorch_kinematics", device="cpu", verbose=False): + def __init__(self, model_path: str, robot_pose=None, + solve_engine: str = "pytorch_kinematics", device="cpu", + verbose=False): """ Initialize a robot model from a URDF or MuJoCo XML file :param model_path: the path of the URDF or MuJoCo XML file + :param robot_pose: initial pose of robot base link, [x, y, z, qx, qy, qz, qw] :param solve_engine: the engine to solve the kinematics, ["pytorch_kinematics", "kinpy", "all"] :param device: the device to run the computation :param verbose: whether to print the chain @@ -15,47 +28,333 @@ def __init__(self, model_path: str, solve_engine: str = "pytorch_kinematics", de assert solve_engine in ["pytorch_kinematics", "kinpy"], "Unsupported solve engine." self.solve_engine = solve_engine self.device = device + self.verbose = verbose + self.robot_pose = robot_pose if robot_pose else [0, 0, 0, 0, 0, 0, 1] + self.robot_pos = self.robot_pose[:3] + self.robot_rot = self.robot_pose[3:] + + self.robot_model_path = model_path + self.mesh_dir = os.path.join(os.path.dirname(model_path), "meshes") + + self._load_model() + self._load_joint_info() + self._load_mesh_info() + self._load_link_info() + def _load_model(self): + """Loads the kinematic chain and robot model (URDF or MJCF).""" if self.solve_engine == "pytorch_kinematics": from rofunc.utils.robolab.kinematics import pytorch_kinematics_utils as pk_utils - self.chain = pk_utils.build_chain_from_model(model_path, verbose) + self.chain = pk_utils.build_chain_from_model(self.robot_model_path, self.verbose) elif self.solve_engine == "kinpy": from rofunc.utils.robolab.kinematics import kinpy_utils as kp_utils - self.chain = kp_utils.build_chain_from_model(model_path, verbose) + self.chain = kp_utils.build_chain_from_model(self.robot_model_path, self.verbose) + + if self.robot_model_path.endswith('.urdf'): + self.robot_model = URDF.from_xml_file(self.robot_model_path) + elif self.robot_model_path.endswith('.xml'): + self.robot_model = MJCF(self.robot_model_path) + else: + raise ValueError("Unsupported model file format.") + + def _load_joint_info(self): + """Loads joint information.""" + self.joint_list = self.get_joint_list() + self.num_joint = len(self.joint_list) + self.joint_limit_max = self.chain.high.to(self.device) + self.joint_limit_min = self.chain.low.to(self.device) + + def _load_mesh_info(self): + """Loads mesh information for the robot.""" + self.link_mesh_map = self.get_link_mesh_map() + self.link_meshname_map = self.get_link_meshname_map() + self.meshes, self.simple_shapes = self.load_meshes() + self.robot_faces = [val[1] for val in self.meshes.values()] + self.num_vertices_per_part = [val[0].shape[0] for val in self.meshes.values()] + self.meshname_mesh = {key: val[0] for key, val in self.meshes.items()} + self.meshname_mesh_normal = {key: val[-1] for key, val in self.meshes.items()} + + def _load_link_info(self): + """Loads link information including virtual and real links.""" + self.link_virtual_map, self.inverse_link_virtual_map = self.get_link_virtual_map() + self.real_link = self.get_real_link_list() + self.all_link = self.get_link_list() + + def show_chain(self): + beauty_print("Robot chain:") + print(self.chain) + + def convert_to_serial_chain(self, export_link): + import pytorch_kinematics as pk + self.serial_chain = pk.SerialChain(self.chain, export_link) + + def set_init_pose(self, robot_pose): + self.robot_pose = robot_pose + self.robot_pos = robot_pose[:3] + self.robot_rot = robot_pose[3:] + + def load_meshes(self): + """ + Load all meshes and store them in a dictionary. Handles both complex meshes and simple shapes. + + :return: A dictionary where keys are mesh names and values are mesh data, and a dictionary for simple shapes. + """ + meshes = {} + simple_shapes = {} # 用于保存简单形状信息 + for link_name, mesh_dict in self.link_mesh_map.items(): + for geom_name, mesh_info in mesh_dict.items(): + if mesh_info['type'] == 'mesh': + # 处理复杂的网格 + mesh_file = mesh_info['params']['mesh_path'] + name = mesh_info['params']['name'] + mesh = trimesh.load(mesh_file) + temp = torch.ones(mesh.vertices.shape[0], 1).float().to(self.device) + + vertices = torch.cat((torch.FloatTensor(np.array(mesh.vertices)), temp), dim=-1).to(self.device) + normals = torch.cat((torch.FloatTensor(np.array(mesh.vertex_normals)), temp), dim=-1).to(self.device) + + meshes[name] = [vertices, mesh.faces, normals] + else: + # 处理简单几何形状,直接保存形状信息 + simple_shapes[geom_name] = mesh_info + + return meshes, simple_shapes + + def get_joint_list(self): + return self.chain.get_joint_parameter_names() + + def get_link_list(self): + if self.solve_engine == "pytorch_kinematics": + return self.chain.get_link_names() + else: + raise ValueError("kinpy does not support get_link_names() method.") + + def get_link_virtual_map(self): + """ + :return: {link_body_name: [virtual_link_0, virtual_link_1, ...]} + """ + all_links = self.get_link_list() + link_virtual_map = {} + for link in all_links: + if "world" in link: + continue + if "_0" in link or "_1" in link or "_2" in link: + link_name = link.split("_")[0] + if link_name not in link_virtual_map: + link_virtual_map[link_name] = [] + link_virtual_map[link_name].append(link) + else: + link_virtual_map[link] = [link] + + inverse_link_virtual_map = {v: k for k, vs in link_virtual_map.items() for v in vs} + return link_virtual_map, inverse_link_virtual_map + + def get_real_link_list(self): + """ + :return: [real_link_0, real_link_1, ...] + """ + return list(self.link_virtual_map.keys()) + + def get_link_mesh_map(self): + """ + Get the map of link and its corresponding geometries from the robot model file (either URDF or MJCF). + + :return: {link_body_name: {geom_name: {'type': geom_type, 'params': geom_specific_parameters}}} + + If the robot model is a URDF file, it will attempt to link the geometry's mesh paths. The URDF format relies on + external mesh files, and this function assumes that any `.obj` mesh files are converted to `.stl` files. + + If the robot model is an MJCF file (which has a `.xml` extension), it uses the MJCF-specific link-mesh mapping + generated by the parser and processes different geometry types, including meshes, spheres, cylinders, boxes, + and capsules. - def get_fk(self, joint_value, export_link_name): + For each geometry type: + - 'mesh': It maps the geometry name to its corresponding mesh file path. + - 'sphere': It maps the geometry name to a dictionary with the sphere radius and position. + - 'cylinder': It maps the geometry name to a dictionary with the cylinder's radius, height, and position. + - 'box': It maps the geometry name to a dictionary with the box's extents (x, y, z) and position. + - 'capsule': It maps the geometry name to a dictionary with the capsule's radius, height, and start/end positions. + """ + if self.robot_model_path.endswith('.urdf'): + # TODO: urdf has some problems + link_mesh_map = {} + for link in link_mesh_map: + mesh_path = link_mesh_map[link].collision.geometry.filename.replace(".obj", ".stl") + link_mesh_map[link] = os.path.join(os.path.dirname(self.robot_model_path), mesh_path) + elif self.robot_model_path.endswith('.xml'): + link_mesh_map = self.robot_model.link_mesh_map + else: + raise ValueError("Unsupported model file.") + return link_mesh_map + + def get_link_meshname_map(self): + """ + :return: {link_body_name: [mesh_name]} + """ + link_meshname_map = {} + for link, geoms in self.link_mesh_map.items(): + link_meshname_map[link] = [] + for geom in geoms: + if self.link_mesh_map[link][geom]['type'] == 'mesh': + link_meshname_map[link].append(self.link_mesh_map[link][geom]['params']['name']) + return link_meshname_map + + def get_robot_mesh(self, vertices_list, faces): + assert len(vertices_list) == len(faces), "The number of vertices and faces should be the same." + robot_mesh = [trimesh.Trimesh(verts, face) for verts, face in zip(vertices_list, faces)] + return robot_mesh + + def get_forward_robot_mesh(self, joint_value, base_trans=None): + """ + Transform the robot mesh according to the joint values and the base pose + :param joint_value: the joint values, [batch_size, num_joint] + :param base_trans: transformation matrix of the base pose, [batch_size, 4, 4] + :return: + """ + batch_size = joint_value.size()[0] + outputs = self.forward(joint_value, base_trans) + vertices_list = [[outputs[i][j].detach().cpu().numpy() for i in range(int(len(outputs) / 2))] for j in + range(batch_size)] + mesh = [self.get_robot_mesh(vertices, self.robot_faces) for vertices in vertices_list] + return mesh + + def forward(self, joint_value, base_trans=None): + """ + Transform the robot mesh according to the joint values and the base pose + + :param joint_value: the joint values, [batch_size, num_joint] + :param base_trans: transformation matrix of the base pose, [batch_size, 4, 4] + :return: + """ + batch_size = joint_value.shape[0] + trans_dict = self.get_trans_dict(joint_value, base_trans) + meshname_link_map = {} + for link, meshnames in self.link_meshname_map.items(): + for meshname in meshnames: + meshname_link_map[meshname] = link + + ret_vertices, ret_normals = [], [] + for mesh_name, mesh in self.meshname_mesh.items(): + link_vertices = self.meshname_mesh[mesh_name].repeat(batch_size, 1, 1) + link_normals = self.meshname_mesh_normal[mesh_name].repeat(batch_size, 1, 1) + + if 'base' not in meshname_link_map[mesh_name]: + link_name = meshname_link_map[mesh_name] + related_link = [key for key in trans_dict.keys() if link_name in key][-1] + link_vertices = torch.matmul(trans_dict[related_link], link_vertices.transpose(2, 1)).transpose(1, 2)[:, :, :3] + link_normals = torch.matmul(trans_dict[related_link], link_normals.transpose(2, 1)).transpose(1, 2)[:, :, :3] + else: + if base_trans is not None: + link_vertices = torch.matmul(base_trans, link_vertices.transpose(2, 1)).transpose(1, 2)[:, :, :3] + link_normals = torch.matmul(base_trans, link_normals.transpose(2, 1)).transpose(1, 2)[:, :, :3] + else: + link_vertices = link_vertices[:, :, :3] + link_normals = link_normals[:, :, :3] + ret_vertices.append(link_vertices) + ret_normals.append(link_normals) + return ret_vertices + ret_normals + + def get_fk(self, joint_value: List, export_link=None): """ Get the forward kinematics from a chain - :param joint_value: - :param export_link_name: + :param joint_value: both single and batch input are supported + :param export_link: the name of the export link :return: the position, rotation of the end effector, and the transformation matrices of all links """ + joint_value = self._prepare_joint_value(joint_value) + batch_size = joint_value.size()[0] + if self.solve_engine == "pytorch_kinematics": - pose, ret = get_fk_from_chain(self.chain, joint_value, export_link_name) + return self._pytorch_fk(joint_value, export_link) + elif self.solve_engine == "kinpy": + return self._kinpy_fk(joint_value, export_link, batch_size) + + def _prepare_joint_value(self, joint_value: List): + """Helper to prepare joint values for FK/IK.""" + joint_value = torch.tensor(joint_value, dtype=torch.float32).to(self.device) + if len(joint_value.size()) == 1: + joint_value = joint_value.unsqueeze(0) + return joint_value + + def _pytorch_fk(self, joint_value, export_link): + """Helper function for PyTorch kinematics FK.""" + joint_value_dict = {joint: joint_value[:, i] for i, joint in enumerate(self.joint_list)} + pose, ret = get_fk_from_chain(self.chain, joint_value_dict, export_link) + + if pose is not None: m = pose.get_matrix() pos = m[:, :3, 3] rot = convert_ori_format(m[:, :3, :3], "mat", "quat") return pos, rot, ret - elif self.solve_engine == "kinpy": - pose, ret = get_fk_from_chain(self.chain, joint_value, export_link_name) - pos = pose.pos - rot = pose.rot - rot = convert_quat_order(rot, "wxyz", "xyzw") - return pos, rot, ret + return None, None, ret + + def _kinpy_fk(self, joint_value, export_link, batch_size): + """Helper function for KinPy kinematics FK.""" + pos_batch, rot_batch = [], [] + for batch in range(batch_size): + joint_value_dict = {joint: joint_value[batch, i] for i, joint in enumerate(self.joint_list)} + pose, ret = get_fk_from_chain(self.chain, joint_value_dict, export_link) + if pose is not None: + pos_batch.append(pose.pos) + rot_batch.append(convert_quat_order(pose.rot, "wxyz", "xyzw")) + return torch.tensor(pos_batch), torch.tensor(rot_batch), ret + + def get_jacobian(self, joint_value: List, export_link: str, locations=None): + """ + Get the jacobian of a chain + + :param joint_value: the joint values, [batch_size, num_joint] + :param export_link: the name of the export link + :param locations: the locations offset from the export link + :return: + """ + self.convert_to_serial_chain(export_link=export_link) + assert self.solve_engine == "pytorch_kinematics", "kinpy does not support get_jacobian() method." + J = self.serial_chain.jacobian(joint_value, locations=locations) + return J + + def get_trans_dict(self, joint_value: List, base_trans: Union[None, torch.Tensor] = None) -> dict: + """ + Get the transformation matrices of all links + + :param joint_value: the joint values, [batch_size, num_joint] + :param base_trans: transformation matrix of the base pose, [batch_size, 4, 4] + :return: A dictionary where the keys are link names and the values are transformation matrices. + """ + _, _, ret = self.get_fk(joint_value) + trans_dict = {} + for link in self.all_link: + if "world" in link: + continue + val = ret[link] + homo_matrix = val.get_matrix().to(self.device) + if base_trans is not None: + homo_matrix = torch.matmul(base_trans, homo_matrix) + + real_link = self.inverse_link_virtual_map[link] + trans_dict[real_link] = homo_matrix + + return trans_dict - def get_ik(self, ee_pose, export_link_name): + def get_ik(self, ee_pose: Union[torch.Tensor, None, List, Tuple], export_link, goal_in_rob_tf: bool = True, + cur_configs=None, num_retries=10): """ Get the inverse kinematics from a chain :param ee_pose: the pose of the end effector, 7D vector with the first 3 elements as position and the last 4 elements as rotation - :param export_link_name: + :param export_link: the name of the export link + :param goal_in_rob_tf: whether the goal pose is in the robot base frame + :param cur_configs: let the ik solver retry from these configurations + :param num_retries: the number of retries :return: the joint values """ + self.convert_to_serial_chain(export_link) if self.solve_engine == "pytorch_kinematics": - return get_ik_from_chain(self.chain, ee_pose[:3], ee_pose[3:], self.device) + return get_ik_from_chain(self.serial_chain, ee_pose, self.device, goal_in_rob_tf=goal_in_rob_tf, + robot_pose=self.robot_pose, cur_configs=cur_configs, num_retries=num_retries) elif self.solve_engine == "kinpy": import kinpy as kp - self.serial_chain = kp.chain.SerialChain(self.chain, export_link_name) + self.serial_chain = kp.chain.SerialChain(self.chain, export_link) homo_matrix = homo_matrix_from_quat_tensor(ee_pose[3:], ee_pose[:3]) - return self.serial_chain.inverse_kinematics(homo_matrix) + return self.serial_chain.inverse_kinematics(homo_matrix) \ No newline at end of file diff --git a/setup.py b/setup.py index 3a2bb2aed..35022ba68 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,6 @@ from setuptools import setup, find_packages + from pathlib import Path this_directory = Path(__file__).parent @@ -24,7 +25,7 @@ include_package_data=True, extras_require=extras, install_requires=['cython==3.0.0a10', # for mujoco_py - 'setuptools==59.8.0', + 'setuptools', 'pandas', 'tqdm==4.65.0', 'pillow==9.5.0', @@ -47,7 +48,7 @@ 'dgl', 'trimesh==4.0.5', 'wandb==0.16.2'], - python_requires=">=3.7,<3.9", + python_requires=">=3.7,<3.11", keywords=['robotics', 'robot learning', 'learning from demonstration', 'reinforcement learning', 'robot manipulation'], license='MIT',