diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ca0e80a --- /dev/null +++ b/.gitignore @@ -0,0 +1,144 @@ + +*.json + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class +out/ +# C extensions +*.so +*.pkl + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +wandb/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +.vscode/* +.vscode/settings.json + + + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +venv/ +env.bak/ +venv.bak/ + + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + + # Pyre type checker +.pyre/ +checkpoints/ +data/* +output/ +log/ + +*.png +*.jpg +stageroom_train/ # tensorboard log +trials/ +output_videos/ +*.pkl +*.pt +new_trials +gh-md-toc \ No newline at end of file diff --git a/README.MD b/README.MD index 814d229..b85d1a9 100644 --- a/README.MD +++ b/README.MD @@ -1,7 +1,11 @@ -[Repo still under construction] + # Universal Humanoid Motion Representations for Physics-Based Control -Official implementation of ICLR 2023 spotlight paper: "Universal Humanoid Motion Representations for Physics-Based Control". +Official implementation of ICLR 2023 spotlight paper: "Universal Humanoid Motion Representations for Physics-Based Control". In this work, we develop a humanoi motion latent space that is low dimensional (32), high coverage (99.8% of AMASS motion), can speed up downstream task for hierarchical RL, and can be randomly sampled as a generative model. + +Our proposed Physics-based Universal motion Latent SpacE (PULSE) is akin to a foundation model for control where downstream tasks ranging from simple locomotion, complex terrain traversal, to free-form motion tracking can all reuse this representation + +This repo has a large amount of code overlap with [PHC](https://github.com/ZhengyiLuo/PHC); PULSE is more focused on the generative tasks. [[paper]](https://arxiv.org/abs/2310.04582) [[website]](https://www.zhengyiluo.com/PULSE/) @@ -9,3 +13,170 @@ Official implementation of ICLR 2023 spotlight paper: "Universal Humanoid Motion +## News 🚩 + +[March 19, 2024] Training and Evaluation code released. + +## TODOs + +- [ ] Add support for smplx/h (fingers!!!). + +- [ ] Add suppport for more downstream tasks. + +### Dependencies + +To create the environment, follow the following instructions: + +1. Create new conda environment and install pytroch: + + +``` +conda create -n isaac python=3.8 +conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvidia +pip install -r requirement.txt +``` + +2. Download and setup [Isaac Gym](https://developer.nvidia.com/isaac-gym). + + +3. [Optional if only inference] Download SMPL paramters from [SMPL](https://smpl.is.tue.mpg.de/). Put them in the `data/smpl` folder, unzip them into 'data/smpl' folder. Please download the v1.1.0 version, which contains the neutral humanoid. Rename the files `basicmodel_neutral_lbs_10_207_0_v1.1.0`, `basicmodel_m_lbs_10_207_0_v1.1.0.pkl`, `basicmodel_f_lbs_10_207_0_v1.1.0.pkl` to `SMPL_NEUTRAL.pkl`, `SMPL_MALE.pkl` and `SMPL_FEMALE.pkl`. Rename The file structure should look like this: + +``` + +|-- data + |-- smpl + |-- SMPL_FEMALE.pkl + |-- SMPL_NEUTRAL.pkl + |-- SMPL_MALE.pkl +``` + +4. Use the following script to download trained models and sample data. + +``` +bash download_data.sh +``` + +this wil download amass_isaac_standing_upright_slim.pkl, which is a standing still pose for testing. + +## Evaluation + +### Viewer Shortcuts + +| Keyboard | Function | +| ---- | --- | +| f | focus on humanoid | +| Right click + WASD | change view port | +| Shift + Right click + WASD | change view port fast | +| r | reset episode | +| i | start random sampling PULSE | +| j | apply large force to the humanoid | +| l | record screenshot, press again to stop recording| +| ; | cancel screen shot| +| m | cancel termination based on imitation | + +... more shortcut can be found in `phc/env/tasks/base_task.py` + +Notes on rendering: I am using pyvirtualdisplay to record the video such that you can see all humanoids at the same time (default function will only capture the first environment). You can disable it using the flag `no_virtual_display=True`. + + +### Run PULSE + +``` +python phc/run_hydra.py env.task=HumanoidImDistillGetup env=env_im_vae exp_name=pulse_vae_iclr robot.real_weight_porpotion_boxes=False learning=im_z_fit env.models=['output/HumanoidIm/phc_3/Humanoid_00258000.pth','output/HumanoidIm/phc_comp_3/Humanoid_00023501.pth'] env.motion_file=sample_data//amass_isaac_standing_upright_slim.pkl test=True env.num_envs=1 headless=False epoch=-1 +``` + +Press M (disable termination), and press I (start sampling), to see ramdomly sampled motion. + + +### Train Downstream Tasks + +(more coming soon!) + +Speed: +``` +python phc/run_hydra.py env.task=HumanoidSpeedZ env=env_im_vae exp_name=pulse_vae_iclr robot.real_weight_porpotion_boxes=False learning=im_z_fit env.models=['output/HumanoidIm/phc_3/Humanoid_00258000.pth','output/HumanoidIm/phc_comp_3/Humanoid_00023501.pth'] env.motion_file=sample_data//amass_isaac_standing_upright_slim.pkl test=True env.num_envs=1 headless=False epoch=-1 +``` + + + + + +## Training + + + +### Data Processing AMASS + +We train on a subset of the [AMASS](https://amass.is.tue.mpg.de/) dataset. + +For processing the AMASS, first, download the AMASS dataset from [AMASS](https://amass.is.tue.mpg.de/). Then, run the following script on the unzipped data: + + +After downloading AMASS, use the script `python scripts/data_process/convert_amass_data.py` + +### Training PULSE + +``` +python phc/run_hydra.py env.task=HumanoidImDistillGetup env=env_im_vae exp_name=pulse_vae robot.real_weight_porpotion_boxes=False learning=im_z_fit env.models=['output/HumanoidIm/phc_3/Humanoid_00258000.pth','output/HumanoidIm/phc_comp_3/Humanoid_00023501.pth'] env.motion_file=[insert data pkl] +``` + + +## Citation +If you find this work useful for your research, please cite our paper: +``` +@inproceedings{ +luo2024universal, +title={Universal Humanoid Motion Representations for Physics-Based Control}, +author={Zhengyi Luo and Jinkun Cao and Josh Merel and Alexander Winkler and Jing Huang and Kris M. Kitani and Weipeng Xu}, +booktitle={The Twelfth International Conference on Learning Representations}, +year={2024}, +url={https://openreview.net/forum?id=OrOd8PxOO2} +} +``` + +Also consider citing these prior works that are used in this project: + +``` + +@inproceedings{Luo2023PerpetualHC, + author={Zhengyi Luo and Jinkun Cao and Alexander W. Winkler and Kris Kitani and Weipeng Xu}, + title={Perpetual Humanoid Control for Real-time Simulated Avatars}, + booktitle={International Conference on Computer Vision (ICCV)}, + year={2023} +} + +@inproceedings{rempeluo2023tracepace, + author={Rempe, Davis and Luo, Zhengyi and Peng, Xue Bin and Yuan, Ye and Kitani, Kris and Kreis, Karsten and Fidler, Sanja and Litany, Or}, + title={Trace and Pace: Controllable Pedestrian Animation via Guided Trajectory Diffusion}, + booktitle={Conference on Computer Vision and Pattern Recognition (CVPR)}, + year={2023} +} + +@inproceedings{Luo2022EmbodiedSH, + title={Embodied Scene-aware Human Pose Estimation}, + author={Zhengyi Luo and Shun Iwase and Ye Yuan and Kris Kitani}, + booktitle={Advances in Neural Information Processing Systems}, + year={2022} +} + +@inproceedings{Luo2021DynamicsRegulatedKP, + title={Dynamics-Regulated Kinematic Policy for Egocentric Pose Estimation}, + author={Zhengyi Luo and Ryo Hachiuma and Ye Yuan and Kris Kitani}, + booktitle={Advances in Neural Information Processing Systems}, + year={2021} +} + +``` + +## References +This repository is built on top of the following amazing repositories: +* Main code framework is from: [IsaacGymEnvs](https://github.com/NVIDIA-Omniverse/IsaacGymEnvs) +* Large amount of code is from: [PHC](https://github.com/ZhengyiLuo/PHC) +* SMPL models and layer is from: [SMPL-X model](https://github.com/vchoutas/smplx) + +Please follow the lisence of the above repositories for usage. + + + + + diff --git a/download_data.sh b/download_data.sh new file mode 100644 index 0000000..1fbda90 --- /dev/null +++ b/download_data.sh @@ -0,0 +1,9 @@ +mkdir sample_data +mkdir -p output output/HumanoidIm/ output/HumanoidIm/phc_3 output/HumanoidIm/phc_comp_3 output/HumanoidIm/pulse_vae_iclr +gdown https://drive.google.com/uc?id=1bLp4SNIZROMB7Sxgt0Mh4-4BLOPGV9_U -O sample_data/ # filtered shapes from AMASS +gdown https://drive.google.com/uc?id=1arpCsue3Knqttj75Nt9Mwo32TKC4TYDx -O sample_data/ # all shapes from AMASS +gdown https://drive.google.com/uc?id=1fFauJE0W0nJfihUvjViq9OzmFfHo_rq0 -O sample_data/ # sample standing neutral data. +gdown https://drive.google.com/uc?id=1uzFkT2s_zVdnAohPWHOLFcyRDq372Fmc -O sample_data/ # amass_occlusion_v3 +gdown https://drive.google.com/uc?id=1ztyljPCzeRwQEJqtlME90gZwMXLhGTOQ -O output/HumanoidIm/pulse_vae_iclr/ +gdown https://drive.google.com/uc?id=1JbK9Vzo1bEY8Pig6D92yAUv8l-1rKWo3 -O output/HumanoidIm/phc_comp_3/ +gdown https://drive.google.com/uc?id=1pS1bRUbKFDp6o6ZJ9XSFaBlXv6_PrhNc -O output/HumanoidIm/phc_3/ \ No newline at end of file diff --git a/phc/__init__.py b/phc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/Chest.stl b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/Chest.stl new file mode 100644 index 0000000..27d555a Binary files /dev/null and b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/Chest.stl differ diff --git a/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/Head.stl b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/Head.stl new file mode 100644 index 0000000..7ab243a Binary files /dev/null and b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/Head.stl differ diff --git a/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/L_Ankle.stl b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/L_Ankle.stl new file mode 100644 index 0000000..5d12a7a Binary files /dev/null and b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/L_Ankle.stl differ diff --git a/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/L_Elbow.stl b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/L_Elbow.stl new file mode 100644 index 0000000..e738d62 Binary files /dev/null and b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/L_Elbow.stl differ diff --git a/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/L_Hand.stl b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/L_Hand.stl new file mode 100644 index 0000000..83fa389 Binary files /dev/null and b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/L_Hand.stl differ diff --git a/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/L_Hip.stl b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/L_Hip.stl new file mode 100644 index 0000000..a2ca9b2 Binary files /dev/null and b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/L_Hip.stl differ diff --git a/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/L_Knee.stl b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/L_Knee.stl new file mode 100644 index 0000000..4a1df90 Binary files /dev/null and b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/L_Knee.stl differ diff --git a/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/L_Shoulder.stl b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/L_Shoulder.stl new file mode 100644 index 0000000..594e7e9 Binary files /dev/null and b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/L_Shoulder.stl differ diff --git a/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/L_Thorax.stl b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/L_Thorax.stl new file mode 100644 index 0000000..2189867 Binary files /dev/null and b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/L_Thorax.stl differ diff --git a/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/L_Toe.stl b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/L_Toe.stl new file mode 100644 index 0000000..40b8e2c Binary files /dev/null and b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/L_Toe.stl differ diff --git a/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/L_Wrist.stl b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/L_Wrist.stl new file mode 100644 index 0000000..b99f3fb Binary files /dev/null and b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/L_Wrist.stl differ diff --git a/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/Neck.stl b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/Neck.stl new file mode 100644 index 0000000..9913d33 Binary files /dev/null and b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/Neck.stl differ diff --git a/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/Pelvis.stl b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/Pelvis.stl new file mode 100644 index 0000000..d8d46ae Binary files /dev/null and b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/Pelvis.stl differ diff --git a/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/R_Ankle.stl b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/R_Ankle.stl new file mode 100644 index 0000000..8964b31 Binary files /dev/null and b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/R_Ankle.stl differ diff --git a/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/R_Elbow.stl b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/R_Elbow.stl new file mode 100644 index 0000000..b869bbe Binary files /dev/null and b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/R_Elbow.stl differ diff --git a/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/R_Hand.stl b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/R_Hand.stl new file mode 100644 index 0000000..5252c25 Binary files /dev/null and b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/R_Hand.stl differ diff --git a/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/R_Hip.stl b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/R_Hip.stl new file mode 100644 index 0000000..81e9f7b Binary files /dev/null and b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/R_Hip.stl differ diff --git a/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/R_Knee.stl b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/R_Knee.stl new file mode 100644 index 0000000..72539b8 Binary files /dev/null and b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/R_Knee.stl differ diff --git a/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/R_Shoulder.stl b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/R_Shoulder.stl new file mode 100644 index 0000000..282b29e Binary files /dev/null and b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/R_Shoulder.stl differ diff --git a/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/R_Thorax.stl b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/R_Thorax.stl new file mode 100644 index 0000000..a44abd4 Binary files /dev/null and b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/R_Thorax.stl differ diff --git a/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/R_Toe.stl b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/R_Toe.stl new file mode 100644 index 0000000..8260e72 Binary files /dev/null and b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/R_Toe.stl differ diff --git a/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/R_Wrist.stl b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/R_Wrist.stl new file mode 100644 index 0000000..ab01a1e Binary files /dev/null and b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/R_Wrist.stl differ diff --git a/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/Spine.stl b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/Spine.stl new file mode 100644 index 0000000..3f16323 Binary files /dev/null and b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/Spine.stl differ diff --git a/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/Torso.stl b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/Torso.stl new file mode 100644 index 0000000..6c8d90d Binary files /dev/null and b/phc/data/assets/mesh/smpl/1c00fde5-abea-4340-b528-921965f3a020/geom/Torso.stl differ diff --git a/phc/data/assets/mjcf/humanoid_template_local.xml b/phc/data/assets/mjcf/humanoid_template_local.xml new file mode 100644 index 0000000..4fde895 --- /dev/null +++ b/phc/data/assets/mjcf/humanoid_template_local.xml @@ -0,0 +1,39 @@ + + + + + + diff --git a/phc/data/assets/mjcf/mesh_humanoid.xml b/phc/data/assets/mjcf/mesh_humanoid.xml new file mode 100644 index 0000000..559406c --- /dev/null +++ b/phc/data/assets/mjcf/mesh_humanoid.xml @@ -0,0 +1,295 @@ + + + + diff --git a/phc/data/assets/mjcf/smpl_0_humanoid.xml b/phc/data/assets/mjcf/smpl_0_humanoid.xml new file mode 100644 index 0000000..e336b8c --- /dev/null +++ b/phc/data/assets/mjcf/smpl_0_humanoid.xml @@ -0,0 +1,243 @@ + + + + diff --git a/phc/data/assets/mjcf/smpl_1_humanoid.xml b/phc/data/assets/mjcf/smpl_1_humanoid.xml new file mode 100644 index 0000000..7cf924f --- /dev/null +++ b/phc/data/assets/mjcf/smpl_1_humanoid.xml @@ -0,0 +1,243 @@ + + + + diff --git a/phc/data/assets/mjcf/smpl_2_humanoid.xml b/phc/data/assets/mjcf/smpl_2_humanoid.xml new file mode 100644 index 0000000..1d8faea --- /dev/null +++ b/phc/data/assets/mjcf/smpl_2_humanoid.xml @@ -0,0 +1,243 @@ + + + + diff --git a/phc/data/assets/mjcf/smpl_humanoid.xml b/phc/data/assets/mjcf/smpl_humanoid.xml new file mode 100644 index 0000000..95ad59b --- /dev/null +++ b/phc/data/assets/mjcf/smpl_humanoid.xml @@ -0,0 +1,243 @@ + + + + diff --git a/phc/data/assets/mjcf/smpl_humanoid_1.xml b/phc/data/assets/mjcf/smpl_humanoid_1.xml new file mode 100644 index 0000000..753e6fd --- /dev/null +++ b/phc/data/assets/mjcf/smpl_humanoid_1.xml @@ -0,0 +1,243 @@ + + + + diff --git a/phc/data/assets/urdf/ball_medium.urdf b/phc/data/assets/urdf/ball_medium.urdf new file mode 100644 index 0000000..a69e315 --- /dev/null +++ b/phc/data/assets/urdf/ball_medium.urdf @@ -0,0 +1,21 @@ + + + + + + + + + + + + + + + + + + + + + diff --git a/phc/data/assets/urdf/block_projectile.urdf b/phc/data/assets/urdf/block_projectile.urdf new file mode 100644 index 0000000..69596a4 --- /dev/null +++ b/phc/data/assets/urdf/block_projectile.urdf @@ -0,0 +1,21 @@ + + + + + + + + + + + + + + + + + + + + + diff --git a/phc/data/assets/urdf/block_projectile_large.urdf b/phc/data/assets/urdf/block_projectile_large.urdf new file mode 100644 index 0000000..4991422 --- /dev/null +++ b/phc/data/assets/urdf/block_projectile_large.urdf @@ -0,0 +1,21 @@ + + + + + + + + + + + + + + + + + + + + + diff --git a/phc/data/assets/urdf/capsule.urdf b/phc/data/assets/urdf/capsule.urdf new file mode 100644 index 0000000..adb734c --- /dev/null +++ b/phc/data/assets/urdf/capsule.urdf @@ -0,0 +1,20 @@ + + + + + + + + + + + + + + + + + + + + diff --git a/phc/data/assets/urdf/heading_marker.urdf b/phc/data/assets/urdf/heading_marker.urdf new file mode 100644 index 0000000..baa24f8 --- /dev/null +++ b/phc/data/assets/urdf/heading_marker.urdf @@ -0,0 +1,21 @@ + + + + + + + + + + + + + + + + + + + + + diff --git a/phc/data/assets/urdf/location_marker.urdf b/phc/data/assets/urdf/location_marker.urdf new file mode 100644 index 0000000..a34215c --- /dev/null +++ b/phc/data/assets/urdf/location_marker.urdf @@ -0,0 +1,21 @@ + + + + + + + + + + + + + + + + + + + + + diff --git a/phc/data/assets/urdf/traj_marker.urdf b/phc/data/assets/urdf/traj_marker.urdf new file mode 100644 index 0000000..0f51fd6 --- /dev/null +++ b/phc/data/assets/urdf/traj_marker.urdf @@ -0,0 +1,21 @@ + + + + + + + + + + + + + + + + + + + + + diff --git a/phc/data/assets/urdf/traj_marker_small.urdf b/phc/data/assets/urdf/traj_marker_small.urdf new file mode 100644 index 0000000..241a374 --- /dev/null +++ b/phc/data/assets/urdf/traj_marker_small.urdf @@ -0,0 +1,21 @@ + + + + + + + + + + + + + + + + + + + + + diff --git a/phc/data/cfg/config.yaml b/phc/data/cfg/config.yaml new file mode 100644 index 0000000..5207481 --- /dev/null +++ b/phc/data/cfg/config.yaml @@ -0,0 +1,47 @@ +defaults: + - _self_ + - env: env_im + - robot: smpl_humanoid + - learning: im + - sim: default_sim + +project_name: "PULSE" +notes: "Default Notes" +exp_name: &exp_name humanoid_smpl +headless: True +seed: 0 +no_log: False +resume_str: null +num_threads: 64 +test: False +output_path: output/HumanoidIm/${exp_name} +torch_deterministic: False +epoch: 0 +im_eval: False +horovod: False # Use horovod for multi-gpu training, have effect only with rl_games RL library +rl_device: "cuda:0" +device: "cuda" +device_id: 0 +metadata: false +play: ${test} +train: True + + +####### Testing Configs. ######## +server_mode: False +has_eval: True +no_virtual_display: False +render_o3d: False +debug: False +follow: False +add_proj: False +real_traj: False + +hydra: + job: + name: ${exp_name} + env_set: + OMP_NUM_THREADS: 1 + run: + dir: output/HumanoidIm/${exp_name} + diff --git a/phc/data/cfg/env/env_im.yaml b/phc/data/cfg/env/env_im.yaml new file mode 100644 index 0000000..68073f1 --- /dev/null +++ b/phc/data/cfg/env/env_im.yaml @@ -0,0 +1,52 @@ +# if given, will override the device setting in gym. +task: HumanoidIm +project_name: "PHC" +notes: "" +motion_file: "" +num_envs: 3072 +env_spacing: 5 +episode_length: 300 +is_flag_run: False +enable_debug_vis: False + +fut_tracks: False +self_obs_v: 1 +obs_v: 6 +auto_pmcp: False +auto_pmcp_soft: True + +cycle_motion: False +hard_negative: False +min_length: 5 + +kp_scale: 1 +power_reward: True + +shape_resampling_interval: 500 + +control_mode: "isaac_pd" +power_scale: 1.0 +controlFrequencyInv: 2 # 30 Hz +stateInit: "Random" +hybridInitProb: 0.5 +numAMPObsSteps: 10 + +local_root_obs: True +root_height_obs: True +key_bodies: ["R_Ankle", "L_Ankle", "R_Wrist", "L_Wrist"] +contact_bodies: ["R_Ankle", "L_Ankle", "R_Toe", "L_Toe"] +reset_bodies: ['Pelvis', 'L_Hip', 'L_Knee', 'R_Hip', 'R_Knee', 'Torso', 'Spine', 'Chest', 'Neck', 'Head', 'L_Thorax', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'L_Hand', 'R_Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'R_Hand'] +terminationHeight: 0.15 +enableEarlyTermination: True +terminationDistance: 0.25 + +### Fut config +numTrajSamples: 3 +trajSampleTimestepInv: 3 +enableTaskObs: True + + +plane: + staticFriction: 1.0 + dynamicFriction: 1.0 + restitution: 0.0 diff --git a/phc/data/cfg/env/env_im_getup_mcp.yaml b/phc/data/cfg/env/env_im_getup_mcp.yaml new file mode 100644 index 0000000..930c604 --- /dev/null +++ b/phc/data/cfg/env/env_im_getup_mcp.yaml @@ -0,0 +1,96 @@ +# if given, will override the device setting in gym. +task: HumanoidImMCPGetup +project_name: "PHC" +notes: "Progressive MCP without softmax, zero out far" + +motion_file: "" +num_envs: 1024 +env_spacing: 2 +episode_length: 300 +is_flag_run: False +enable_debug_vis: False + +sym_loss_coef: 1 +fut_tracks: False +obs_v: 6 + +######## PNN Configs ######## +has_pnn: True +fitting: True +num_prim: 4 +training_prim: 0 +actors_to_load: 4 +has_lateral: False +models: [] + +######## Getup Configs ######## +zero_out_far: True +zero_out_far_train: False +cycle_motion: False +getup_udpate_epoch: 78750 + +getup_schedule: True +recoverySteps: 90 +zero_out_far_steps: 90 +recoveryEpisodeProb: 0.5 +fallInitProb: 0.3 +hard_negative: False + +z_activation: "silu" + +kp_scale: 1 + +power_reward: True +power_coefficient: 0.00005 + +has_shape_obs: True +has_shape_obs_disc: True +has_shape_variation: True +shape_resampling_interval: 500 + +control_mode: "isaac_pd" +power_scale: 1.0 +controlFrequencyInv: 2 # 30 Hz +stateInit: "Random" +hybridInitProb: 0.5 +numAMPObsSteps: 10 + +local_root_obs: True +root_height_obs: True +key_bodies: ["R_Ankle", "L_Ankle", "R_Wrist", "L_Wrist"] +contact_bodies: ["R_Ankle", "L_Ankle", "R_Toe", "L_Toe"] +reset_bodies: ['Pelvis', 'L_Hip', 'L_Knee', 'R_Hip', 'R_Knee', 'Torso', 'Spine', 'Chest', 'Neck', 'Head', 'L_Thorax', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'L_Hand', 'R_Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'R_Hand'] +terminationHeight: 0.15 +enableEarlyTermination: True +terminationDistance: 0.25 + +### Fut config +numTrajSamples: 3 +trajSampleTimestepInv: 3 +enableTaskObs: True + +asset: + assetRoot: "/" + assetFileName: "mjcf/smpl_humanoid.xml" + +plane: + staticFriction: 1.0 + dynamicFriction: 1.0 + restitution: 0.0 + +sim: +substeps: 2 +physx: + num_threads: 4 + solver_type: 1 # 0: pgs, 1: tgs + num_position_iterations: 4 + num_velocity_iterations: 0 + contact_offset: 0.02 + rest_offset: 0.0 + bounce_threshold_velocity: 0.2 + max_depenetration_velocity: 10.0 + default_buffer_size_multiplier: 10.0 + +flex: + num_inner_iterations: 10 + warm_start: 0.25 diff --git a/phc/data/cfg/env/env_im_pnn.yaml b/phc/data/cfg/env/env_im_pnn.yaml new file mode 100644 index 0000000..bc1d7a9 --- /dev/null +++ b/phc/data/cfg/env/env_im_pnn.yaml @@ -0,0 +1,65 @@ +# if given, will override the device setting in gym. +task: HumanoidIm +project_name: "PHC" +notes: " " +motion_file: "" +num_envs: 3072 +env_spacing: 5 +episode_length: 300 +is_flag_run: False +enable_debug_vis: False + +fut_tracks: False +self_obs_v: 1 +obs_v: 6 +auto_pmcp: False +auto_pmcp_soft: True + + +has_pnn: True +fitting: False +num_prim: 3 +training_prim: 0 +actors_to_load: 0 +has_lateral: False +models: [] + +######## Getup Configs ######## +# zero_out_far: True +# zero_out_far_train: False +# getup_udpate_epoch: 78750 +# cycle_motion: True +# hard_negative: False +# min_length: 5 + +kp_scale: 1 +power_reward: True + +shape_resampling_interval: 500 + +control_mode: "isaac_pd" +power_scale: 1.0 +controlFrequencyInv: 2 # 30 Hz +stateInit: "Random" +hybridInitProb: 0.5 +numAMPObsSteps: 10 + +local_root_obs: True +root_height_obs: True +key_bodies: ["R_Ankle", "L_Ankle", "R_Wrist", "L_Wrist"] +contact_bodies: ["R_Ankle", "L_Ankle", "R_Toe", "L_Toe"] +reset_bodies: ['Pelvis', 'L_Hip', 'L_Knee', 'R_Hip', 'R_Knee', 'Torso', 'Spine', 'Chest', 'Neck', 'Head', 'L_Thorax', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'L_Hand', 'R_Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'R_Hand'] +terminationHeight: 0.15 +enableEarlyTermination: True +terminationDistance: 0.25 + +### Fut config +numTrajSamples: 3 +trajSampleTimestepInv: 3 +enableTaskObs: True + + +plane: + staticFriction: 1.0 + dynamicFriction: 1.0 + restitution: 0.0 diff --git a/phc/data/cfg/env/env_im_vae.yaml b/phc/data/cfg/env/env_im_vae.yaml new file mode 100644 index 0000000..5dd6a83 --- /dev/null +++ b/phc/data/cfg/env/env_im_vae.yaml @@ -0,0 +1,112 @@ +# if given, will override the device setting in gym. +task: HumanoidImDistillGetup +motion_file: "" +num_envs: 3072 +env_spacing: 5 + +# env_spacing: 0.5 +# divide_group: True + +episode_length: 300 +isFlagrun: False +enable_debug_vis: False + +fut_tracks: False +fut_tracks_dropout: False +obs_v: 6 +auto_pmcp: False +auto_pmcp_soft: True +eval_full: False + +embedding_norm: 1 +embedding_size: 32 +z_type: vae +use_vae_prior: True +use_ar1_prior: True +use_vae_clamped_prior: True +vae_var_clamp_max: 2 +kld_coefficient: 0.01 +kld_coefficient_min: 0.001 +kld_anneal: True +ar1_coefficient: 0.005 + + +######## Getup Features. ######## +getup_schedule: False +recoverySteps: 90 +# recoveryEpisodeProb: 0.5 +# fallInitProb: 0.3 +recoveryEpisodeProb: 0.3 +fallInitProb: 0.1 +getup_udpate_epoch: 0 +zero_out_far: False +zero_out_far_train: False +zero_out_far_steps: 90 +####### + +cycle_motion: True +hard_negative: False +fitting: True + +only_kin_loss: True +distill: True +save_kin_info: True +distill_z_model: False +z_read: False +distill_model_config: + has_pnn: True + num_prim: 3 + training_prim: 0 + actors_to_load: 0 + has_lateral: False + z_activation: "silu" + + fut_tracks_dropout: False + fut_tracks: False + trajSampleTimestepInv: 5 + numTrajSamples: 1 + + +models: ['output/dgx/phc_3/Humanoid_00258000.pth', 'output/dgx/phc_comp_3/Humanoid_00023501.pth'] + +control_mode: "isaac_pd" + +real_weight: True +kp_scale: 1 +remove_toe_im: False # For imitation +power_reward: True + +has_shape_obs: false +has_shape_obs_disc: false +has_shape_variation: False +shape_resampling_interval: 250 + +pdControl: True +power_scale: 1.0 +controlFrequencyInv: 2 # 30 Hz +stateInit: "Random" +hybridInitProb: 0.5 +numAMPObsSteps: 10 + +local_root_obs: True +rootHeightObs: True +key_bodies: ["R_Ankle", "L_Ankle", "R_Wrist", "L_Wrist"] +contact_bodies: ["R_Ankle", "L_Ankle", "R_Toe", "L_Toe"] +reset_bodies: ['Pelvis', 'L_Hip', 'L_Knee', 'R_Hip', 'R_Knee', 'Torso', 'Spine', 'Chest', 'Neck', 'Head', 'L_Thorax', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'L_Hand', 'R_Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'R_Hand'] +terminationHeight: 0.15 +enableEarlyTermination: True +terminationDistance: 0.25 + +### Fut config +numTrajSamples: 10 +trajSampleTimestepInv: 5 +enableTaskObs: True + +asset: + assetRoot: "/" + assetFileName: "mjcf/smpl_humanoid.xml" + +plane: + staticFriction: 1.0 + dynamicFriction: 1.0 + restitution: 0.0 diff --git a/phc/data/cfg/env/env_vr.yaml b/phc/data/cfg/env/env_vr.yaml new file mode 100644 index 0000000..6859d7e --- /dev/null +++ b/phc/data/cfg/env/env_vr.yaml @@ -0,0 +1,53 @@ +# if given, will override the device setting in gym. +task: HumanoidIm +project_name: "PHC" +notes: "VR modell, three point tracking" +motion_file: "" +num_envs: 3072 +env_spacing: 5 +episode_length: 300 +is_flag_run: False +enable_debug_vis: False + +fut_tracks: False +self_obs_v: 1 +obs_v: 6 +auto_pmcp: False +auto_pmcp_soft: True + +cycle_motion: False +hard_negative: False +min_length: 5 + +kp_scale: 1 +power_reward: True + +shape_resampling_interval: 500 + +control_mode: "isaac_pd" +power_scale: 1.0 +controlFrequencyInv: 2 # 30 Hz +stateInit: "Random" +hybridInitProb: 0.5 +numAMPObsSteps: 10 + +local_root_obs: True +root_height_obs: True +key_bodies: ["R_Ankle", "L_Ankle", "R_Wrist", "L_Wrist"] +contact_bodies: ["R_Ankle", "L_Ankle", "R_Toe", "L_Toe"] +reset_bodies: ["Head", "L_Hand", "R_Hand"] +trackBodies: ["Head", "L_Hand", "R_Hand"] +terminationHeight: 0.15 +enableEarlyTermination: True +terminationDistance: 0.25 + +### Fut config +numTrajSamples: 3 +trajSampleTimestepInv: 3 +enableTaskObs: True + + +plane: + staticFriction: 1.0 + dynamicFriction: 1.0 + restitution: 0.0 diff --git a/phc/data/cfg/env/phc_kp_mcp_iccv.yaml b/phc/data/cfg/env/phc_kp_mcp_iccv.yaml new file mode 100644 index 0000000..a89013f --- /dev/null +++ b/phc/data/cfg/env/phc_kp_mcp_iccv.yaml @@ -0,0 +1,106 @@ +# if given, will override the device setting in gym. +project_name: "PHC" +notes: " obs v7, sorry for the confusing name!! This is from im_pnn_1" +env: + numEnvs: 1536 + envSpacing: 5 + episodeLength: 300 + isFlagrun: False + enableDebugVis: False + + bias_offset: False + has_self_collision: True + has_mesh: False + has_jt_limit: False + has_dof_subset: True + has_upright_start: True + has_smpl_pd_offset: False + remove_toe: False # For humanoid's geom toe + real_weight_porpotion_capsules: True + + sym_loss_coef: 1 + big_ankle: True + fut_tracks: False + obs_v: 7 + + has_pnn: True + fitting: True + num_prim: 4 + training_prim: 2 + actors_to_load: 4 + has_lateral: False + models: ['output/phc_kp_pnn_iccv/Humanoid.pth'] + + zero_out_far: True + zero_out_far_train: False + cycle_motion: False + + getup_udpate_epoch: 95000 + getup_schedule: True + recoverySteps: 90 + zero_out_far_steps: 90 + recoveryEpisodeProb: 0.5 + fallInitProb: 0.3 + + hard_negative: False + + masterfoot: False + freeze_toe: false + + real_weight: True + kp_scale: 1 + remove_toe_im: False # For imitation + power_reward: True + power_coefficient: 0.00005 + + has_shape_obs: False + has_shape_obs_disc: False + has_shape_variation: False + shape_resampling_interval: 500 + + pdControl: True + powerScale: 1.0 + controlFrequencyInv: 2 # 30 Hz + stateInit: "Random" + hybridInitProb: 0.5 + numAMPObsSteps: 10 + + localRootObs: True + rootHeightObs: True + key_bodies: ["R_Ankle", "L_Ankle", "R_Wrist", "L_Wrist"] + contactBodies: ["R_Ankle", "L_Ankle", "R_Toe", "L_Toe"] + reset_bodies: ['Pelvis', 'L_Hip', 'L_Knee', 'R_Hip', 'R_Knee', 'Torso', 'Spine', 'Chest', 'Neck', 'Head', 'L_Thorax', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'L_Hand', 'R_Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'R_Hand'] + terminationHeight: 0.15 + enableEarlyTermination: True + terminationDistance: 0.25 + + ### Fut config + numTrajSamples: 3 + trajSampleTimestepInv: 3 + enableTaskObs: True + + asset: + assetRoot: "/" + assetFileName: "mjcf/smpl_humanoid.xml" + + plane: + staticFriction: 1.0 + dynamicFriction: 1.0 + restitution: 0.0 + +sim: + substeps: 2 + physx: + num_threads: 4 + solver_type: 1 # 0: pgs, 1: tgs + num_position_iterations: 4 + num_velocity_iterations: 0 + contact_offset: 0.02 + rest_offset: 0.0 + bounce_threshold_velocity: 0.2 + max_depenetration_velocity: 10.0 + default_buffer_size_multiplier: 10.0 + + flex: + num_inner_iterations: 10 + warm_start: 0.25 diff --git a/phc/data/cfg/env/phc_kp_pnn_iccv.yaml b/phc/data/cfg/env/phc_kp_pnn_iccv.yaml new file mode 100644 index 0000000..756cff7 --- /dev/null +++ b/phc/data/cfg/env/phc_kp_pnn_iccv.yaml @@ -0,0 +1,102 @@ +# if given, will override the device setting in gym. +project_name: "PHC" +notes: "PNN, no Laternal connection " +env: + numEnvs: 1536 + envSpacing: 5 + episodeLength: 300 + isFlagrun: False + enableDebugVis: False + + bias_offset: False + has_self_collision: True + has_mesh: False + has_jt_limit: False + has_dof_subset: True + has_upright_start: True + has_smpl_pd_offset: False + remove_toe: False # For humanoid's geom toe + real_weight_porpotion_capsules: True + + sym_loss_coef: 1 + big_ankle: True + fut_tracks: False + obs_v: 7 + + + has_pnn: True + fitting: True + num_prim: 4 + training_prim: 2 + actors_to_load: 4 + has_lateral: False + models: ['output/phc_kp_pnn_iccv/Humanoid.pth'] + + ######## Getup Configs ######## + zero_out_far: True + zero_out_far_train: False + cycle_motion: False + getup_udpate_epoch: 78750 + + cycle_motion: True + hard_negative: False + + masterfoot: False + freeze_toe: false + + real_weight: True + kp_scale: 1 + remove_toe_im: False # For imitation + power_reward: True + + has_shape_obs: False + has_shape_obs_disc: False + has_shape_variation: False + shape_resampling_interval: 500 + + pdControl: True + powerScale: 1.0 + controlFrequencyInv: 2 # 30 Hz + stateInit: "Random" + hybridInitProb: 0.5 + numAMPObsSteps: 10 + + localRootObs: True + rootHeightObs: True + key_bodies: ["R_Ankle", "L_Ankle", "R_Wrist", "L_Wrist"] + contactBodies: ["R_Ankle", "L_Ankle", "R_Toe", "L_Toe"] + reset_bodies: ['Pelvis', 'L_Hip', 'L_Knee', 'R_Hip', 'R_Knee', 'Torso', 'Spine', 'Chest', 'Neck', 'Head', 'L_Thorax', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'L_Hand', 'R_Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'R_Hand'] + terminationHeight: 0.15 + enableEarlyTermination: True + terminationDistance: 0.25 + + ### Fut config + numTrajSamples: 3 + trajSampleTimestepInv: 3 + enableTaskObs: True + + asset: + assetRoot: "/" + assetFileName: "mjcf/smpl_humanoid.xml" + + plane: + staticFriction: 1.0 + dynamicFriction: 1.0 + restitution: 0.0 + +sim: + substeps: 2 + physx: + num_threads: 4 + solver_type: 1 # 0: pgs, 1: tgs + num_position_iterations: 4 + num_velocity_iterations: 0 + contact_offset: 0.02 + rest_offset: 0.0 + bounce_threshold_velocity: 0.2 + max_depenetration_velocity: 10.0 + default_buffer_size_multiplier: 10.0 + + flex: + num_inner_iterations: 10 + warm_start: 0.25 diff --git a/phc/data/cfg/env/phc_prim_iccv.yaml b/phc/data/cfg/env/phc_prim_iccv.yaml new file mode 100644 index 0000000..a084186 --- /dev/null +++ b/phc/data/cfg/env/phc_prim_iccv.yaml @@ -0,0 +1,89 @@ +# if given, will override the device setting in gym. +project_name: "PHC" +notes: "PNN, no Laternal connection " +env: + numEnvs: 1536 + envSpacing: 5 + episodeLength: 300 + isFlagrun: False + enableDebugVis: False + + bias_offset: False + has_self_collision: True + has_mesh: False + has_jt_limit: False + has_dof_subset: True + has_upright_start: True + has_smpl_pd_offset: False + remove_toe: False # For humanoid's geom toe + real_weight_porpotion_capsules: True + + sym_loss_coef: 1 + big_ankle: True + fut_tracks: False + obs_v: 6 + + + + cycle_motion: False + hard_negative: False + + masterfoot: False + freeze_toe: false + + real_weight: True + kp_scale: 1 + remove_toe_im: False # For imitation + power_reward: True + + has_shape_obs: True + has_shape_obs_disc: True + has_shape_variation: True + shape_resampling_interval: 500 + + pdControl: True + powerScale: 1.0 + controlFrequencyInv: 2 # 30 Hz + stateInit: "Random" + hybridInitProb: 0.5 + numAMPObsSteps: 10 + + localRootObs: True + rootHeightObs: True + key_bodies: ["R_Ankle", "L_Ankle", "R_Wrist", "L_Wrist"] + contactBodies: ["R_Ankle", "L_Ankle", "R_Toe", "L_Toe"] + reset_bodies: ['Pelvis', 'L_Hip', 'L_Knee', 'R_Hip', 'R_Knee', 'Torso', 'Spine', 'Chest', 'Neck', 'Head', 'L_Thorax', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'L_Hand', 'R_Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'R_Hand'] + terminationHeight: 0.15 + enableEarlyTermination: True + terminationDistance: 0.25 + + ### Fut config + numTrajSamples: 3 + trajSampleTimestepInv: 3 + enableTaskObs: True + + asset: + assetRoot: "/" + assetFileName: "mjcf/smpl_humanoid.xml" + + plane: + staticFriction: 1.0 + dynamicFriction: 1.0 + restitution: 0.0 + +sim: + substeps: 2 + physx: + num_threads: 4 + solver_type: 1 # 0: pgs, 1: tgs + num_position_iterations: 4 + num_velocity_iterations: 0 + contact_offset: 0.02 + rest_offset: 0.0 + bounce_threshold_velocity: 0.2 + max_depenetration_velocity: 10.0 + default_buffer_size_multiplier: 10.0 + + flex: + num_inner_iterations: 10 + warm_start: 0.25 diff --git a/phc/data/cfg/env/phc_prim_vr.yaml b/phc/data/cfg/env/phc_prim_vr.yaml new file mode 100644 index 0000000..bc2e277 --- /dev/null +++ b/phc/data/cfg/env/phc_prim_vr.yaml @@ -0,0 +1,93 @@ +# if given, will override the device setting in gym. +project_name: "PHC" +notes: "obs_v6 No z. Bigger no distilliation. Direct policy" +env: + numEnvs: 3072 + envSpacing: 5 + episodeLength: 300 + isFlagrun: False + enableDebugVis: False + + bias_offset: False + has_self_collision: True + has_mesh: False + has_jt_limit: False + has_dof_subset: True + has_upright_start: True + has_smpl_pd_offset: False + remove_toe: False # For humanoid's geom toe + real_weight_porpotion_capsules: True + motion_sym_loss: False + sym_loss_coef: 1 + big_ankle: True + fut_tracks: False + obs_v: 6 + auto_pmcp: False + auto_pmcp_soft: True + fitting: False + eval_full: False + + cycle_motion: False + hard_negative: False + + masterfoot: False + freeze_toe: false + freeze_hand: False + + real_weight: True + kp_scale: 1 + remove_toe_im: False # For imitation + power_reward: True + + has_shape_obs: false + has_shape_obs_disc: false + has_shape_variation: False + shape_resampling_interval: 250 + + pdControl: True + powerScale: 1.0 + controlFrequencyInv: 2 # 30 Hz + stateInit: "Random" + hybridInitProb: 0.5 + numAMPObsSteps: 10 + + localRootObs: True + rootHeightObs: True + key_bodies: ["R_Ankle", "L_Ankle", "R_Wrist", "L_Wrist"] + contactBodies: ["R_Ankle", "L_Ankle", "R_Toe", "L_Toe"] + reset_bodies: ["Head", "L_Hand", "R_Hand"] + trackBodies: ["Head", "L_Hand", "R_Hand"] + terminationHeight: 0.15 + enableEarlyTermination: True + terminationDistance: 0.25 + + ### Fut config + numTrajSamples: 10 + trajSampleTimestepInv: 10 + enableTaskObs: True + + asset: + assetRoot: "/" + assetFileName: "mjcf/smpl_humanoid.xml" + + plane: + staticFriction: 1.0 + dynamicFriction: 1.0 + restitution: 0.0 + +sim: + substeps: 2 + physx: + num_threads: 4 + solver_type: 1 # 0: pgs, 1: tgs + num_position_iterations: 4 + num_velocity_iterations: 0 + contact_offset: 0.02 + rest_offset: 0.0 + bounce_threshold_velocity: 0.2 + max_depenetration_velocity: 10.0 + default_buffer_size_multiplier: 10.0 + + flex: + num_inner_iterations: 10 + warm_start: 0.25 diff --git a/phc/data/cfg/env/phc_shape_mcp_iccv.yaml b/phc/data/cfg/env/phc_shape_mcp_iccv.yaml new file mode 100644 index 0000000..31c9f5a --- /dev/null +++ b/phc/data/cfg/env/phc_shape_mcp_iccv.yaml @@ -0,0 +1,108 @@ +# if given, will override the device setting in gym. +project_name: "PHC" +notes: "Progressive MCP without softmax, zero out far" +env: + numEnvs: 1024 + envSpacing: 2 + episodeLength: 300 + isFlagrun: False + enableDebugVis: False + + bias_offset: False + has_self_collision: True + has_mesh: False + has_jt_limit: False + has_dof_subset: True + has_upright_start: True + has_smpl_pd_offset: False + remove_toe: False # For humanoid's geom toe + real_weight_porpotion_capsules: True + + sym_loss_coef: 1 + big_ankle: True + fut_tracks: False + obs_v: 6 + + ######## PNN Configs ######## + has_pnn: True + fitting: True + num_prim: 4 + training_prim: 0 + actors_to_load: 4 + has_lateral: False + models: ['output/phc_shape_pnn_iccv/Humanoid.pth'] + + ######## Getup Configs ######## + zero_out_far: True + zero_out_far_train: False + cycle_motion: False + getup_udpate_epoch: 78750 + + getup_schedule: True + recoverySteps: 90 + zero_out_far_steps: 90 + recoveryEpisodeProb: 0.5 + fallInitProb: 0.3 + + hard_negative: False + + masterfoot: False + freeze_toe: false + + real_weight: True + kp_scale: 1 + remove_toe_im: False # For imitation + power_reward: True + power_coefficient: 0.00005 + + has_shape_obs: True + has_shape_obs_disc: True + has_shape_variation: True + shape_resampling_interval: 500 + + pdControl: True + powerScale: 1.0 + controlFrequencyInv: 2 # 30 Hz + stateInit: "Random" + hybridInitProb: 0.5 + numAMPObsSteps: 10 + + localRootObs: True + rootHeightObs: True + key_bodies: ["R_Ankle", "L_Ankle", "R_Wrist", "L_Wrist"] + contactBodies: ["R_Ankle", "L_Ankle", "R_Toe", "L_Toe"] + reset_bodies: ['Pelvis', 'L_Hip', 'L_Knee', 'R_Hip', 'R_Knee', 'Torso', 'Spine', 'Chest', 'Neck', 'Head', 'L_Thorax', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'L_Hand', 'R_Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'R_Hand'] + terminationHeight: 0.15 + enableEarlyTermination: True + terminationDistance: 0.25 + + ### Fut config + numTrajSamples: 3 + trajSampleTimestepInv: 3 + enableTaskObs: True + + asset: + assetRoot: "/" + assetFileName: "mjcf/smpl_humanoid.xml" + + plane: + staticFriction: 1.0 + dynamicFriction: 1.0 + restitution: 0.0 + +sim: + substeps: 2 + physx: + num_threads: 4 + solver_type: 1 # 0: pgs, 1: tgs + num_position_iterations: 4 + num_velocity_iterations: 0 + contact_offset: 0.02 + rest_offset: 0.0 + bounce_threshold_velocity: 0.2 + max_depenetration_velocity: 10.0 + default_buffer_size_multiplier: 10.0 + + flex: + num_inner_iterations: 10 + warm_start: 0.25 diff --git a/phc/data/cfg/env/phc_shape_pnn_iccv.yaml b/phc/data/cfg/env/phc_shape_pnn_iccv.yaml new file mode 100644 index 0000000..30b4b41 --- /dev/null +++ b/phc/data/cfg/env/phc_shape_pnn_iccv.yaml @@ -0,0 +1,102 @@ +# if given, will override the device setting in gym. +project_name: "PHC" +notes: "PNN, no Laternal connection " +env: + numEnvs: 1536 + envSpacing: 5 + episodeLength: 300 + isFlagrun: False + enableDebugVis: False + + bias_offset: False + has_self_collision: True + has_mesh: False + has_jt_limit: False + has_dof_subset: True + has_upright_start: True + has_smpl_pd_offset: False + remove_toe: False # For humanoid's geom toe + real_weight_porpotion_capsules: True + + sym_loss_coef: 1 + big_ankle: True + fut_tracks: False + obs_v: 6 + + has_pnn: True + fitting: True + num_prim: 4 + training_prim: 0 + actors_to_load: 0 + has_lateral: False + models: ['output/phc_shape_pnn_iccv/Humanoid.pth'] + + cycle_motion: True + hard_negative: False + + + ######## Getup Configs ######## + zero_out_far: True + zero_out_far_train: False + cycle_motion: False + getup_udpate_epoch: 78750 + + masterfoot: False + freeze_toe: false + + real_weight: True + kp_scale: 1 + remove_toe_im: False # For imitation + power_reward: True + + has_shape_obs: True + has_shape_obs_disc: True + has_shape_variation: True + shape_resampling_interval: 500 + + pdControl: True + powerScale: 1.0 + controlFrequencyInv: 2 # 30 Hz + stateInit: "Random" + hybridInitProb: 0.5 + numAMPObsSteps: 10 + + localRootObs: True + rootHeightObs: True + key_bodies: ["R_Ankle", "L_Ankle", "R_Wrist", "L_Wrist"] + contactBodies: ["R_Ankle", "L_Ankle", "R_Toe", "L_Toe"] + reset_bodies: ['Pelvis', 'L_Hip', 'L_Knee', 'R_Hip', 'R_Knee', 'Torso', 'Spine', 'Chest', 'Neck', 'Head', 'L_Thorax', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'L_Hand', 'R_Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'R_Hand'] + terminationHeight: 0.15 + enableEarlyTermination: True + terminationDistance: 0.25 + + ### Fut config + numTrajSamples: 3 + trajSampleTimestepInv: 3 + enableTaskObs: True + + asset: + assetRoot: "/" + assetFileName: "mjcf/smpl_humanoid.xml" + + plane: + staticFriction: 1.0 + dynamicFriction: 1.0 + restitution: 0.0 + +sim: + substeps: 2 + physx: + num_threads: 4 + solver_type: 1 # 0: pgs, 1: tgs + num_position_iterations: 4 + num_velocity_iterations: 0 + contact_offset: 0.02 + rest_offset: 0.0 + bounce_threshold_velocity: 0.2 + max_depenetration_velocity: 10.0 + default_buffer_size_multiplier: 10.0 + + flex: + num_inner_iterations: 10 + warm_start: 0.25 diff --git a/phc/data/cfg/env/phc_shape_pnn_train_iccv.yaml b/phc/data/cfg/env/phc_shape_pnn_train_iccv.yaml new file mode 100644 index 0000000..f35bfbe --- /dev/null +++ b/phc/data/cfg/env/phc_shape_pnn_train_iccv.yaml @@ -0,0 +1,100 @@ +# if given, will override the device setting in gym. +project_name: "PHC" +notes: "PNN, no Laternal connection " +env: + numEnvs: 1536 + envSpacing: 5 + episodeLength: 300 + isFlagrun: False + enableDebugVis: False + + bias_offset: False + has_self_collision: True + has_mesh: False + has_jt_limit: False + has_dof_subset: True + has_upright_start: True + has_smpl_pd_offset: False + remove_toe: False # For humanoid's geom toe + real_weight_porpotion_capsules: True + + sym_loss_coef: 1 + big_ankle: True + fut_tracks: False + obs_v: 6 + + has_pnn: True + fitting: False + num_prim: 4 + training_prim: 0 + actors_to_load: 0 + has_lateral: False + + cycle_motion: False + hard_negative: False + + ######## Getup Configs ######## + zero_out_far: False + zero_out_far_train: False + cycle_motion: False + getup_udpate_epoch: 78750 + + masterfoot: False + freeze_toe: false + + real_weight: True + kp_scale: 1 + remove_toe_im: False # For imitation + power_reward: True + + has_shape_obs: True + has_shape_obs_disc: True + has_shape_variation: True + shape_resampling_interval: 500 + + pdControl: True + powerScale: 1.0 + controlFrequencyInv: 2 # 30 Hz + stateInit: "Random" + hybridInitProb: 0.5 + numAMPObsSteps: 10 + + localRootObs: True + rootHeightObs: True + key_bodies: ["R_Ankle", "L_Ankle", "R_Wrist", "L_Wrist"] + contactBodies: ["R_Ankle", "L_Ankle", "R_Toe", "L_Toe"] + reset_bodies: ['Pelvis', 'L_Hip', 'L_Knee', 'R_Hip', 'R_Knee', 'Torso', 'Spine', 'Chest', 'Neck', 'Head', 'L_Thorax', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'L_Hand', 'R_Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'R_Hand'] + terminationHeight: 0.15 + enableEarlyTermination: True + terminationDistance: 0.25 + + ### Fut config + numTrajSamples: 3 + trajSampleTimestepInv: 3 + enableTaskObs: True + + asset: + assetRoot: "/" + assetFileName: "mjcf/smpl_humanoid.xml" + + plane: + staticFriction: 1.0 + dynamicFriction: 1.0 + restitution: 0.0 + +sim: + substeps: 2 + physx: + num_threads: 4 + solver_type: 1 # 0: pgs, 1: tgs + num_position_iterations: 4 + num_velocity_iterations: 0 + contact_offset: 0.02 + rest_offset: 0.0 + bounce_threshold_velocity: 0.2 + max_depenetration_velocity: 10.0 + default_buffer_size_multiplier: 10.0 + + flex: + num_inner_iterations: 10 + warm_start: 0.25 diff --git a/phc/data/cfg/env/pulse_amp.yaml b/phc/data/cfg/env/pulse_amp.yaml new file mode 100644 index 0000000..73131b0 --- /dev/null +++ b/phc/data/cfg/env/pulse_amp.yaml @@ -0,0 +1,90 @@ +motion_file: "" +num_envs: 1536 +env_spacing: 5 +episode_length: 300 +isFlagrun: False +enable_debug_vis: False + +embedding_norm: 1 +embedding_size: 32 + +z_readout: False +fitting: False +z_model: True # For motion symm loss +freeze_hand: False +distill: false +save_kin_info: False +distill_z_model: false +z_read: False + +use_vae_prior: True +use_vae_sphere_posterior: False +use_vae_fixed_prior: False +use_vae_sphere_prior: False +use_vae_prior_loss: False + + +distill: false +save_kin_info: False +distill_z_model: false +z_read: False +distill_model_config: + embedding_norm: 1 + embedding_size: 32 + fut_tracks_dropout: False + fut_tracks: False + trajSampleTimestepInv: 5 + numTrajSamples: 10 + z_activation: "silu" + z_type: "vae" + +models: ['output/HumanoidIm/pulse_vae_iclr/Humanoid.pth'] + +real_weight: True +box_body: True +kp_scale: 1 +real_weight: True +freeze_hand: False +freeze_toe: False + +power_reward: False +power_usage_reward: False +power_usage_coefficient: 0.01 + +has_shape_obs: false +has_shape_obs_disc: false +has_shape_variation: False +shape_resampling_interval: 250 + +# Task specific parameters +tarSpeedMin: 0.0 +tarSpeedMax: 5 +speedChangeStepsMin: 100 +speedChangeStepsMax: 200 +enableTaskObs: True + +pdControl: True +power_scale: 1.0 +controlFrequencyInv: 2 # 30 Hz +stateInit: "Random" +hybridInitProb: 0.5 +numAMPObsSteps: 10 +enableTaskObs: True + +local_root_obs: True +root_height_obs: True +ampRootHeightObs: False +key_bodies: ["R_Ankle", "L_Ankle", "R_Wrist", "L_Wrist"] +contact_bodies: ["R_Ankle", "L_Ankle", "R_Toe", "L_Toe"] +terminationHeight: 0.15 +enableEarlyTermination: True + +asset: + assetRoot: "/" + assetFileName: "mjcf/smpl_humanoid.xml" + +plane: + staticFriction: 1.0 + dynamicFriction: 1.0 + restitution: 0.0 + diff --git a/phc/data/cfg/learning/im.yaml b/phc/data/cfg/learning/im.yaml new file mode 100644 index 0000000..5df093e --- /dev/null +++ b/phc/data/cfg/learning/im.yaml @@ -0,0 +1,94 @@ +params: + seed: 0 + + algo: + name: im_amp + + model: + name: amp + + network: + name: amp + separate: True + discrete: False + + space: + continuous: + mu_activation: None + sigma_activation: None + mu_init: + name: default + sigma_init: + name: const_initializer + val: -2.9 + fixed_sigma: True + learn_sigma: False + + mlp: + units: [1024, 512] + activation: relu + d2rl: False + + initializer: + name: default + regularizer: + name: None + + disc: + units: [1024, 512] + activation: relu + + initializer: + name: default + + load_checkpoint: False + + config: + name: Humanoid + env_name: rlgpu + multi_gpu: False + ppo: True + mixed_precision: False + normalize_input: True + normalize_value: True + reward_shaper: + scale_value: 1 + normalize_advantage: True + gamma: 0.99 + tau: 0.95 + learning_rate: 2e-5 + lr_schedule: constant + score_to_win: 20000 + max_epochs: 10000000 + save_best_after: 100 + save_frequency: 2500 + print_stats: False + save_intermediate: True + entropy_coef: 0.0 + truncate_grads: True + grad_norm: 50.0 + e_clip: 0.2 + horizon_length: 32 + minibatch_size: 16384 + mini_epochs: 6 + critic_coef: 5 + clip_value: False + + bounds_loss_coef: 10 + amp_obs_demo_buffer_size: 200000 + amp_replay_buffer_size: 200000 + amp_replay_keep_prob: 0.01 + amp_batch_size: 512 + amp_minibatch_size: 4096 + disc_coef: 5 + disc_logit_reg: 0.01 + disc_grad_penalty: 5 + disc_reward_scale: 2 + disc_weight_decay: 0.0001 + normalize_amp_input: True + + task_reward_w: 0.5 + disc_reward_w: 0.5 + + player: + games_num: 50000000 \ No newline at end of file diff --git a/phc/data/cfg/learning/im_big.yaml b/phc/data/cfg/learning/im_big.yaml new file mode 100644 index 0000000..f9eba07 --- /dev/null +++ b/phc/data/cfg/learning/im_big.yaml @@ -0,0 +1,97 @@ +params: + seed: 0 + + algo: + name: im_amp + + model: + name: amp + + network: + name: amp + separate: True + + space: + continuous: + mu_activation: None + sigma_activation: None + mu_init: + name: default + sigma_init: + name: const_initializer + val: -2.9 + fixed_sigma: True + learn_sigma: False + + mlp: + units: [2048, 1536, 1024, 1024, 512, 512] # comparable paramter to z_big_task + activation: silu + d2rl: False + + initializer: + name: default + regularizer: + name: None + + disc: + # units: [2048, 1024, 512] + # activation: silu + + units: [1024, 512] + activation: relu + + + initializer: + name: default + + load_checkpoint: False + + config: + name: Humanoid + env_name: rlgpu + multi_gpu: False + ppo: True + mixed_precision: False + normalize_input: True + normalize_value: True + reward_shaper: + scale_value: 1 + normalize_advantage: True + gamma: 0.99 + tau: 0.95 + learning_rate: 2e-5 + lr_schedule: constant + score_to_win: 20000 + max_epochs: 10000000 + save_best_after: 100 + save_frequency: 1500 + print_stats: False + save_intermediate: True + entropy_coef: 0.0 + truncate_grads: True + grad_norm: 50.0 + e_clip: 0.2 + horizon_length: 32 + minibatch_size: 16384 + mini_epochs: 6 + critic_coef: 5 + clip_value: False + + bounds_loss_coef: 10 + amp_obs_demo_buffer_size: 200000 + amp_replay_buffer_size: 200000 + amp_replay_keep_prob: 0.01 + amp_batch_size: 512 + amp_minibatch_size: 4096 + disc_coef: 5 + disc_logit_reg: 0.01 + disc_grad_penalty: 5 + disc_reward_scale: 2 + disc_weight_decay: 0.0001 + normalize_amp_input: True + + task_reward_w: 0.5 + disc_reward_w: 0.5 + + player: + games_num: 50000000 \ No newline at end of file diff --git a/phc/data/cfg/learning/im_mcp.yaml b/phc/data/cfg/learning/im_mcp.yaml new file mode 100644 index 0000000..66887cb --- /dev/null +++ b/phc/data/cfg/learning/im_mcp.yaml @@ -0,0 +1,97 @@ +# No softmax +params: + seed: 0 + + algo: + name: im_amp + + model: + name: amp + + network: + name: amp_mcp + separate: True + discrete: false + has_softmax: False + ending_act: True + + space: + continuous: + mu_activation: None + sigma_activation: None + mu_init: + name: default + sigma_init: + name: const_initializer + val: -2.9 + fixed_sigma: True + learn_sigma: False + + mlp: + units: [1024, 512] + activation: relu + d2rl: False + + initializer: + name: default + regularizer: + name: None + + disc: + units: [1024, 512] + activation: relu + + initializer: + name: default + + load_checkpoint: False + + config: + name: Humanoid + env_name: rlgpu + multi_gpu: False + ppo: True + mixed_precision: False + normalize_input: True + normalize_value: True + reward_shaper: + scale_value: 1 + normalize_advantage: True + gamma: 0.99 + tau: 0.95 + learning_rate: 2e-5 + lr_schedule: constant + score_to_win: 20000 + max_epochs: 10000000 + save_best_after: 100 + save_frequency: 2500 + print_stats: False + save_intermediate: True + entropy_coef: 0.0 + truncate_grads: True + grad_norm: 50.0 + e_clip: 0.2 + horizon_length: 32 + minibatch_size: 16384 + mini_epochs: 6 + critic_coef: 5 + clip_value: False + + bounds_loss_coef: 10 + amp_obs_demo_buffer_size: 200000 + amp_replay_buffer_size: 200000 + amp_replay_keep_prob: 0.01 + amp_batch_size: 512 + amp_minibatch_size: 4096 + disc_coef: 5 + disc_logit_reg: 0.01 + disc_grad_penalty: 5 + disc_reward_scale: 2 + disc_weight_decay: 0.0001 + normalize_amp_input: True + + task_reward_w: 0.5 + disc_reward_w: 0.5 + + player: + games_num: 999999999999999999999999 \ No newline at end of file diff --git a/phc/data/cfg/learning/im_mcp_big.yaml b/phc/data/cfg/learning/im_mcp_big.yaml new file mode 100644 index 0000000..2045847 --- /dev/null +++ b/phc/data/cfg/learning/im_mcp_big.yaml @@ -0,0 +1,99 @@ +# No softmax has ending MLP. +params: + seed: 0 + + algo: + name: im_amp + + model: + name: amp + + network: + name: amp_mcp + separate: True + discrete: false + has_softmax: False + ending_act: False + + space: + continuous: + mu_activation: None + sigma_activation: None + mu_init: + name: default + sigma_init: + name: const_initializer + val: -2.9 + fixed_sigma: True + learn_sigma: False + + mlp: + units: [2048, 1536, 1024, 1024, 512, 512] # comparable paramter to z_big_task + activation: silu + d2rl: False + + initializer: + name: default + regularizer: + name: None + + disc: + units: [1024, 512] + activation: relu + + initializer: + name: default + + load_checkpoint: False + + config: + name: Humanoid + env_name: rlgpu + multi_gpu: False + ppo: True + mixed_precision: False + normalize_input: True + normalize_value: True + reward_shaper: + scale_value: 1 + normalize_advantage: True + gamma: 0.99 + tau: 0.95 + learning_rate: 2e-5 + lr_schedule: constant + score_to_win: 20000 + max_epochs: 10000000 + save_best_after: 100 + save_frequency: 500 + print_stats: False + save_intermediate: True + entropy_coef: 0.0 + truncate_grads: True + grad_norm: 50.0 + e_clip: 0.2 + horizon_length: 32 + minibatch_size: 16384 + mini_epochs: 6 + critic_coef: 5 + clip_value: False + + bounds_loss_coef: 10 + amp_obs_demo_buffer_size: 200000 + amp_replay_buffer_size: 200000 + amp_replay_keep_prob: 0.01 + amp_batch_size: 512 + amp_minibatch_size: 4096 + disc_coef: 5 + disc_logit_reg: 0.01 + disc_grad_penalty: 5 + disc_reward_scale: 2 + disc_weight_decay: 0.0001 + normalize_amp_input: True + + task_reward_w: 0.5 + disc_reward_w: 0.5 + + amp_dropout: True + + player: + games_num: 999999999999999999999999 \ No newline at end of file diff --git a/phc/data/cfg/learning/im_pnn.yaml b/phc/data/cfg/learning/im_pnn.yaml new file mode 100644 index 0000000..d1fcfb1 --- /dev/null +++ b/phc/data/cfg/learning/im_pnn.yaml @@ -0,0 +1,94 @@ +params: + seed: 0 + + algo: + name: im_amp + + model: + name: amp + + network: + name: amp_pnn + separate: True + discrete: False + + space: + continuous: + mu_activation: None + sigma_activation: None + mu_init: + name: default + sigma_init: + name: const_initializer + val: -2.9 + fixed_sigma: True + learn_sigma: False + + mlp: + units: [1024, 512] + activation: relu + d2rl: False + + initializer: + name: default + regularizer: + name: None + + disc: + units: [1024, 512] + activation: relu + + initializer: + name: default + + load_checkpoint: False + + config: + name: Humanoid + env_name: rlgpu + multi_gpu: False + ppo: True + mixed_precision: False + normalize_input: True + normalize_value: True + reward_shaper: + scale_value: 1 + normalize_advantage: True + gamma: 0.99 + tau: 0.95 + learning_rate: 2e-5 + lr_schedule: constant + score_to_win: 20000 + max_epochs: 10000000 + save_best_after: 100 + save_frequency: 2500 + print_stats: False + save_intermediate: True + entropy_coef: 0.0 + truncate_grads: True + grad_norm: 50.0 + e_clip: 0.2 + horizon_length: 32 + minibatch_size: 16384 + mini_epochs: 6 + critic_coef: 5 + clip_value: False + + bounds_loss_coef: 10 + amp_obs_demo_buffer_size: 200000 + amp_replay_buffer_size: 200000 + amp_replay_keep_prob: 0.01 + amp_batch_size: 512 + amp_minibatch_size: 4096 + disc_coef: 5 + disc_logit_reg: 0.01 + disc_grad_penalty: 5 + disc_reward_scale: 2 + disc_weight_decay: 0.0001 + normalize_amp_input: True + + task_reward_w: 0.5 + disc_reward_w: 0.5 + + player: + games_num: 50000000 \ No newline at end of file diff --git a/phc/data/cfg/learning/im_pnn_big.yaml b/phc/data/cfg/learning/im_pnn_big.yaml new file mode 100644 index 0000000..de07f0b --- /dev/null +++ b/phc/data/cfg/learning/im_pnn_big.yaml @@ -0,0 +1,96 @@ +params: + seed: 0 + + algo: + name: im_amp + + model: + name: amp + + network: + name: amp_pnn + separate: True + discrete: False + + space: + continuous: + mu_activation: None + sigma_activation: None + mu_init: + name: default + sigma_init: + name: const_initializer + val: -2.9 + fixed_sigma: True + learn_sigma: False + + mlp: + units: [2048, 1536, 1024, 1024, 512, 512] # comparable paramter to z_big_task + activation: silu + d2rl: False + + initializer: + name: default + regularizer: + name: None + + disc: + units: [1024, 512] + activation: relu + + initializer: + name: default + + load_checkpoint: False + + config: + name: Humanoid + env_name: rlgpu + multi_gpu: False + ppo: True + mixed_precision: False + normalize_input: True + normalize_value: True + reward_shaper: + scale_value: 1 + normalize_advantage: True + gamma: 0.99 + tau: 0.95 + learning_rate: 2e-5 + lr_schedule: constant + score_to_win: 20000 + max_epochs: 10000000 + save_best_after: 100 + save_frequency: 1500 + print_stats: False + save_intermediate: True + entropy_coef: 0.0 + truncate_grads: True + grad_norm: 50.0 + e_clip: 0.2 + horizon_length: 32 + minibatch_size: 16384 + mini_epochs: 6 + critic_coef: 5 + clip_value: False + + bounds_loss_coef: 10 + amp_obs_demo_buffer_size: 200000 + amp_replay_buffer_size: 200000 + amp_replay_keep_prob: 0.01 + amp_batch_size: 512 + amp_minibatch_size: 4096 + disc_coef: 5 + disc_logit_reg: 0.01 + disc_grad_penalty: 5 + disc_reward_scale: 2 + disc_weight_decay: 0.0001 + normalize_amp_input: True + + task_reward_w: 0.5 + disc_reward_w: 0.5 + + amp_dropout: False + + player: + games_num: 50000000 \ No newline at end of file diff --git a/phc/data/cfg/learning/im_z_fit.yaml b/phc/data/cfg/learning/im_z_fit.yaml new file mode 100644 index 0000000..9d6dd6b --- /dev/null +++ b/phc/data/cfg/learning/im_z_fit.yaml @@ -0,0 +1,106 @@ +params: + seed: 0 + + algo: + name: im_amp + + model: + name: amp + + network: + name: amp_z + separate: True + + space: + continuous: + mu_activation: None + sigma_activation: None + mu_init: + name: default + sigma_init: + name: const_initializer + val: -2.9 + fixed_sigma: True + learn_sigma: False + + mlp: + units: [3096, 2048, 1024] + activation: silu + d2rl: False + + initializer: + name: default + regularizer: + name: None + + task_mlp: + units: [1536, 1024, 512] + activation: silu + d2rl: False + + initializer: + name: default + regularizer: + name: None + + disc: + units: [1024, 512] + activation: relu + + initializer: + name: default + + load_checkpoint: False + + config: + name: Humanoid + env_name: rlgpu + multi_gpu: False + use_seq_rl: True + mixed_precision: False + normalize_input: True + normalize_value: True + reward_shaper: + scale_value: 1 + normalize_advantage: True + gamma: 0.99 + tau: 0.95 + learning_rate: 2e-5 + lr_schedule: constant + score_to_win: 20000 + max_epochs: 10000000 + save_best_after: 100 + save_frequency: 500 + print_stats: False + save_intermediate: True + entropy_coef: 0.0 + truncate_grads: True + grad_norm: 50.0 + ppo: True + e_clip: 0.2 + horizon_length: 32 + minibatch_size: 16384 + mini_epochs: 6 + critic_coef: 5 + clip_value: False + + bounds_loss_coef: 10 + amp_obs_demo_buffer_size: 200000 + amp_replay_buffer_size: 200000 + amp_replay_keep_prob: 0.01 + amp_batch_size: 512 + amp_minibatch_size: 4096 + disc_coef: 5 + disc_logit_reg: 0.01 + disc_grad_penalty: 5 + disc_reward_scale: 2 + disc_weight_decay: 0.0001 + normalize_amp_input: True + + task_reward_w: 0.5 + disc_reward_w: 0.5 + + amp_dropout: False + + player: + games_num: 50000000 \ No newline at end of file diff --git a/phc/data/cfg/robot/smpl_humanoid.yaml b/phc/data/cfg/robot/smpl_humanoid.yaml new file mode 100644 index 0000000..349d076 --- /dev/null +++ b/phc/data/cfg/robot/smpl_humanoid.yaml @@ -0,0 +1,28 @@ +humanoid_type: smpl +bias_offset: False +has_self_collision: True +has_mesh: False +has_jt_limit: False +has_dof_subset: True +has_upright_start: True +has_smpl_pd_offset: False +remove_toe: False # For humanoid's geom toe +motion_sym_loss: False +sym_loss_coef: 1 +big_ankle: True + +has_shape_obs: false +has_shape_obs_disc: false +has_shape_variation: False + +masterfoot: False +freeze_toe: false +freeze_hand: False +box_body: True +real_weight: True +real_weight_porpotion_capsules: True +real_weight_porpotion_boxes: True + +asset: + assetRoot: "/" + assetFileName: "mjcf/smpl_humanoid.xml" \ No newline at end of file diff --git a/phc/data/cfg/robot/smpl_humanoid_shape.yaml b/phc/data/cfg/robot/smpl_humanoid_shape.yaml new file mode 100644 index 0000000..a77232c --- /dev/null +++ b/phc/data/cfg/robot/smpl_humanoid_shape.yaml @@ -0,0 +1,28 @@ +humanoid_type: smpl +bias_offset: False +has_self_collision: True +has_mesh: False +has_jt_limit: False +has_dof_subset: True +has_upright_start: True +has_smpl_pd_offset: False +remove_toe: False # For humanoid's geom toe +motion_sym_loss: False +sym_loss_coef: 1 +big_ankle: True + +has_shape_obs: True +has_shape_obs_disc: True +has_shape_variation: True + +masterfoot: False +freeze_toe: false +freeze_hand: False +box_body: True +real_weight: True +real_weight_porpotion_capsules: True +real_weight_porpotion_boxes: True + +asset: + assetRoot: "/" + assetFileName: "mjcf/smpl_humanoid.xml" \ No newline at end of file diff --git a/phc/data/cfg/robot/smplx_humanoid.yaml b/phc/data/cfg/robot/smplx_humanoid.yaml new file mode 100644 index 0000000..9a555f6 --- /dev/null +++ b/phc/data/cfg/robot/smplx_humanoid.yaml @@ -0,0 +1,28 @@ +humanoid_type: smplx +bias_offset: False +has_self_collision: True +has_mesh: False +has_jt_limit: False +has_dof_subset: True +has_upright_start: False +has_smpl_pd_offset: False +remove_toe: False # For humanoid's geom toe +motion_sym_loss: False +sym_loss_coef: 1 +big_ankle: True + +has_shape_obs: false +has_shape_obs_disc: false +has_shape_variation: False + +masterfoot: False +freeze_toe: false +freeze_hand: False +box_body: True +real_weight: True +real_weight_porpotion_capsules: True +real_weight_porpotion_boxes: True + +asset: + assetRoot: "/" + assetFileName: "mjcf/smplx_humanoid.xml" \ No newline at end of file diff --git a/phc/data/cfg/sim/default_sim.yaml b/phc/data/cfg/sim/default_sim.yaml new file mode 100644 index 0000000..a94a139 --- /dev/null +++ b/phc/data/cfg/sim/default_sim.yaml @@ -0,0 +1,22 @@ +sim_device: "cuda:0" +pipeline: "gpu" +graphics_device_id: 0 +subscenes: 0 # Number of PhysX subscenes to simulate in parallel +slices: 0 # Number of client threads that process env slices +use_flex: False + +substeps: 2 +physx: + num_threads: 4 + solver_type: 1 # 0: pgs, 1: tgs + num_position_iterations: 4 + num_velocity_iterations: 0 + contact_offset: 0.02 + rest_offset: 0.0 + bounce_threshold_velocity: 0.2 + max_depenetration_velocity: 10.0 + default_buffer_size_multiplier: 10.0 + +flex: + num_inner_iterations: 10 + warm_start: 0.25 diff --git a/phc/data/cfg/train/rlg/im.yaml b/phc/data/cfg/train/rlg/im.yaml new file mode 100644 index 0000000..93314b1 --- /dev/null +++ b/phc/data/cfg/train/rlg/im.yaml @@ -0,0 +1,95 @@ +params: + seed: 0 + + algo: + name: im_amp + + model: + name: amp + + network: + name: amp + separate: True + discrete: False + + space: + continuous: + mu_activation: None + sigma_activation: None + mu_init: + name: default + sigma_init: + name: const_initializer + val: -2.9 + fixed_sigma: True + learn_sigma: False + + mlp: + units: [1024, 512] + activation: relu + d2rl: False + + initializer: + name: default + regularizer: + name: None + + disc: + units: [1024, 512] + activation: relu + + initializer: + name: default + + load_checkpoint: False + + config: + name: Humanoid + env_name: rlgpu + multi_gpu: False + ppo: True + mixed_precision: False + normalize_input: True + normalize_value: True + reward_shaper: + scale_value: 1 + normalize_advantage: True + gamma: 0.99 + tau: 0.95 + learning_rate: 2e-5 + lr_schedule: constant + score_to_win: 20000 + max_epochs: 10000000 + save_best_after: 100 + save_frequency: 2500 + print_stats: False + save_intermediate: True + entropy_coef: 0.0 + truncate_grads: True + grad_norm: 50.0 + ppo: True + e_clip: 0.2 + horizon_length: 32 + minibatch_size: 16384 + mini_epochs: 6 + critic_coef: 5 + clip_value: False + + bounds_loss_coef: 10 + amp_obs_demo_buffer_size: 200000 + amp_replay_buffer_size: 200000 + amp_replay_keep_prob: 0.01 + amp_batch_size: 512 + amp_minibatch_size: 4096 + disc_coef: 5 + disc_logit_reg: 0.01 + disc_grad_penalty: 5 + disc_reward_scale: 2 + disc_weight_decay: 0.0001 + normalize_amp_input: True + + task_reward_w: 0.5 + disc_reward_w: 0.5 + + player: + games_num: 50000000 \ No newline at end of file diff --git a/phc/data/cfg/train/rlg/im_big.yaml b/phc/data/cfg/train/rlg/im_big.yaml new file mode 100644 index 0000000..89cfa66 --- /dev/null +++ b/phc/data/cfg/train/rlg/im_big.yaml @@ -0,0 +1,98 @@ +params: + seed: 0 + + algo: + name: im_amp + + model: + name: amp + + network: + name: amp + separate: True + + space: + continuous: + mu_activation: None + sigma_activation: None + mu_init: + name: default + sigma_init: + name: const_initializer + val: -2.9 + fixed_sigma: True + learn_sigma: False + + mlp: + units: [2048, 1536, 1024, 1024, 512, 512] # comparable paramter to z_big_task + activation: silu + d2rl: False + + initializer: + name: default + regularizer: + name: None + + disc: + # units: [2048, 1024, 512] + # activation: silu + + units: [1024, 512] + activation: relu + + + initializer: + name: default + + load_checkpoint: False + + config: + name: Humanoid + env_name: rlgpu + multi_gpu: False + ppo: True + mixed_precision: False + normalize_input: True + normalize_value: True + reward_shaper: + scale_value: 1 + normalize_advantage: True + gamma: 0.99 + tau: 0.95 + learning_rate: 2e-5 + lr_schedule: constant + score_to_win: 20000 + max_epochs: 10000000 + save_best_after: 100 + save_frequency: 1500 + print_stats: False + save_intermediate: True + entropy_coef: 0.0 + truncate_grads: True + grad_norm: 50.0 + ppo: True + e_clip: 0.2 + horizon_length: 32 + minibatch_size: 16384 + mini_epochs: 6 + critic_coef: 5 + clip_value: False + + bounds_loss_coef: 10 + amp_obs_demo_buffer_size: 200000 + amp_replay_buffer_size: 200000 + amp_replay_keep_prob: 0.01 + amp_batch_size: 512 + amp_minibatch_size: 4096 + disc_coef: 5 + disc_logit_reg: 0.01 + disc_grad_penalty: 5 + disc_reward_scale: 2 + disc_weight_decay: 0.0001 + normalize_amp_input: True + + task_reward_w: 0.5 + disc_reward_w: 0.5 + + player: + games_num: 50000000 \ No newline at end of file diff --git a/phc/data/cfg/train/rlg/im_mcp.yaml b/phc/data/cfg/train/rlg/im_mcp.yaml new file mode 100644 index 0000000..63b6cb4 --- /dev/null +++ b/phc/data/cfg/train/rlg/im_mcp.yaml @@ -0,0 +1,97 @@ +# No softmax +params: + seed: 0 + + algo: + name: im_amp + + model: + name: amp + + network: + name: amp_mcp + separate: True + discrete: false + has_softmax: False + + space: + continuous: + mu_activation: None + sigma_activation: None + mu_init: + name: default + sigma_init: + name: const_initializer + val: -2.9 + fixed_sigma: True + learn_sigma: False + + mlp: + units: [1024, 512] + activation: relu + d2rl: False + + initializer: + name: default + regularizer: + name: None + + disc: + units: [1024, 512] + activation: relu + + initializer: + name: default + + load_checkpoint: False + + config: + name: Humanoid + env_name: rlgpu + multi_gpu: False + ppo: True + mixed_precision: False + normalize_input: True + normalize_value: True + reward_shaper: + scale_value: 1 + normalize_advantage: True + gamma: 0.99 + tau: 0.95 + learning_rate: 2e-5 + lr_schedule: constant + score_to_win: 20000 + max_epochs: 10000000 + save_best_after: 100 + save_frequency: 2500 + print_stats: False + save_intermediate: True + entropy_coef: 0.0 + truncate_grads: True + grad_norm: 50.0 + ppo: True + e_clip: 0.2 + horizon_length: 32 + minibatch_size: 16384 + mini_epochs: 6 + critic_coef: 5 + clip_value: False + + bounds_loss_coef: 10 + amp_obs_demo_buffer_size: 200000 + amp_replay_buffer_size: 200000 + amp_replay_keep_prob: 0.01 + amp_batch_size: 512 + amp_minibatch_size: 4096 + disc_coef: 5 + disc_logit_reg: 0.01 + disc_grad_penalty: 5 + disc_reward_scale: 2 + disc_weight_decay: 0.0001 + normalize_amp_input: True + + task_reward_w: 0.5 + disc_reward_w: 0.5 + + player: + games_num: 999999999999999999999999 \ No newline at end of file diff --git a/phc/data/cfg/train/rlg/im_pnn.yaml b/phc/data/cfg/train/rlg/im_pnn.yaml new file mode 100644 index 0000000..1f953da --- /dev/null +++ b/phc/data/cfg/train/rlg/im_pnn.yaml @@ -0,0 +1,95 @@ +params: + seed: 0 + + algo: + name: im_amp + + model: + name: amp + + network: + name: amp_pnn + separate: True + discrete: False + + space: + continuous: + mu_activation: None + sigma_activation: None + mu_init: + name: default + sigma_init: + name: const_initializer + val: -2.9 + fixed_sigma: True + learn_sigma: False + + mlp: + units: [1024, 512] + activation: relu + d2rl: False + + initializer: + name: default + regularizer: + name: None + + disc: + units: [1024, 512] + activation: relu + + initializer: + name: default + + load_checkpoint: False + + config: + name: Humanoid + env_name: rlgpu + multi_gpu: False + ppo: True + mixed_precision: False + normalize_input: True + normalize_value: True + reward_shaper: + scale_value: 1 + normalize_advantage: True + gamma: 0.99 + tau: 0.95 + learning_rate: 2e-5 + lr_schedule: constant + score_to_win: 20000 + max_epochs: 10000000 + save_best_after: 100 + save_frequency: 2500 + print_stats: False + save_intermediate: True + entropy_coef: 0.0 + truncate_grads: True + grad_norm: 50.0 + ppo: True + e_clip: 0.2 + horizon_length: 32 + minibatch_size: 16384 + mini_epochs: 6 + critic_coef: 5 + clip_value: False + + bounds_loss_coef: 10 + amp_obs_demo_buffer_size: 200000 + amp_replay_buffer_size: 200000 + amp_replay_keep_prob: 0.01 + amp_batch_size: 512 + amp_minibatch_size: 4096 + disc_coef: 5 + disc_logit_reg: 0.01 + disc_grad_penalty: 5 + disc_reward_scale: 2 + disc_weight_decay: 0.0001 + normalize_amp_input: True + + task_reward_w: 0.5 + disc_reward_w: 0.5 + + player: + games_num: 50000000 \ No newline at end of file diff --git a/phc/data/cfg/train/rlg/im_pnn_big.yaml b/phc/data/cfg/train/rlg/im_pnn_big.yaml new file mode 100644 index 0000000..f049c8c --- /dev/null +++ b/phc/data/cfg/train/rlg/im_pnn_big.yaml @@ -0,0 +1,97 @@ +params: + seed: 0 + + algo: + name: im_amp + + model: + name: amp + + network: + name: amp_pnn + separate: True + discrete: False + + space: + continuous: + mu_activation: None + sigma_activation: None + mu_init: + name: default + sigma_init: + name: const_initializer + val: -2.9 + fixed_sigma: True + learn_sigma: False + + mlp: + units: [2048, 1536, 1024, 1024, 512, 512] # comparable paramter to z_big_task + activation: silu + d2rl: False + + initializer: + name: default + regularizer: + name: None + + disc: + units: [1024, 512] + activation: relu + + initializer: + name: default + + load_checkpoint: False + + config: + name: Humanoid + env_name: rlgpu + multi_gpu: False + ppo: True + mixed_precision: False + normalize_input: True + normalize_value: True + reward_shaper: + scale_value: 1 + normalize_advantage: True + gamma: 0.99 + tau: 0.95 + learning_rate: 2e-5 + lr_schedule: constant + score_to_win: 20000 + max_epochs: 10000000 + save_best_after: 100 + save_frequency: 1500 + print_stats: False + save_intermediate: True + entropy_coef: 0.0 + truncate_grads: True + grad_norm: 50.0 + ppo: True + e_clip: 0.2 + horizon_length: 32 + minibatch_size: 16384 + mini_epochs: 6 + critic_coef: 5 + clip_value: False + + bounds_loss_coef: 10 + amp_obs_demo_buffer_size: 200000 + amp_replay_buffer_size: 200000 + amp_replay_keep_prob: 0.01 + amp_batch_size: 512 + amp_minibatch_size: 4096 + disc_coef: 5 + disc_logit_reg: 0.01 + disc_grad_penalty: 5 + disc_reward_scale: 2 + disc_weight_decay: 0.0001 + normalize_amp_input: True + + task_reward_w: 0.5 + disc_reward_w: 0.5 + + amp_dropout: False + + player: + games_num: 50000000 \ No newline at end of file diff --git a/phc/env/__init__.py b/phc/env/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/phc/env/tasks/__init__.py b/phc/env/tasks/__init__.py new file mode 100644 index 0000000..d79b55c --- /dev/null +++ b/phc/env/tasks/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) 2018-2023, NVIDIA Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/phc/env/tasks/base_task.py b/phc/env/tasks/base_task.py new file mode 100644 index 0000000..db167ae --- /dev/null +++ b/phc/env/tasks/base_task.py @@ -0,0 +1,722 @@ +# Copyright (c) 2018-2023, NVIDIA Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import enum +import glob +import os +import sys +import pdb +import os.path as osp + +sys.path.append(os.getcwd()) + +import operator +from copy import deepcopy +import random + +from isaacgym import gymapi +from isaacgym.gymutil import get_property_setter_map, get_property_getter_map, get_default_setter_args, apply_random_samples, check_buckets, generate_random_samples +from isaacgym import gymtorch + +import numpy as np +import torch + +import imageio +from datetime import datetime +from phc.utils.flags import flags +from collections import defaultdict +import aiohttp, cv2, asyncio +import json +from collections import deque +import threading +from tqdm import tqdm + +# Base class for RL tasks +class BaseTask(): + + def __init__(self, cfg, enable_camera_sensors=False): + self.headless = cfg["headless"] + if self.headless == False and not flags.no_virtual_display: + from pyvirtualdisplay.smartdisplay import SmartDisplay + self.virtual_display = SmartDisplay(size=(1800, 990), visible=True) + self.virtual_display.start() + + self.gym = gymapi.acquire_gym() + self.paused = False + self.device_type = cfg.get("device_type", "cuda") + self.device_id = cfg.get("device_id", 0) + self.state_record = defaultdict(list) + + self.device = "cpu" + if self.device_type == "cuda" or self.device_type == "GPU": + self.device = "cuda" + ":" + str(self.device_id) + + # double check! + self.graphics_device_id = self.device_id + if enable_camera_sensors == False and self.headless == True: + self.graphics_device_id = -1 + # if flags.server_mode: + # self.graphics_device_id = self.device_id + + self.num_envs = cfg["env"]["num_envs"] + self.num_obs = cfg["env"]["numObservations"] + self.num_states = cfg["env"].get("numStates", 0) + self.num_actions = cfg["env"]["numActions"] + self.is_discrete = cfg["env"].get("is_discrete", False) + + self.control_freq_inv = cfg["env"].get("controlFrequencyInv", 1) + + # optimization flags for pytorch JIT + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + + # allocate buffers + self.obs_buf = torch.zeros((self.num_envs, self.num_obs), device=self.device, dtype=torch.float) + self.states_buf = torch.zeros((self.num_envs, self.num_states), device=self.device, dtype=torch.float) + self.rew_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.float) + self.reset_buf = torch.ones(self.num_envs, device=self.device, dtype=torch.long) + self.progress_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.long) + self.randomize_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.long) + self.extras = {} + + self.original_props = {} + self.dr_randomizations = {} + self.first_randomization = True + self.actor_params_generator = None + self.extern_actor_params = {} + for env_id in range(self.num_envs): + self.extern_actor_params[env_id] = None + + self.last_step = -1 + self.last_rand_step = -1 + + # create envs, sim and viewer + self.create_sim() + self.gym.prepare_sim(self.sim) + + # todo: read from config + self.enable_viewer_sync = True + self.viewer = None + + # if running with a viewer, set up keyboard shortcuts and camera + self.create_viewer() + if flags.server_mode: + # bgsk = threading.Thread(target=self.setup_video_client, daemon=True).start() + bgsk = threading.Thread(target=self.setup_talk_client, daemon=False).start() + + def create_viewer(self): + if self.headless == False: + # headless server mode will use the smart display + + # subscribe to keyboard shortcuts + camera_props = gymapi.CameraProperties() + camera_props.width = 1920 + camera_props.height = 1000 + self.viewer = self.gym.create_viewer(self.sim, camera_props) + self.gym.subscribe_viewer_keyboard_event(self.viewer, gymapi.KEY_ESCAPE, "QUIT") + self.gym.subscribe_viewer_keyboard_event(self.viewer, gymapi.KEY_V, "toggle_viewer_sync") + self.gym.subscribe_viewer_keyboard_event(self.viewer, gymapi.KEY_L, "toggle_video_record") + self.gym.subscribe_viewer_keyboard_event(self.viewer, gymapi.KEY_SEMICOLON, "cancel_video_record") + self.gym.subscribe_viewer_keyboard_event(self.viewer, gymapi.KEY_R, "reset") + self.gym.subscribe_viewer_keyboard_event(self.viewer, gymapi.KEY_F, "follow") + self.gym.subscribe_viewer_keyboard_event(self.viewer, gymapi.KEY_G, "fixed") + self.gym.subscribe_viewer_keyboard_event(self.viewer, gymapi.KEY_H, "divide_group") + self.gym.subscribe_viewer_keyboard_event(self.viewer, gymapi.KEY_C, "print_cam") + self.gym.subscribe_viewer_keyboard_event(self.viewer, gymapi.KEY_M, "disable_collision_reset") + self.gym.subscribe_viewer_keyboard_event(self.viewer, gymapi.KEY_B, "fixed_path") + self.gym.subscribe_viewer_keyboard_event(self.viewer, gymapi.KEY_N, "real_path") + self.gym.subscribe_viewer_keyboard_event(self.viewer, gymapi.KEY_K, "show_traj") + self.gym.subscribe_viewer_keyboard_event(self.viewer, gymapi.KEY_J, "apply_force") + self.gym.subscribe_viewer_keyboard_event(self.viewer, gymapi.KEY_LEFT, "prev_env") + self.gym.subscribe_viewer_keyboard_event(self.viewer, gymapi.KEY_RIGHT, "next_env") + self.gym.subscribe_viewer_keyboard_event(self.viewer, gymapi.KEY_T, "resample_motion") + self.gym.subscribe_viewer_keyboard_event(self.viewer, gymapi.KEY_Y, "slow_traj") + self.gym.subscribe_viewer_keyboard_event(self.viewer, gymapi.KEY_I, "trigger_input") + self.gym.subscribe_viewer_keyboard_event(self.viewer, gymapi.KEY_P, "show_progress") + self.gym.subscribe_viewer_keyboard_event(self.viewer, gymapi.KEY_O, "change_color") + + self.gym.subscribe_viewer_keyboard_event(self.viewer, gymapi.KEY_SPACE, "PAUSE") + + # set the camera position based on up axis + sim_params = self.gym.get_sim_params(self.sim) + if sim_params.up_axis == gymapi.UP_AXIS_Z: + cam_pos = gymapi.Vec3(20.0, 25.0, 3.0) + cam_target = gymapi.Vec3(10.0, 15.0, 0.0) + else: + cam_pos = gymapi.Vec3(20.0, 3.0, 25.0) + cam_target = gymapi.Vec3(10.0, 0.0, 15.0) + + self.gym.viewer_camera_look_at(self.viewer, None, cam_pos, cam_target) + + ###### Custom Camera Sensors ###### + self.recorder_camera_handles = [] + self.max_num_camera = 10 + self.viewing_env_idx = 0 + for idx, env in enumerate(self.envs): + self.recorder_camera_handles.append(self.gym.create_camera_sensor(env, gymapi.CameraProperties())) + if idx > self.max_num_camera: + break + + self.recorder_camera_handle = self.recorder_camera_handles[0] + self.recording, self.recording_state_change = False, False + self.max_video_queue_size = 100000 + self._video_queue = deque(maxlen=self.max_video_queue_size) + rendering_out = osp.join("output", "renderings") + states_out = osp.join("output", "states") + os.makedirs(rendering_out, exist_ok=True) + os.makedirs(states_out, exist_ok=True) + self.cfg_name = self.cfg.exp_name + self._video_path = osp.join(rendering_out, f"{self.cfg_name}-%s.mp4") + self._states_path = osp.join(states_out, f"{self.cfg_name}-%s.pkl") + # self.gym.draw_env_rigid_contacts(self.viewer, self.envs[1], gymapi.Vec3(0.9, 0.3, 0.3), 1.0, True) + + # set gravity based on up axis and return axis index + def set_sim_params_up_axis(self, sim_params, axis): + if axis == 'z': + sim_params.up_axis = gymapi.UP_AXIS_Z + sim_params.gravity.x = 0 + sim_params.gravity.y = 0 + sim_params.gravity.z = -9.81 + return 2 + return 1 + + def create_sim(self, compute_device, graphics_device, physics_engine, sim_params): + sim = self.gym.create_sim(compute_device, graphics_device, physics_engine, sim_params) + if sim is None: + print("*** Failed to create sim") + quit() + + return sim + + def step(self, actions): + if self.dr_randomizations.get('actions', None): + actions = self.dr_randomizations['actions']['noise_lambda'](actions) + # apply actions + self.pre_physics_step(actions) + + # step physics and render each frame + self._physics_step() + + # to fix! + if self.device == 'cpu': + self.gym.fetch_results(self.sim, True) + + # compute observations, rewards, resets, ... + self.post_physics_step() + + + if self.dr_randomizations.get('observations', None): + self.obs_buf = self.dr_randomizations['observations']['noise_lambda'](self.obs_buf) + + def get_states(self): + return self.states_buf + + def _clear_recorded_states(self): + pass + + def _record_states(self): + pass + + def _write_states_to_file(self, file_name): + pass + + def setup_video_client(self): + loop = asyncio.new_event_loop() # <-- create new loop in this thread here + asyncio.set_event_loop(loop) + loop.run_until_complete(self.video_stream()) + loop.run_forever() + + def setup_talk_client(self): + loop = asyncio.new_event_loop() # <-- create new loop in this thread here + asyncio.set_event_loop(loop) + loop.run_until_complete(self.talk()) + loop.run_forever() + + #print(URL) + async def talk(self): + URL = 'http://klab-cereal.pc.cs.cmu.edu:8080/ws' + print("Starting websocket client") + session = aiohttp.ClientSession() + async with session.ws_connect(URL) as ws: + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + if msg.data == 'close cmd': + await ws.close() + break + else: + print(msg.data) + try: + msg = json.loads(msg.data) + if msg['action'] == 'reset': + self.reset() + elif msg['action'] == 'start_record': + if self.recording: + print("Already recording") + else: + self.recording = True + self.recording_state_change = True + elif msg['action'] == 'end_record': + if not self.recording: + print("Not recording") + else: + self.recording = False + self.recording_state_change = True + elif msg['action'] == 'set_env': + query = msg['query'] + env_id = query['env'] + self.viewing_env_idx = int(env_id) + print("view env idx: ", self.viewing_env_idx) + except: + import ipdb + ipdb.set_trace() + print("error parsing server message") + elif msg.type == aiohttp.WSMsgType.CLOSED: + break + elif msg.type == aiohttp.WSMsgType.ERROR: + break + + #print(URL) + async def video_stream(self): + URL = 'http://klab-cereal.pc.cs.cmu.edu:8080/ws' + print("Starting websocket client") + session = aiohttp.ClientSession() + async with session.ws_connect(URL) as ws: + await ws.send_str("Start") + while True: + if "color_image" in self.__dict__ and not self.color_image is None and len(self.color_image.shape) == 3: + image = cv2.resize(self.color_image, (800, 450), interpolation=cv2.INTER_AREA) + await ws.send_bytes(image.tobytes()) + else: + print("no image yet") + await asyncio.sleep(1) + + def render(self, sync_frame_time=False): + if self.viewer: + # check for window closed + if self.gym.query_viewer_has_closed(self.viewer): + sys.exit() + + # check for keyboard events + for evt in self.gym.query_viewer_action_events(self.viewer): + + if evt.action == "QUIT" and evt.value > 0: + sys.exit() + if evt.action == "PAUSE" and evt.value > 0: + self.paused = not self.paused + + elif evt.action == "toggle_viewer_sync" and evt.value > 0: + self.enable_viewer_sync = not self.enable_viewer_sync + elif evt.action == "toggle_video_record" and evt.value > 0: + self.recording = not self.recording + self.recording_state_change = True + elif evt.action == "cancel_video_record" and evt.value > 0: + self.recording = False + self.recording_state_change = False + self._video_queue = deque(maxlen=self.max_video_queue_size) + self._clear_recorded_states() + elif evt.action == "reset" and evt.value > 0: + self.reset() + elif evt.action == "follow" and evt.value > 0: + flags.follow = not flags.follow + elif evt.action == "fixed" and evt.value > 0: + flags.fixed = not flags.fixed + elif evt.action == "divide_group" and evt.value > 0: + flags.divide_group = not flags.divide_group + elif evt.action == "print_cam" and evt.value > 0: + cam_trans = self.gym.get_viewer_camera_transform(self.viewer, None) + cam_pos = np.array([cam_trans.p.x, cam_trans.p.y, cam_trans.p.z]) + print("Print camera", cam_pos) + elif evt.action == "disable_collision_reset" and evt.value > 0: + flags.no_collision_check = not flags.no_collision_check + print("collision_reset: ", flags.no_collision_check) + elif evt.action == "fixed_path" and evt.value > 0: + flags.fixed_path = not flags.fixed_path + print("fixed_path: ", flags.fixed_path) + elif evt.action == "real_path" and evt.value > 0: + flags.real_path = not flags.real_path + print("real_path: ", flags.real_path) + elif evt.action == "show_traj" and evt.value > 0: + flags.show_traj = not flags.show_traj + print("show_traj: ", flags.show_traj) + elif evt.action == "trigger_input" and evt.value > 0: + flags.trigger_input = not flags.trigger_input + self.change_char_color() + print("show_traj: ", flags.show_traj) + elif evt.action == "show_progress" and evt.value > 0: + print("Progress ", self.progress_buf) + elif evt.action == "apply_force" and evt.value > 0: + forces = torch.zeros((1, self._rigid_body_state.shape[0], 3), device=self.device, dtype=torch.float) + torques = torch.zeros((1, self._rigid_body_state.shape[0], 3), device=self.device, dtype=torch.float) + # forces[:, 8, :] = -800 + for i in range(self._rigid_body_state.shape[0] // self.num_bodies): + forces[:, i * self.num_bodies + 3, :] = -3500 + forces[:, i * self.num_bodies + 7, :] = -3500 + # torques[:, 1, :] = 500 + + self.gym.apply_rigid_body_force_tensors(self.sim, gymtorch.unwrap_tensor(forces), gymtorch.unwrap_tensor(torques), gymapi.ENV_SPACE) + + elif evt.action == "prev_env" and evt.value > 0: + self.viewing_env_idx = (self.viewing_env_idx - 1) % self.num_envs + flags.idx -= 1; print(flags.idx) + + # self.recorder_camera_handle = self.recorder_camera_handles[self.viewing_env_idx] + print("\nShowing env: ", self.viewing_env_idx, flags.idx) + elif evt.action == "next_env" and evt.value > 0: + self.viewing_env_idx = (self.viewing_env_idx + 1) % self.num_envs + flags.idx += 1; + # self.recorder_camera_handle = self.recorder_camera_handles[self.viewing_env_idx] + print("\nShowing env: ", self.viewing_env_idx, flags.idx) + elif evt.action == "resample_motion" and evt.value > 0: + self.resample_motions() + + elif evt.action == "slow_traj" and evt.value > 0: + flags.slow = not flags.slow + print("slow_traj: ", flags.slow) + + elif evt.action == "change_color" and evt.value > 0: + self.change_char_color() + print("Change character color") + + if self.recording_state_change: + if not self.recording: + if not flags.server_mode: + self.writer.close() + del self.writer + + self._write_states_to_file(self.curr_states_file_name) + print(f"============ Video finished writing {self.curr_states_file_name}============") + + else: + print(f"============ Writing video ============") + self.recording_state_change = False + + if self.recording: + if not flags.server_mode: + if flags.no_virtual_display: + self.gym.render_all_camera_sensors(self.sim) + color_image = self.gym.get_camera_image(self.sim, self.envs[self.viewing_env_idx], self.recorder_camera_handles[self.viewing_env_idx], gymapi.IMAGE_COLOR) + self.color_image = color_image.reshape(color_image.shape[0], -1, 4) + else: + img = self.virtual_display.grab() + self.color_image = np.array(img) + if not "H" in self.__dict__: + H, W, C = self.color_image.shape + self.H = (H - H % 2) - 10 + self.W = (W - W % 2) - 10 + + self.color_image = self.color_image[:self.H, :self.W, :] + + if not flags.server_mode: + if not "writer" in self.__dict__: + curr_date_time = datetime.now().strftime('%Y-%m-%d-%H:%M:%S') + self.curr_video_file_name = self._video_path % curr_date_time + self.curr_states_file_name = self._states_path % curr_date_time + if not flags.server_mode: + self.writer = imageio.get_writer(self.curr_video_file_name, fps=60, macro_block_size=None) + self.writer.append_data(self.color_image) + + + self._record_states() + + # fetch results + if self.device != 'cpu': + self.gym.fetch_results(self.sim, True) + + # step graphics + if self.enable_viewer_sync: + self.gym.step_graphics(self.sim) + self.gym.draw_viewer(self.viewer, self.sim, True) + # self.gym.sync_frame_time(self.sim) + + else: + self.gym.poll_viewer_events(self.viewer) + + # else: + # if flags.server_mode: + # # headless server model only support rendering from one env + # self.gym.fetch_results(self.sim, True) + # self.gym.step_graphics(self.sim) + # self.gym.render_all_camera_sensors(self.sim) + # self.gym.start_access_image_tensors(self.sim) + + # # self.gym.get_viewer_camera_handle(self.viewer) + # color_image = self.gym.get_camera_image(self.sim, self.envs[self.viewing_env_idx], self.recorder_camera_handles[self.viewing_env_idx], gymapi.IMAGE_COLOR) + + # self.color_image = color_image.reshape(color_image.shape[0], -1, 4)[..., :3] + + # if self.recording: + # self._video_queue.append(self.color_image) + # self._record_states() + + + + + + def get_actor_params_info(self, dr_params, env): + """Returns a flat array of actor params, their names and ranges.""" + if "actor_params" not in dr_params: + return None + params = [] + names = [] + lows = [] + highs = [] + param_getters_map = get_property_getter_map(self.gym) + for actor, actor_properties in dr_params["actor_params"].items(): + handle = self.gym.find_actor_handle(env, actor) + for prop_name, prop_attrs in actor_properties.items(): + if prop_name == 'color': + continue # this is set randomly + props = param_getters_map[prop_name](env, handle) + if not isinstance(props, list): + props = [props] + for prop_idx, prop in enumerate(props): + for attr, attr_randomization_params in prop_attrs.items(): + name = prop_name + '_' + str(prop_idx) + '_' + attr + lo_hi = attr_randomization_params['range'] + distr = attr_randomization_params['distribution'] + if 'uniform' not in distr: + lo_hi = (-1.0 * float('Inf'), float('Inf')) + if isinstance(prop, np.ndarray): + for attr_idx in range(prop[attr].shape[0]): + params.append(prop[attr][attr_idx]) + names.append(name + '_' + str(attr_idx)) + lows.append(lo_hi[0]) + highs.append(lo_hi[1]) + else: + params.append(getattr(prop, attr)) + names.append(name) + lows.append(lo_hi[0]) + highs.append(lo_hi[1]) + return params, names, lows, highs + + # Apply randomizations only on resets, due to current PhysX limitations + def apply_randomizations(self, dr_params): + # If we don't have a randomization frequency, randomize every step + rand_freq = dr_params.get("frequency", 1) + + # First, determine what to randomize: + # - non-environment parameters when > frequency steps have passed since the last non-environment + # - physical environments in the reset buffer, which have exceeded the randomization frequency threshold + # - on the first call, randomize everything + self.last_step = self.gym.get_frame_count(self.sim) + if self.first_randomization: + do_nonenv_randomize = True + env_ids = list(range(self.num_envs)) + else: + do_nonenv_randomize = (self.last_step - self.last_rand_step) >= rand_freq + rand_envs = torch.where(self.randomize_buf >= rand_freq, torch.ones_like(self.randomize_buf), torch.zeros_like(self.randomize_buf)) + rand_envs = torch.logical_and(rand_envs, self.reset_buf) + env_ids = torch.nonzero(rand_envs, as_tuple=False).squeeze(-1).tolist() + self.randomize_buf[rand_envs] = 0 + + if do_nonenv_randomize: + self.last_rand_step = self.last_step + + param_setters_map = get_property_setter_map(self.gym) + param_setter_defaults_map = get_default_setter_args(self.gym) + param_getters_map = get_property_getter_map(self.gym) + + # On first iteration, check the number of buckets + if self.first_randomization: + check_buckets(self.gym, self.envs, dr_params) + + for nonphysical_param in ["observations", "actions"]: + if nonphysical_param in dr_params and do_nonenv_randomize: + dist = dr_params[nonphysical_param]["distribution"] + op_type = dr_params[nonphysical_param]["operation"] + sched_type = dr_params[nonphysical_param]["schedule"] if "schedule" in dr_params[nonphysical_param] else None + sched_step = dr_params[nonphysical_param]["schedule_steps"] if "schedule" in dr_params[nonphysical_param] else None + op = operator.add if op_type == 'additive' else operator.mul + + if sched_type == 'linear': + sched_scaling = 1.0 / sched_step * \ + min(self.last_step, sched_step) + elif sched_type == 'constant': + sched_scaling = 0 if self.last_step < sched_step else 1 + else: + sched_scaling = 1 + + if dist == 'gaussian': + mu, var = dr_params[nonphysical_param]["range"] + mu_corr, var_corr = dr_params[nonphysical_param].get("range_correlated", [0., 0.]) + + if op_type == 'additive': + mu *= sched_scaling + var *= sched_scaling + mu_corr *= sched_scaling + var_corr *= sched_scaling + elif op_type == 'scaling': + var = var * sched_scaling # scale up var over time + mu = mu * sched_scaling + 1.0 * \ + (1.0 - sched_scaling) # linearly interpolate + + var_corr = var_corr * sched_scaling # scale up var over time + mu_corr = mu_corr * sched_scaling + 1.0 * \ + (1.0 - sched_scaling) # linearly interpolate + + def noise_lambda(tensor, param_name=nonphysical_param): + params = self.dr_randomizations[param_name] + corr = params.get('corr', None) + if corr is None: + corr = torch.randn_like(tensor) + params['corr'] = corr + corr = corr * params['var_corr'] + params['mu_corr'] + return op(tensor, corr + torch.randn_like(tensor) * params['var'] + params['mu']) + + self.dr_randomizations[nonphysical_param] = {'mu': mu, 'var': var, 'mu_corr': mu_corr, 'var_corr': var_corr, 'noise_lambda': noise_lambda} + + elif dist == 'uniform': + lo, hi = dr_params[nonphysical_param]["range"] + lo_corr, hi_corr = dr_params[nonphysical_param].get("range_correlated", [0., 0.]) + + if op_type == 'additive': + lo *= sched_scaling + hi *= sched_scaling + lo_corr *= sched_scaling + hi_corr *= sched_scaling + elif op_type == 'scaling': + lo = lo * sched_scaling + 1.0 * (1.0 - sched_scaling) + hi = hi * sched_scaling + 1.0 * (1.0 - sched_scaling) + lo_corr = lo_corr * sched_scaling + 1.0 * (1.0 - sched_scaling) + hi_corr = hi_corr * sched_scaling + 1.0 * (1.0 - sched_scaling) + + def noise_lambda(tensor, param_name=nonphysical_param): + params = self.dr_randomizations[param_name] + corr = params.get('corr', None) + if corr is None: + corr = torch.randn_like(tensor) + params['corr'] = corr + corr = corr * (params['hi_corr'] - params['lo_corr']) + params['lo_corr'] + return op(tensor, corr + torch.rand_like(tensor) * (params['hi'] - params['lo']) + params['lo']) + + self.dr_randomizations[nonphysical_param] = {'lo': lo, 'hi': hi, 'lo_corr': lo_corr, 'hi_corr': hi_corr, 'noise_lambda': noise_lambda} + + if "sim_params" in dr_params and do_nonenv_randomize: + prop_attrs = dr_params["sim_params"] + prop = self.gym.get_sim_params(self.sim) + + if self.first_randomization: + self.original_props["sim_params"] = {attr: getattr(prop, attr) for attr in dir(prop)} + + for attr, attr_randomization_params in prop_attrs.items(): + apply_random_samples(prop, self.original_props["sim_params"], attr, attr_randomization_params, self.last_step) + + self.gym.set_sim_params(self.sim, prop) + + # If self.actor_params_generator is initialized: use it to + # sample actor simulation params. This gives users the + # freedom to generate samples from arbitrary distributions, + # e.g. use full-covariance distributions instead of the DR's + # default of treating each simulation parameter independently. + extern_offsets = {} + if self.actor_params_generator is not None: + for env_id in env_ids: + self.extern_actor_params[env_id] = \ + self.actor_params_generator.sample() + extern_offsets[env_id] = 0 + + for actor, actor_properties in dr_params["actor_params"].items(): + for env_id in env_ids: + env = self.envs[env_id] + handle = self.gym.find_actor_handle(env, actor) + extern_sample = self.extern_actor_params[env_id] + + for prop_name, prop_attrs in actor_properties.items(): + if prop_name == 'color': + num_bodies = self.gym.get_actor_rigid_body_count(env, handle) + for n in range(num_bodies): + self.gym.set_rigid_body_color(env, handle, n, gymapi.MESH_VISUAL, gymapi.Vec3(random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1))) + continue + if prop_name == 'scale': + attr_randomization_params = prop_attrs + sample = generate_random_samples(attr_randomization_params, 1, self.last_step, None) + og_scale = 1 + if attr_randomization_params['operation'] == 'scaling': + new_scale = og_scale * sample + elif attr_randomization_params['operation'] == 'additive': + new_scale = og_scale + sample + self.gym.set_actor_scale(env, handle, new_scale) + continue + + prop = param_getters_map[prop_name](env, handle) + if isinstance(prop, list): + if self.first_randomization: + self.original_props[prop_name] = [{attr: getattr(p, attr) for attr in dir(p)} for p in prop] + for p, og_p in zip(prop, self.original_props[prop_name]): + for attr, attr_randomization_params in prop_attrs.items(): + smpl = None + if self.actor_params_generator is not None: + smpl, extern_offsets[env_id] = get_attr_val_from_sample(extern_sample, extern_offsets[env_id], p, attr) + apply_random_samples(p, og_p, attr, attr_randomization_params, self.last_step, smpl) + else: + if self.first_randomization: + self.original_props[prop_name] = deepcopy(prop) + for attr, attr_randomization_params in prop_attrs.items(): + smpl = None + if self.actor_params_generator is not None: + smpl, extern_offsets[env_id] = get_attr_val_from_sample(extern_sample, extern_offsets[env_id], prop, attr) + apply_random_samples(prop, self.original_props[prop_name], attr, attr_randomization_params, self.last_step, smpl) + + setter = param_setters_map[prop_name] + default_args = param_setter_defaults_map[prop_name] + setter(env, handle, prop, *default_args) + + if self.actor_params_generator is not None: + for env_id in env_ids: # check that we used all dims in sample + if extern_offsets[env_id] > 0: + extern_sample = self.extern_actor_params[env_id] + if extern_offsets[env_id] != extern_sample.shape[0]: + print('env_id', env_id, 'extern_offset', extern_offsets[env_id], 'vs extern_sample.shape', extern_sample.shape) + raise Exception("Invalid extern_sample size") + + self.first_randomization = False + + def pre_physics_step(self, actions): + raise NotImplementedError + + def _physics_step(self): + for i in range(self.control_freq_inv): + self.render() + + if not self.paused and self.enable_viewer_sync: + self.gym.simulate(self.sim) + return + + def post_physics_step(self): + raise NotImplementedError + + +def get_attr_val_from_sample(sample, offset, prop, attr): + """Retrieves param value for the given prop and attr from the sample.""" + if sample is None: + return None, 0 + if isinstance(prop, np.ndarray): + smpl = sample[offset:offset + prop[attr].shape[0]] + return smpl, offset + prop[attr].shape[0] + else: + return sample[offset], offset + 1 diff --git a/phc/env/tasks/humanoid.py b/phc/env/tasks/humanoid.py new file mode 100644 index 0000000..cde30e2 --- /dev/null +++ b/phc/env/tasks/humanoid.py @@ -0,0 +1,1849 @@ +# Copyright (c) 2018-2023, NVIDIA Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from uuid import uuid4 +import numpy as np +import os + +import torch +import multiprocessing + +from isaacgym import gymtorch +from isaacgym import gymapi +from isaacgym.torch_utils import * +import joblib +from phc.utils import torch_utils + +from smpl_sim.smpllib.smpl_joint_names import SMPL_MUJOCO_NAMES, SMPLH_MUJOCO_NAMES +from smpl_sim.smpllib.smpl_local_robot import SMPL_Robot + +from phc.utils.flags import flags +from phc.env.tasks.base_task import BaseTask +from tqdm import tqdm +from poselib.poselib.skeleton.skeleton3d import SkeletonTree +from collections import defaultdict +from poselib.poselib.skeleton.skeleton3d import SkeletonMotion, SkeletonState +from scipy.spatial.transform import Rotation as sRot +import gc +import torch.multiprocessing as mp +from phc.utils.draw_utils import agt_color, get_color_gradient + + +ENABLE_MAX_COORD_OBS = True +# PERTURB_OBJS = [ +# ["small", 60], +# ["small", 7], +# ["small", 10], +# ["small", 35], +# ["small", 2], +# ["small", 2], +# ["small", 3], +# ["small", 2], +# ["small", 2], +# ["small", 3], +# ["small", 2], +# ["large", 60], +# ["small", 300], +# ] +PERTURB_OBJS = [ + ["small", 60], + # ["large", 60], +] + + +class Humanoid(BaseTask): + + def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, headless): + self.cfg = cfg + self.sim_params = sim_params + self.physics_engine = physics_engine + self.has_task = False + + self.load_humanoid_configs(cfg) + + self.control_mode = self.cfg["env"]["control_mode"] + if self.control_mode in ['isaac_pd']: + self._pd_control = True + else: + self._pd_control = False + self.power_scale = self.cfg["env"]["power_scale"] + + self.debug_viz = self.cfg["env"]["enable_debug_vis"] + self.plane_static_friction = self.cfg["env"]["plane"]["staticFriction"] + self.plane_dynamic_friction = self.cfg["env"]["plane"]["dynamicFriction"] + self.plane_restitution = self.cfg["env"]["plane"]["restitution"] + + self.max_episode_length = self.cfg["env"]["episode_length"] + self._local_root_obs = self.cfg["env"]["local_root_obs"] + self._root_height_obs = self.cfg["env"].get("root_height_obs", True) + self._enable_early_termination = self.cfg["env"]["enableEarlyTermination"] + self.temp_running_mean = self.cfg["env"].get("temp_running_mean", True) + self.partial_running_mean = self.cfg["env"].get("partial_running_mean", False) + self.self_obs_v = self.cfg["env"].get("self_obs_v", 1) + + self.key_bodies = self.cfg["env"]["key_bodies"] + + self._setup_character_props(self.key_bodies) + + self.cfg["env"]["numObservations"] = self.get_obs_size() + self.cfg["env"]["numActions"] = self.get_action_size() + self.cfg["device_type"] = device_type + self.cfg["device_id"] = device_id + self.cfg["headless"] = headless + + + super().__init__(cfg=self.cfg) + + self.dt = self.control_freq_inv * sim_params.dt + self._setup_tensors() + self.self_obs_buf = torch.zeros((self.num_envs, self.get_self_obs_size()), device=self.device, dtype=torch.float) + self.reward_raw = torch.zeros((self.num_envs, 1)).to(self.device) + + return + + def _load_proj_asset(self): + asset_root = "phc/data/assets/urdf/" + + small_asset_file = "block_projectile.urdf" + # small_asset_file = "ball_medium.urdf" + small_asset_options = gymapi.AssetOptions() + small_asset_options.angular_damping = 0.01 + small_asset_options.linear_damping = 0.01 + small_asset_options.max_angular_velocity = 100.0 + small_asset_options.density = 10000000.0 + # small_asset_options.fix_base_link = True + small_asset_options.default_dof_drive_mode = gymapi.DOF_MODE_NONE + self._small_proj_asset = self.gym.load_asset(self.sim, asset_root, small_asset_file, small_asset_options) + + large_asset_file = "block_projectile_large.urdf" + large_asset_options = gymapi.AssetOptions() + large_asset_options.angular_damping = 0.01 + large_asset_options.linear_damping = 0.01 + large_asset_options.max_angular_velocity = 100.0 + large_asset_options.density = 10000000.0 + # large_asset_options.fix_base_link = True + large_asset_options.default_dof_drive_mode = gymapi.DOF_MODE_NONE + self._large_proj_asset = self.gym.load_asset(self.sim, asset_root, large_asset_file, large_asset_options) + return + + def _build_proj(self, env_id, env_ptr): + pos = [ + [-0.01, 0.3, 0.4], + # [ 0.0890016, -0.40830246, 0.25] + ] + for i, obj in enumerate(PERTURB_OBJS): + default_pose = gymapi.Transform() + default_pose.p.x = pos[i][0] + default_pose.p.y = pos[i][1] + default_pose.p.z = pos[i][2] + obj_type = obj[0] + if (obj_type == "small"): + proj_asset = self._small_proj_asset + elif (obj_type == "large"): + proj_asset = self._large_proj_asset + + proj_handle = self.gym.create_actor(env_ptr, proj_asset, default_pose, "proj{:d}".format(i), env_id, 2) + self._proj_handles.append(proj_handle) + + return + + def _setup_tensors(self): + # get gym GPU state tensors + actor_root_state = self.gym.acquire_actor_root_state_tensor(self.sim) + dof_state_tensor = self.gym.acquire_dof_state_tensor(self.sim) + sensor_tensor = self.gym.acquire_force_sensor_tensor(self.sim) + rigid_body_state = self.gym.acquire_rigid_body_state_tensor(self.sim) + contact_force_tensor = self.gym.acquire_net_contact_force_tensor(self.sim) + + # ZL: needs to put this back + if self.self_obs_v == 3: + sensors_per_env = len(self.force_sensor_joints) + self.vec_sensor_tensor = gymtorch.wrap_tensor(sensor_tensor).view(self.num_envs, sensors_per_env * 6) + + + dof_force_tensor = self.gym.acquire_dof_force_tensor(self.sim) + self.dof_force_tensor = gymtorch.wrap_tensor(dof_force_tensor).view(self.num_envs, self.num_dof) + + self.gym.refresh_dof_state_tensor(self.sim) + self.gym.refresh_actor_root_state_tensor(self.sim) + self.gym.refresh_rigid_body_state_tensor(self.sim) + self.gym.refresh_net_contact_force_tensor(self.sim) + + self._root_states = gymtorch.wrap_tensor(actor_root_state) + num_actors = self.get_num_actors_per_env() + + self._humanoid_root_states = self._root_states.view(self.num_envs, num_actors, actor_root_state.shape[-1])[..., 0, :] + self._initial_humanoid_root_states = self._humanoid_root_states.clone() + self._initial_humanoid_root_states[:, 7:13] = 0 + + self._humanoid_actor_ids = num_actors * torch.arange(self.num_envs, device=self.device, dtype=torch.int32) + + # create some wrapper tensors for different slices + self._dof_state = gymtorch.wrap_tensor(dof_state_tensor) + dofs_per_env = self._dof_state.shape[0] // self.num_envs + self._dof_pos = self._dof_state.view(self.num_envs, dofs_per_env, 2)[..., :self.num_dof, 0] + self._dof_vel = self._dof_state.view(self.num_envs, dofs_per_env, 2)[..., :self.num_dof, 1] + + self._initial_dof_pos = torch.zeros_like(self._dof_pos, device=self.device, dtype=torch.float) + self._initial_dof_vel = torch.zeros_like(self._dof_vel, device=self.device, dtype=torch.float) + + self._rigid_body_state = gymtorch.wrap_tensor(rigid_body_state) + bodies_per_env = self._rigid_body_state.shape[0] // self.num_envs + self._rigid_body_state_reshaped = self._rigid_body_state.view(self.num_envs, bodies_per_env, 13) + + self._rigid_body_pos = self._rigid_body_state_reshaped[..., :self.num_bodies, 0:3] + self._rigid_body_rot = self._rigid_body_state_reshaped[..., :self.num_bodies, 3:7] + self._rigid_body_vel = self._rigid_body_state_reshaped[..., :self.num_bodies, 7:10] + self._rigid_body_ang_vel = self._rigid_body_state_reshaped[..., :self.num_bodies, 10:13] + + if self.self_obs_v == 2: + self._rigid_body_pos_hist = torch.zeros((self.num_envs, self.past_track_steps, self.num_bodies, 3), device=self.device, dtype=torch.float) + self._rigid_body_rot_hist = torch.zeros((self.num_envs, self.past_track_steps, self.num_bodies, 4), device=self.device, dtype=torch.float) + self._rigid_body_vel_hist = torch.zeros((self.num_envs, self.past_track_steps, self.num_bodies, 3), device=self.device, dtype=torch.float) + self._rigid_body_ang_vel_hist = torch.zeros((self.num_envs, self.past_track_steps, self.num_bodies, 3), device=self.device, dtype=torch.float) + + contact_force_tensor = gymtorch.wrap_tensor(contact_force_tensor) + self._contact_forces = contact_force_tensor.view(self.num_envs, bodies_per_env, 3)[..., :self.num_bodies, :] + + self._terminate_buf = torch.ones(self.num_envs, device=self.device, dtype=torch.long) + + self._build_termination_heights() + + contact_bodies = self.cfg["env"]["contact_bodies"] + self._key_body_ids = self._build_key_body_ids_tensor(self.key_bodies) + + self._contact_body_ids = self._build_contact_body_ids_tensor(contact_bodies) + + if self.viewer != None or flags.server_mode: + self._init_camera() + + + def load_humanoid_configs(self, cfg): + self.humanoid_type = cfg.robot.humanoid_type + if self.humanoid_type in ["smpl", "smplh", "smplx"]: + self.load_smpl_configs(cfg) + else: + raise NotImplementedError + + + def load_common_humanoid_configs(self, cfg): + self._divide_group = cfg["env"].get("divide_group", False) + self._group_obs = cfg["env"].get("group_obs", False) + self._disable_group_obs = cfg["env"].get("disable_group_obs", False) + if self._divide_group: + self._group_num_people = group_num_people = min(cfg['env'].get("num_env_group", 128), cfg['env']['num_envs']) + self._group_ids = torch.tensor(np.arange(cfg["env"]["num_envs"] / group_num_people).repeat(group_num_people).astype(int)) + + self.force_sensor_joints = cfg["env"].get("force_sensor_joints", ["L_Ankle", "R_Ankle"]) # force tensor joints + + ##### Robot Configs ##### + self._has_shape_obs = cfg.robot.get("has_shape_obs", False) + self._has_shape_obs_disc = cfg.robot.get("has_shape_obs_disc", False) + self._has_limb_weight_obs = cfg.robot.get("has_weight_obs", False) + self._has_limb_weight_obs_disc = cfg.robot.get("has_weight_obs_disc", False) + self.has_shape_variation = cfg.robot.get("has_shape_variation", False) + self._bias_offset = cfg.robot.get("bias_offset", False) + self._has_self_collision = cfg.robot.get("has_self_collision", False) + self._has_mesh = cfg.robot.get("has_mesh", True) + self._replace_feet = cfg.robot.get("replace_feet", True) # replace feet or not + self._has_jt_limit = cfg.robot.get("has_jt_limit", True) + self._has_dof_subset = cfg.robot.get("has_dof_subset", False) + self._has_smpl_pd_offset = cfg.robot.get("has_smpl_pd_offset", False) + self._masterfoot = cfg.robot.get("masterfoot", False) + self._freeze_toe = cfg.robot.get("freeze_toe", True) + ##### Robot Configs ##### + + + self.shape_resampling_interval = cfg["env"].get("shape_resampling_interval", 100) + self.getup_schedule = cfg["env"].get("getup_schedule", False) + self._kp_scale = cfg["env"].get("kp_scale", 1.0) + self._kd_scale = cfg["env"].get("kd_scale", self._kp_scale) + + self.hard_negative = cfg["env"].get("hard_negative", False) # hard negative sampling for im + self.cycle_motion = cfg["env"].get("cycle_motion", False) # Cycle motion to reach 300 + self.power_reward = cfg["env"].get("power_reward", False) + self.obs_v = cfg["env"].get("obs_v", 1) + self.amp_obs_v = cfg["env"].get("amp_obs_v", 1) + + + ## Kin stuff + self.save_kin_info = cfg["env"].get("save_kin_info", False) + self.only_kin_loss = cfg["env"].get("only_kin_loss", False) + self.kin_policy = cfg["env"].get("kin_policy", False) + self.kin_lr = cfg["env"].get("kin_lr", 5e-4) + self.z_readout = cfg["env"].get("z_readout", False) + self.z_read = cfg["env"].get("z_read", False) + self.z_uniform = cfg["env"].get("z_uniform", False) + self.z_model = cfg["env"].get("z_model", False) + self.distill = cfg["env"].get("distill", False) + + self.remove_disc_rot = cfg["env"].get("remove_disc_rot", False) + + ## ZL Devs + #################### Devs #################### + self.fitting = cfg["env"].get("fitting", False) + self.zero_out_far = cfg["env"].get("zero_out_far", False) + self.zero_out_far_train = cfg["env"].get("zero_out_far_train", True) + self.max_len = cfg["env"].get("max_len", -1) + self.cycle_motion_xp = cfg["env"].get("cycle_motion_xp", False) # Cycle motion, but cycle farrrrr. + self.models_path = cfg["env"].get("models", ['output/dgx/smpl_im_fit_3_1/Humanoid_00185000.pth', 'output/dgx/smpl_im_fit_3_2/Humanoid_00198750.pth']) + + self.eval_full = cfg["env"].get("eval_full", False) + self.auto_pmcp = cfg["env"].get("auto_pmcp", False) + self.auto_pmcp_soft = cfg["env"].get("auto_pmcp_soft", False) + self.strict_eval = cfg["env"].get("strict_eval", False) + self.add_obs_noise = cfg["env"].get("add_obs_noise", False) + + self._occl_training = cfg["env"].get("occl_training", False) # Cycle motion, but cycle farrrrr. + self._occl_training_prob = cfg["env"].get("occl_training_prob", 0.1) # Cycle motion, but cycle farrrrr. + self._sim_occlu = False + self._res_action = cfg["env"].get("res_action", False) + self.close_distance = cfg["env"].get("close_distance", 0.25) + self.far_distance = cfg["env"].get("far_distance", 3) + self._zero_out_far_steps = cfg["env"].get("zero_out_far_steps", 90) + self.past_track_steps = cfg["env"].get("past_track_steps", 5) + #################### Devs #################### + + ######################################################################## + # Z reader + self.vae_reader = cfg["env"].get("vae_reader", False) + self.z_type = cfg["env"].get("z_type", None) + self.kld_coefficient = cfg["env"].get("kld_coefficient", 0.01) + self.kld_coefficient_min = cfg["env"].get("kld_coefficient_min", 0.001) + self.ar1_coefficient = cfg["env"].get("ar1_coefficient", 0.005) + self.kld_anneal = cfg["env"].get("kld_anneal", True) + self.use_ar1_prior = cfg["env"].get("use_ar1_prior", False) + self.use_vae_prior = cfg["env"].get("use_vae_prior", False) + self.vae_prior_policy = cfg["env"].get("vae_prior_policy", False) + self.use_vae_prior_regu = cfg["env"].get("use_vae_prior_regu", False) + self.use_vae_fixed_prior = cfg["env"].get("use_vae_fixed_prior", False) + self.use_vae_prior_loss = cfg['env'].get("use_vae_prior_loss", False) + self.use_vae_sphere_prior = cfg['env'].get("use_vae_sphere_prior", False) + self.use_vae_sphere_posterior = cfg['env'].get("use_vae_sphere_posterior", False) + ######################################################################## + + def load_smpl_configs(self, cfg): + self.load_common_humanoid_configs(cfg) + + ##### Robot Configs ##### + self._has_upright_start = cfg.robot.get("has_upright_start", True) + self.remove_toe = cfg.robot.get("remove_toe", False) + self.big_ankle = cfg.robot.get("big_ankle", False) + self._real_weight_porpotion_capsules = cfg.robot.get("real_weight_porpotion_capsules", False) + self._real_weight_porpotion_boxes = cfg.robot.get("real_weight_porpotion_boxes", False) + self._real_weight = cfg.robot.get("real_weight", False) + self._master_range = cfg.robot.get("master_range", 30) + self._freeze_toe = cfg.robot.get("freeze_toe", True) + self._freeze_hand = cfg.robot.get("freeze_hand", True) + self._box_body = cfg.robot.get("box_body", False) + self.reduce_action = cfg.robot.get("reduce_action", False) + + + if self._masterfoot: + self.action_idx = [0, 1, 2, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 25, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 54, 55, 60, 61, 62, 65, 66, 67, 68, 75, 76, 77, 80, 81, 82, 83] + else: + self.action_idx = [0, 1, 2, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 36, 37, 42, 43, 44, 47, 48, 49, 50, 57, 58, 59, 62, 63, 64, 65] + + disc_idxes = [] + if self.humanoid_type == "smpl": + self._body_names_orig = SMPL_MUJOCO_NAMES + elif self.humanoid_type in ["smplh", "smplx"]: + self._body_names_orig = SMPLH_MUJOCO_NAMES + + self._full_track_bodies = self._body_names_orig.copy() + + _body_names_orig_copy = self._body_names_orig.copy() + _body_names_orig_copy.remove('L_Toe') # Following UHC as hand and toes does not have realiable data. + _body_names_orig_copy.remove('R_Toe') + if self.humanoid_type == "smpl": + _body_names_orig_copy.remove('L_Hand') + _body_names_orig_copy.remove('R_Hand') + + self._eval_bodies = _body_names_orig_copy # default eval bodies + + self._body_names = self._body_names_orig + self._masterfoot_config = None + + self._dof_names = self._body_names[1:] + + + if self.humanoid_type == "smpl": + remove_names = ["L_Hand", "R_Hand", "L_Toe", "R_Toe"] + self.limb_weight_group = [ + ['L_Hip', 'L_Knee', 'L_Ankle', 'L_Toe'], \ + ['R_Hip', 'R_Knee', 'R_Ankle', 'R_Toe'], \ + ['Pelvis', 'Torso', 'Spine', 'Chest', 'Neck', 'Head'], \ + [ 'L_Thorax', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'L_Hand'], \ + ['R_Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'R_Hand']] + elif self.humanoid_type in ["smplh", "smplx"]: + remove_names = ["L_Toe", "R_Toe"] + self.limb_weight_group = [ + ['L_Hip', 'L_Knee', 'L_Ankle', 'L_Toe'], \ + ['R_Hip', 'R_Knee', 'R_Ankle', 'R_Toe'], \ + ['Pelvis', 'Torso', 'Spine', 'Chest', 'Neck', 'Head'], \ + [ 'L_Thorax', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'L_Index1', 'L_Index2', 'L_Index3', 'L_Middle1', 'L_Middle2', 'L_Middle3', 'L_Pinky1', 'L_Pinky2', 'L_Pinky3', 'L_Ring1', 'L_Ring2', 'L_Ring3', 'L_Thumb1', 'L_Thumb2', 'L_Thumb3'], \ + ['R_Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'R_Index1', 'R_Index2', 'R_Index3', 'R_Middle1', 'R_Middle2', 'R_Middle3', 'R_Pinky1', 'R_Pinky2', 'R_Pinky3', 'R_Ring1', 'R_Ring2', 'R_Ring3', 'R_Thumb1', 'R_Thumb2', 'R_Thumb3']] + + if self.remove_disc_rot: + remove_names = self._body_names_orig # NO AMP Rotation + self.limb_weight_group = [[self._body_names.index(g) for g in group] for group in self.limb_weight_group] + + for idx, name in enumerate(self._dof_names): + if not name in remove_names: + disc_idxes.append(np.arange(idx * 3, (idx + 1) * 3)) + + self.dof_subset = torch.from_numpy(np.concatenate(disc_idxes)) if len(disc_idxes) > 0 else torch.tensor([]).long() + self.left_indexes = [idx for idx , name in enumerate(self._dof_names) if name.startswith("L")] + self.right_indexes = [idx for idx , name in enumerate(self._dof_names) if name.startswith("R")] + + self.left_lower_indexes = [idx for idx , name in enumerate(self._dof_names) if name.startswith("L") and name[2:] in ["Hip", "Knee", "Ankle", "Toe"]] + self.right_lower_indexes = [idx for idx , name in enumerate(self._dof_names) if name.startswith("R") and name[2:] in ["Hip", "Knee", "Ankle", "Toe"]] + + self._load_amass_gender_betas() + + def _clear_recorded_states(self): + del self.state_record + self.state_record = defaultdict(list) + + def _record_states(self): + self.state_record['dof_pos'].append(self._dof_pos.cpu().clone()) + self.state_record['root_states'].append(self._humanoid_root_states.cpu().clone()) + self.state_record['progress'].append(self.progress_buf.cpu().clone()) + + def _write_states_to_file(self, file_name): + self.state_record['skeleton_trees'] = self.skeleton_trees + self.state_record['humanoid_betas'] = self.humanoid_shapes + print(f"Dumping states into {file_name}") + + progress = torch.stack(self.state_record['progress'], dim=1) + progress_diff = torch.cat([progress, -10 * torch.ones(progress.shape[0], 1).to(progress)], dim=-1) + + diff = torch.abs(progress_diff[:, :-1] - progress_diff[:, 1:]) + split_idx = torch.nonzero(diff > 1) + split_idx[:, 1] += 1 + dof_pos_all = torch.stack(self.state_record['dof_pos']) + root_states_all = torch.stack(self.state_record['root_states']) + fps = 60 + motion_dict_dump = {} + num_for_this_humanoid = 0 + curr_humanoid_index = 0 + + for idx in range(len(split_idx)): + split_info = split_idx[idx] + humanoid_index = split_info[0] + + if humanoid_index != curr_humanoid_index: + num_for_this_humanoid = 0 + curr_humanoid_index = humanoid_index + + if num_for_this_humanoid == 0: + start = 0 + else: + start = split_idx[idx - 1][-1] + + end = split_idx[idx][-1] + + dof_pos_seg = dof_pos_all[start:end, humanoid_index] + B, H = dof_pos_seg.shape + + root_states_seg = root_states_all[start:end, humanoid_index] + body_quat = torch.cat([root_states_seg[:, None, 3:7], torch_utils.exp_map_to_quat(dof_pos_seg.reshape(B, -1, 3))], dim=1) + + motion_dump = { + "skeleton_tree": self.state_record['skeleton_trees'][humanoid_index].to_dict(), + "body_quat": body_quat, + "trans": root_states_seg[:, :3], + "root_states_seg": root_states_seg, + "dof_pos": dof_pos_seg, + } + motion_dump['fps'] = fps + motion_dump['betas'] = self.humanoid_shapes[humanoid_index].detach().cpu().numpy() + motion_dict_dump[f"{humanoid_index}_{num_for_this_humanoid}"] = motion_dump + num_for_this_humanoid += 1 + + joblib.dump(motion_dict_dump, file_name) + self.state_record = defaultdict(list) + + def get_obs_size(self): + return self.get_self_obs_size() + + def get_running_mean_size(self): + return (self.get_obs_size(), ) + + def get_self_obs_size(self): + if self.self_obs_v == 1: + return self._num_self_obs + elif self.self_obs_v == 2: + return self._num_self_obs * (self.past_track_steps + 1) + elif self.self_obs_v == 3: + return self._num_self_obs + + def get_action_size(self): + return self._num_actions + + def get_dof_action_size(self): + return self._dof_size + + def get_num_actors_per_env(self): + num_actors = self._root_states.shape[0] // self.num_envs + return num_actors + + def create_sim(self): + self.up_axis_idx = self.set_sim_params_up_axis(self.sim_params, 'z') + + self.sim = super().create_sim(self.device_id, self.graphics_device_id, self.physics_engine, self.sim_params) + + self._create_ground_plane() + self._create_envs(self.num_envs, self.cfg["env"]['env_spacing'], int(np.sqrt(self.num_envs))) + return + + def reset(self, env_ids=None): + safe_reset = (env_ids is None) or len(env_ids) == self.num_envs + if (env_ids is None): + env_ids = to_torch(np.arange(self.num_envs), device=self.device, dtype=torch.long) + + self._reset_envs(env_ids) + + if safe_reset: + # import ipdb; ipdb.set_trace() + # print("3resetting here!!!!", self._humanoid_root_states[0, :3] - self._rigid_body_pos[0, 0]) + # ZL: This way it will simuate one step, then get reset again, squashing any remaining wiredness. Temporary fix + self.gym.simulate(self.sim) + self._reset_envs(env_ids) + torch.cuda.empty_cache() + + return + + def change_char_color(self): + colors = [] + offset = np.random.randint(0, 10) + for env_id in range(self.num_envs): + rand_cols = agt_color(env_id + offset) + colors.append(rand_cols) + + self.sample_char_color(torch.tensor(colors), torch.arange(self.num_envs)) + + + def sample_char_color(self, cols, env_ids): + for env_id in env_ids: + env_ptr = self.envs[env_id] + handle = self.humanoid_handles[env_id] + + for j in range(self.num_bodies): + self.gym.set_rigid_body_color(env_ptr, handle, j, gymapi.MESH_VISUAL, gymapi.Vec3(cols[env_id, 0], cols[env_id, 1], cols[env_id, 2])) + return + + def set_char_color(self, col, env_ids): + for env_id in env_ids: + env_ptr = self.envs[env_id] + handle = self.humanoid_handles[env_id] + + for j in range(self.num_bodies): + self.gym.set_rigid_body_color(env_ptr, handle, j, gymapi.MESH_VISUAL, + gymapi.Vec3(col[0], col[1], col[2])) + + return + + + def _reset_envs(self, env_ids): + if (len(env_ids) > 0): + + self._reset_actors(env_ids) # this funciton calle _set_env_state, and should set all state vectors + self._reset_env_tensors(env_ids) + + self._refresh_sim_tensors() + if self.self_obs_v == 2: + self._init_tensor_history(env_ids) + + self._compute_observations(env_ids) + + + return + + def _reset_env_tensors(self, env_ids): + env_ids_int32 = self._humanoid_actor_ids[env_ids] + + self.gym.set_actor_root_state_tensor_indexed(self.sim, gymtorch.unwrap_tensor(self._root_states), gymtorch.unwrap_tensor(env_ids_int32), len(env_ids_int32)) + self.gym.set_dof_state_tensor_indexed(self.sim, gymtorch.unwrap_tensor(self._dof_state), gymtorch.unwrap_tensor(env_ids_int32), len(env_ids_int32)) + + + # print("#################### refreshing ####################") + # print("rb", (self._rigid_body_state_reshaped[None, :] - self._rigid_body_state_reshaped[:, None]).abs().sum()) + # print("contact", (self._contact_forces[None, :] - self._contact_forces[:, None]).abs().sum()) + # print('dof_pos', (self._dof_pos[None, :] - self._dof_pos[:, None]).abs().sum()) + # print("dof_vel", (self._dof_vel[None, :] - self._dof_vel[:, None]).abs().sum()) + # print("root_states", (self._humanoid_root_states[None, :] - self._humanoid_root_states[:, None]).abs().sum()) + # print("#################### refreshing ####################") + + self.progress_buf[env_ids] = 0 + self.reset_buf[env_ids] = 0 + self._terminate_buf[env_ids] = 0 + self._contact_forces[env_ids] = 0 + + return + + def _create_ground_plane(self): + plane_params = gymapi.PlaneParams() + plane_params.normal = gymapi.Vec3(0.0, 0.0, 1.0) + plane_params.static_friction = self.plane_static_friction + plane_params.dynamic_friction = self.plane_dynamic_friction + + # plane_params.static_friction = 50 + # plane_params.dynamic_friction = 50 + + plane_params.restitution = self.plane_restitution + self.gym.add_ground(self.sim, plane_params) + return + + def _setup_character_props(self, key_bodies): + + asset_file = self.cfg.robot.asset.assetFileName + num_key_bodies = len(key_bodies) + + if (asset_file == "mjcf/amp_humanoid.xml"): + ### ZL: changes + self._dof_body_ids = [1, 2, 3, 4, 6, 7, 9, 10, 11, 12, 13, 14] + self._dof_offsets = [0, 3, 6, 9, 10, 13, 14, 17, 18, 21, 24, 25, 28] + self._dof_obs_size = 72 + self._num_actions = 28 + + if (ENABLE_MAX_COORD_OBS): + self._num_self_obs = 1 + 15 * (3 + 6 + 3 + 3) - 3 + else: + self._num_self_obs = 13 + self._dof_obs_size + 28 + 3 * num_key_bodies # [root_h, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_body_pos] + + elif self.humanoid_type in ["smpl", "smplh", "smplx"]: + # import ipdb; ipdb.set_trace() + self._dof_body_ids = np.arange(1, len(self._body_names)) + self._dof_offsets = np.linspace(0, len(self._dof_names) * 3, len(self._body_names)).astype(int) + self._dof_obs_size = len(self._dof_names) * 6 + self._dof_size = len(self._dof_names) * 3 + if self.reduce_action: + self._num_actions = len(self.action_idx) + else: + self._num_actions = len(self._dof_names) * 3 + + if (ENABLE_MAX_COORD_OBS): + self._num_self_obs = 1 + len(self._body_names) * (3 + 6 + 3 + 3) - 3 # height + num_bodies * 15 (pos + vel + rot + ang_vel) - root_pos + else: + raise NotImplementedError() + + if self._has_shape_obs: + self._num_self_obs += 11 + # if self._has_limb_weight_obs: self._num_self_obs += 23 + 24 if not self._masterfoot else 29 + 30 # 23 + 24 (length + weight) + if self._has_limb_weight_obs: + self._num_self_obs += 10 + + if not self._root_height_obs: + self._num_self_obs -= 1 + + if self.self_obs_v == 3: + self._num_self_obs += 6 * len(self.force_sensor_joints) + + else: + print("Unsupported character config file: {s}".format(asset_file)) + assert (False) + + return + + def _build_termination_heights(self): + head_term_height = 0.3 + shield_term_height = 0.32 + + termination_height = self.cfg["env"]["terminationHeight"] + self._termination_heights = np.array([termination_height] * self.num_bodies) + + head_id = self.gym.find_actor_rigid_body_handle(self.envs[0], self.humanoid_handles[0], "head") + self._termination_heights[head_id] = max(head_term_height, self._termination_heights[head_id]) + + asset_file = self.cfg.robot.asset["assetFileName"] + if (asset_file == "mjcf/amp_humanoid_sword_shield.xml"): + left_arm_id = self.gym.find_actor_rigid_body_handle(self.envs[0], self.humanoid_handles[0], "left_lower_arm") + self._termination_heights[left_arm_id] = max(shield_term_height, self._termination_heights[left_arm_id]) + + self._termination_heights = to_torch(self._termination_heights, device=self.device) + return + + def _create_smpl_humanoid_xml(self, num_humanoids, smpl_robot, queue, pid): + np.random.seed(np.random.randint(5002) * (pid + 1)) + res = {} + for idx in num_humanoids: + if self.has_shape_variation: + gender_beta = self._amass_gender_betas[idx % self._amass_gender_betas.shape[0]] + else: + gender_beta = np.zeros(17) + + if flags.im_eval: + gender_beta = np.zeros(17) + + asset_id = uuid4() + + if not smpl_robot is None: + asset_id = uuid4() + asset_file_real = f"/tmp/smpl/smpl_humanoid_{asset_id}.xml" + smpl_robot.load_from_skeleton(betas=torch.from_numpy(gender_beta[None, 1:]), gender=gender_beta[0:1], objs_info=None) + smpl_robot.write_xml(asset_file_real) + else: + asset_file_real = f"phc/data/assets/mjcf/smpl_{int(gender_beta[0])}_humanoid.xml" + + res[idx] = (gender_beta, asset_file_real) + + if not queue is None: + queue.put(res) + else: + return res + + def _load_amass_gender_betas(self): + if self._has_mesh: + gender_betas_data = joblib.load("sample_data/amass_isaac_gender_betas.pkl") + self._amass_gender_betas = np.array(list(gender_betas_data.values())) + else: + gender_betas_data = joblib.load("sample_data/amass_isaac_gender_betas_unique.pkl") + self._amass_gender_betas = np.array(gender_betas_data) + + def _create_envs(self, num_envs, spacing, num_per_row): + lower = gymapi.Vec3(-spacing, -spacing, 0.0) + upper = gymapi.Vec3(spacing, spacing, spacing) + + asset_root = self.cfg.robot.asset["assetRoot"] + asset_file = self.cfg.robot.asset["assetFileName"] + self.humanoid_masses = [] + + if (self.humanoid_type in ["smpl", "smplh", "smplx"]): + self.humanoid_shapes = [] + self.humanoid_assets = [] + self.humanoid_limb_and_weights = [] + self.skeleton_trees = [] + robot_cfg = { + "mesh": self._has_mesh, + "replace_feet": self._replace_feet, + "rel_joint_lm": self._has_jt_limit, + "upright_start": self._has_upright_start, + "remove_toe": self.remove_toe, + "freeze_hand": self._freeze_hand, + "real_weight_porpotion_capsules": self._real_weight_porpotion_capsules, + "real_weight_porpotion_boxes": self._real_weight_porpotion_boxes, + "real_weight": self._real_weight, + "masterfoot": self._masterfoot, + "master_range": self._master_range, + "big_ankle": self.big_ankle, + "box_body": self._box_body, + "body_params": {}, + "joint_params": {}, + "geom_params": {}, + "actuator_params": {}, + "model": self.humanoid_type, + "sim": "isaacgym", + } + if os.path.exists("data/smpl"): + robot = SMPL_Robot( + robot_cfg, + data_dir="data/smpl", + ) + else: + print("!!!!!!! SMPL files not found, loading pre-computed humanoid assets, only for demo purposes !!!!!!!") + print("!!!!!!! SMPL files not found, loading pre-computed humanoid assets, only for demo purposes !!!!!!!") + print("!!!!!!! SMPL files not found, loading pre-computed humanoid assets, only for demo purposes !!!!!!!") + asset_root = "./" + robot = None + + + + asset_options = gymapi.AssetOptions() + asset_options.angular_damping = 0.01 + asset_options.max_angular_velocity = 100.0 + asset_options.default_dof_drive_mode = gymapi.DOF_MODE_NONE + + if self.has_shape_variation: + queue = mp.Queue() + num_jobs = min(mp.cpu_count(), 64) + if num_jobs <= 8: + num_jobs = 1 + if flags.debug: + num_jobs = 1 + res_acc = {} + jobs = np.arange(num_envs) + chunk = np.ceil(len(jobs) / num_jobs).astype(int) + jobs = [jobs[i:i + chunk] for i in range(0, len(jobs), chunk)] + job_args = [jobs[i] for i in range(len(jobs))] + + for i in range(1, len(jobs)): + worker_args = (job_args[i], robot, queue, i) + worker = multiprocessing.Process(target=self._create_smpl_humanoid_xml, args=worker_args) + worker.start() + res_acc.update(self._create_smpl_humanoid_xml(jobs[0], robot, None, 0)) + for i in tqdm(range(len(jobs) - 1)): + res = queue.get() + res_acc.update(res) + + for idx in np.arange(num_envs): + gender_beta, asset_file_real = res_acc[idx] + humanoid_asset = self.gym.load_asset(self.sim, asset_root, asset_file_real, asset_options) + actuator_props = self.gym.get_asset_actuator_properties(humanoid_asset) + motor_efforts = [prop.motor_effort for prop in actuator_props] + + sk_tree = SkeletonTree.from_mjcf(asset_file_real) + + # create force sensors at the feet + if self.self_obs_v == 3: + self.create_humanoid_force_sensors(humanoid_asset, self.force_sensor_joints) + + self.humanoid_shapes.append(torch.from_numpy(gender_beta).float()) + self.humanoid_assets.append(humanoid_asset) + self.skeleton_trees.append(sk_tree) + + if not robot is None: + robot.remove_geoms() # Clean up the geoms + + self.humanoid_shapes = torch.vstack(self.humanoid_shapes).to(self.device) + else: + gender_beta, asset_file_real = self._create_smpl_humanoid_xml([0], robot, None, 0)[0] + sk_tree = SkeletonTree.from_mjcf(asset_file_real) + + humanoid_asset = self.gym.load_asset(self.sim, asset_root, asset_file_real, asset_options) + actuator_props = self.gym.get_asset_actuator_properties(humanoid_asset) + motor_efforts = [prop.motor_effort for prop in actuator_props] + + # create force sensors at the feet + if self.self_obs_v == 3: + self.create_humanoid_force_sensors(humanoid_asset, self.force_sensor_joints) + + + self.humanoid_shapes = torch.tensor(np.array([gender_beta] * num_envs)).float().to(self.device) + self.humanoid_assets = [humanoid_asset] * num_envs + self.skeleton_trees = [sk_tree] * num_envs + + else: + + asset_path = os.path.join(asset_root, asset_file) + asset_root = os.path.dirname(asset_path) + asset_file = os.path.basename(asset_path) + + asset_options = gymapi.AssetOptions() + asset_options.angular_damping = 0.01 + asset_options.max_angular_velocity = 100.0 + asset_options.default_dof_drive_mode = gymapi.DOF_MODE_NONE + #asset_options.fix_base_link = True + humanoid_asset = self.gym.load_asset(self.sim, asset_root, asset_file, asset_options) + + actuator_props = self.gym.get_asset_actuator_properties(humanoid_asset) + motor_efforts = [prop.motor_effort for prop in actuator_props] + + # create force sensors at the feet + self.create_humanoid_force_sensors(humanoid_asset, ["right_foot", "left_foot"]) + self.humanoid_assets = [humanoid_asset] * num_envs + + self.max_motor_effort = max(motor_efforts) + self.motor_efforts = to_torch(motor_efforts, device=self.device) + self.torso_index = 0 + self.num_bodies = self.gym.get_asset_rigid_body_count(humanoid_asset) + self.num_dof = self.gym.get_asset_dof_count(humanoid_asset) + self.num_asset_joints = self.gym.get_asset_joint_count(humanoid_asset) + self.humanoid_handles = [] + self.envs = [] + self.dof_limits_lower = [] + self.dof_limits_upper = [] + + for i in range(self.num_envs): + # create env instance + env_ptr = self.gym.create_env(self.sim, lower, upper, num_per_row) + self._build_env(i, env_ptr, self.humanoid_assets[i]) + self.envs.append(env_ptr) + self.humanoid_limb_and_weights = torch.stack(self.humanoid_limb_and_weights).to(self.device) + print("Humanoid Weights", self.humanoid_masses[:10]) + + dof_prop = self.gym.get_actor_dof_properties(self.envs[0], self.humanoid_handles[0]) + + ######################################## Joint frictino + # dof_prop['friction'][:] = 10 + # self.gym.set_actor_dof_properties(self.envs[0], self.humanoid_handles[0], dof_prop) + + for j in range(self.num_dof): + if dof_prop['lower'][j] > dof_prop['upper'][j]: + self.dof_limits_lower.append(dof_prop['upper'][j]) + self.dof_limits_upper.append(dof_prop['lower'][j]) + else: + self.dof_limits_lower.append(dof_prop['lower'][j]) + self.dof_limits_upper.append(dof_prop['upper'][j]) + + self.dof_limits_lower = to_torch(self.dof_limits_lower, device=self.device) + self.dof_limits_upper = to_torch(self.dof_limits_upper, device=self.device) + + if self.control_mode == "pd": + self.torque_limits = torch.ones_like(self.dof_limits_upper) * 1000 # ZL: hacking + + if self.control_mode in ["pd", "isaac_pd"]: + self._build_pd_action_offset_scale() + return + + def create_humanoid_force_sensors(self, humanoid_asset, sensor_joint_names): + for jt in sensor_joint_names: + right_foot_idx = self.gym.find_asset_rigid_body_index(humanoid_asset, jt) + sensor_pose = gymapi.Transform() + sensor_options = gymapi.ForceSensorProperties() + sensor_options.enable_constraint_solver_forces = True # for example contacts + sensor_options.use_world_frame = False # Local frame so we can directly send it to computation. + # These are the default values. + + self.gym.create_asset_force_sensor(humanoid_asset, right_foot_idx, sensor_pose, sensor_options) + + return + + def _build_env(self, env_id, env_ptr, humanoid_asset): + if self._divide_group or flags.divide_group: + col_group = self._group_ids[env_id] + else: + col_group = env_id # no inter-environment collision + + col_filter = 0 + if (self.humanoid_type in ["smpl", "smplh", "smplx"] ) and (not self._has_self_collision): + col_filter = 1 + + start_pose = gymapi.Transform() + asset_file = self.cfg.robot.asset["assetFileName"] + if (asset_file == "mjcf/ov_humanoid.xml" or asset_file == "mjcf/ov_humanoid_sword_shield.xml"): + char_h = 0.927 + else: + char_h = 0.89 + + pos = torch.tensor(get_axis_params(char_h, self.up_axis_idx)).to(self.device) + pos[:2] += torch_rand_float(-1., 1., (2, 1), device=self.device).squeeze(1) # ZL: segfault if we do not randomize the position + + start_pose.p = gymapi.Vec3(*pos) + start_pose.r = gymapi.Quat(0.0, 0.0, 0.0, 1.0) + + humanoid_handle = self.gym.create_actor(env_ptr, humanoid_asset, start_pose, "humanoid", col_group, col_filter, 0) + self.gym.enable_actor_dof_force_sensors(env_ptr, humanoid_handle) + mass_ind = [prop.mass for prop in self.gym.get_actor_rigid_body_properties(env_ptr, humanoid_handle)] + humanoid_mass = np.sum(mass_ind) + self.humanoid_masses.append(humanoid_mass) + + curr_skeleton_tree = self.skeleton_trees[env_id] + limb_lengths = torch.norm(curr_skeleton_tree.local_translation, dim=-1) + masses = torch.tensor(mass_ind) + + # humanoid_limb_weight = torch.cat([limb_lengths[1:], masses]) + + limb_lengths = [limb_lengths[group].sum() for group in self.limb_weight_group] + masses = [masses[group].sum() for group in self.limb_weight_group] + humanoid_limb_weight = torch.tensor(limb_lengths + masses) + self.humanoid_limb_and_weights.append(humanoid_limb_weight) # ZL: attach limb lengths and full body weight. + + if self.humanoid_type in ["smpl", "smplh", "smplx"]: + gender = self.humanoid_shapes[env_id, 0].long() + percentage = 1 - np.clip((humanoid_mass - 70) / 70, 0, 1) + if gender == 0: + gender = 1 + color_vec = gymapi.Vec3(*get_color_gradient(percentage, "Greens")) + elif gender == 1: + gender = 2 + color_vec = gymapi.Vec3(*get_color_gradient(percentage, "Blues")) + elif gender == 2: + gender = 0 + color_vec = gymapi.Vec3(*get_color_gradient(percentage, "Reds")) + + # color = torch.zeros(3) + # color[gender] = 1 - np.clip((humanoid_mass - 70) / 70, 0, 1) + if flags.test: + color_vec = gymapi.Vec3(*agt_color(env_id + 0)) + # if env_id == 0: + # color_vec = gymapi.Vec3(0.23192618223760095, 0.5456516724336793, 0.7626143790849673) + # elif env_id == 1: + # color_vec = gymapi.Vec3(0.907912341407151, 0.20284505959246443, 0.16032295271049596) + + else: + color_vec = gymapi.Vec3(0.54, 0.85, 0.2) + + for j in range(self.num_bodies): + self.gym.set_rigid_body_color(env_ptr, humanoid_handle, j, gymapi.MESH_VISUAL, color_vec) + + dof_prop = self.gym.get_asset_dof_properties(humanoid_asset) + if self.has_shape_variation: + pd_scale = humanoid_mass / self.cfg['env'].get('default_humanoid_mass', 77.0 if self._real_weight else 35.0) + self._kp_scale = pd_scale * self._kp_scale + self._kd_scale = pd_scale * self._kd_scale + + if (self.control_mode == "isaac_pd"): + dof_prop["driveMode"][:] = gymapi.DOF_MODE_POS + dof_prop['stiffness'] *= self._kp_scale + dof_prop['damping'] *= self._kd_scale + else: + if self.control_mode == "pd": + # self.kp_gains = to_torch(self._kp_scale * dof_prop['stiffness'], device=self.device) + # self.kd_gains = to_torch(self._kd_scale * dof_prop['damping'], device=self.device) + self.kp_gains = to_torch(self._kp_scale * dof_prop['stiffness']/4, device=self.device) + self.kd_gains = to_torch(self._kd_scale * dof_prop['damping']/4, device=self.device) + dof_prop['velocity'][:] = 100 + dof_prop['stiffness'][:] = 0 + dof_prop['friction'][:] = 1 + dof_prop['damping'][:] = 0.001 + elif self.control_mode == "force": + dof_prop['velocity'][:] = 100 + dof_prop['stiffness'][:] = 0 + dof_prop['friction'][:] = 1 + dof_prop['damping'][:] = 0.001 + + dof_prop["driveMode"][:] = gymapi.DOF_MODE_EFFORT + self.gym.set_actor_dof_properties(env_ptr, humanoid_handle, dof_prop) + + if self.humanoid_type in ["smpl", "smplh", "smplx"] and self._has_self_collision: + # compliance_vals = [0.1] * 24 + # thickness_vals = [1.0] * 24 + if self._has_mesh: + filter_ints = [0, 1, 224, 512, 384, 1, 1792, 64, 1056, 4096, 6, 6168, 0, 2048, 0, 20, 0, 0, 0, 0, 10, 0, 0, 0] + else: + if self.humanoid_type == "smpl": + filter_ints = [0, 0, 7, 16, 12, 0, 56, 2, 33, 128, 0, 192, 0, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + elif self.humanoid_type in ["smplh", "smplx"]: + filter_ints = [0, 0, 7, 16, 12, 0, 56, 2, 33, 128, 0, 192, 0, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + + props = self.gym.get_actor_rigid_shape_properties(env_ptr, humanoid_handle) + + assert (len(filter_ints) == len(props)) + for p_idx in range(len(props)): + props[p_idx].filter = filter_ints[p_idx] + self.gym.set_actor_rigid_shape_properties(env_ptr, humanoid_handle, props) + + self.humanoid_handles.append(humanoid_handle) + + return + + def _build_pd_action_offset_scale(self): + num_joints = len(self._dof_offsets) - 1 + + lim_low = self.dof_limits_lower.cpu().numpy() + lim_high = self.dof_limits_upper.cpu().numpy() + + for j in range(num_joints): + dof_offset = self._dof_offsets[j] + dof_size = self._dof_offsets[j + 1] - self._dof_offsets[j] + if not self._bias_offset: + if (dof_size == 3): + curr_low = lim_low[dof_offset:(dof_offset + dof_size)] + curr_high = lim_high[dof_offset:(dof_offset + dof_size)] + curr_low = np.max(np.abs(curr_low)) + curr_high = np.max(np.abs(curr_high)) + curr_scale = max([curr_low, curr_high]) + curr_scale = 1.2 * curr_scale + curr_scale = min([curr_scale, np.pi]) + + lim_low[dof_offset:(dof_offset + dof_size)] = -curr_scale + lim_high[dof_offset:(dof_offset + dof_size)] = curr_scale + + #lim_low[dof_offset:(dof_offset + dof_size)] = -np.pi + #lim_high[dof_offset:(dof_offset + dof_size)] = np.pi + + elif (dof_size == 1): + curr_low = lim_low[dof_offset] + curr_high = lim_high[dof_offset] + curr_mid = 0.5 * (curr_high + curr_low) + + # extend the action range to be a bit beyond the joint limits so that the motors + # don't lose their strength as they approach the joint limits + curr_scale = 0.7 * (curr_high - curr_low) + curr_low = curr_mid - curr_scale + curr_high = curr_mid + curr_scale + + lim_low[dof_offset] = curr_low + lim_high[dof_offset] = curr_high + else: + curr_low = lim_low[dof_offset:(dof_offset + dof_size)] + curr_high = lim_high[dof_offset:(dof_offset + dof_size)] + curr_mid = 0.5 * (curr_high + curr_low) + + # extend the action range to be a bit beyond the joint limits so that the motors + # don't lose their strength as they approach the joint limits + curr_scale = 0.7 * (curr_high - curr_low) + curr_low = curr_mid - curr_scale + curr_high = curr_mid + curr_scale + + lim_low[dof_offset:(dof_offset + dof_size)] = curr_low + lim_high[dof_offset:(dof_offset + dof_size)] = curr_high + + self._pd_action_offset = 0.5 * (lim_high + lim_low) + self._pd_action_scale = 0.5 * (lim_high - lim_low) + self._pd_action_offset = to_torch(self._pd_action_offset, device=self.device) + self._pd_action_scale = to_torch(self._pd_action_scale, device=self.device) + if self.humanoid_type in ["smpl", "smplh", "smplx"]: + self._L_knee_dof_idx = self._dof_names.index("L_Knee") * 3 + 1 + self._R_knee_dof_idx = self._dof_names.index("R_Knee") * 3 + 1 + + # ZL: Modified SMPL to give stronger knee + self._pd_action_scale[self._L_knee_dof_idx] = 5 + self._pd_action_scale[self._R_knee_dof_idx] = 5 + + if self._has_smpl_pd_offset: + if self._has_upright_start: + self._pd_action_offset[self._dof_names.index("L_Shoulder") * 3] = -np.pi / 2 + self._pd_action_offset[self._dof_names.index("R_Shoulder") * 3] = np.pi / 2 + else: + self._pd_action_offset[self._dof_names.index("L_Shoulder") * 3] = -np.pi / 6 + self._pd_action_offset[self._dof_names.index("L_Shoulder") * 3 + 2] = -np.pi / 2 + self._pd_action_offset[self._dof_names.index("R_Shoulder") * 3] = -np.pi / 3 + self._pd_action_offset[self._dof_names.index("R_Shoulder") * 3 + 2] = np.pi / 2 + + return + + def _compute_reward(self, actions): + self.rew_buf[:] = compute_humanoid_reward(self.obs_buf) + return + + def _compute_reset(self): + self.reset_buf[:], self._terminate_buf[:] = compute_humanoid_reset(self.reset_buf, self.progress_buf, self._contact_forces, self._contact_body_ids, self._rigid_body_pos, self.max_episode_length, self._enable_early_termination, self._termination_heights) + return + + def _refresh_sim_tensors(self): + self.gym.refresh_dof_state_tensor(self.sim) + self.gym.refresh_actor_root_state_tensor(self.sim) + self.gym.refresh_rigid_body_state_tensor(self.sim) + + self.gym.refresh_force_sensor_tensor(self.sim) + self.gym.refresh_dof_force_tensor(self.sim) + self.gym.refresh_net_contact_force_tensor(self.sim) + + return + + def _compute_observations(self, env_ids=None): + obs = self._compute_humanoid_obs(env_ids) + + return + + def _compute_humanoid_obs(self, env_ids=None): + if (ENABLE_MAX_COORD_OBS): + if (env_ids is None): + body_pos = self._rigid_body_pos + body_rot = self._rigid_body_rot + body_vel = self._rigid_body_vel + body_ang_vel = self._rigid_body_ang_vel + if self.self_obs_v == 2: + body_pos = torch.cat([self._rigid_body_pos_hist, body_pos.unsqueeze(1)], dim=1) + body_rot = torch.cat([self._rigid_body_rot_hist, body_rot.unsqueeze(1)], dim=1) + body_vel = torch.cat([self._rigid_body_vel_hist, body_vel.unsqueeze(1)], dim=1) + body_ang_vel = torch.cat([self._rigid_body_ang_vel_hist, body_ang_vel.unsqueeze(1)], dim=1) + if self.self_obs_v == 3: + force_sensor_readings = self.vec_sensor_tensor + + + else: + body_pos = self._rigid_body_pos[env_ids] + body_rot = self._rigid_body_rot[env_ids] + body_vel = self._rigid_body_vel[env_ids] + body_ang_vel = self._rigid_body_ang_vel[env_ids] + if self.self_obs_v == 2: + body_pos = torch.cat([self._rigid_body_pos_hist[env_ids], body_pos.unsqueeze(1)], dim=1) + body_rot = torch.cat([self._rigid_body_rot_hist[env_ids], body_rot.unsqueeze(1)], dim=1) + body_vel = torch.cat([self._rigid_body_vel_hist[env_ids], body_vel.unsqueeze(1)], dim=1) + body_ang_vel = torch.cat([self._rigid_body_ang_vel_hist[env_ids], body_ang_vel.unsqueeze(1)], dim=1) + if self.self_obs_v == 3: + force_sensor_readings = self.vec_sensor_tensor[env_ids] + + + + if self.humanoid_type in ["smpl", "smplh", "smplx"] : + if (env_ids is None): + body_shape_params = self.humanoid_shapes[:, :-6] if self.humanoid_type in ["smpl", "smplh", "smplx"] else self.humanoid_shapes + limb_weights = self.humanoid_limb_and_weights + else: + body_shape_params = self.humanoid_shapes[env_ids, :-6] if self.humanoid_type in ["smpl", "smplh", "smplx"] else self.humanoid_shapes[env_ids] + limb_weights = self.humanoid_limb_and_weights[env_ids] + + if self.self_obs_v == 1: + obs = compute_humanoid_observations_smpl_max(body_pos, body_rot, body_vel, body_ang_vel, body_shape_params, limb_weights, self._local_root_obs, self._root_height_obs, self._has_upright_start, self._has_shape_obs, self._has_limb_weight_obs) + elif self.self_obs_v == 2: + obs = compute_humanoid_observations_smpl_max_v2(body_pos, body_rot, body_vel, body_ang_vel, body_shape_params, limb_weights, self._local_root_obs, self._root_height_obs, self._has_upright_start, self._has_shape_obs, self._has_limb_weight_obs, self.past_track_steps + 1) + elif self.self_obs_v == 3: + obs = compute_humanoid_observations_smpl_max_v3(body_pos, body_rot, body_vel, body_ang_vel, force_sensor_readings, body_shape_params, limb_weights, self._local_root_obs, self._root_height_obs, self._has_upright_start, self._has_shape_obs, self._has_limb_weight_obs) + + + else: + obs = compute_humanoid_observations_max(body_pos, body_rot, body_vel, body_ang_vel, self._local_root_obs, self._root_height_obs) + + else: + if (env_ids is None): + root_pos = self._rigid_body_pos[:, 0, :] + root_rot = self._rigid_body_rot[:, 0, :] + root_vel = self._rigid_body_vel[:, 0, :] + root_ang_vel = self._rigid_body_ang_vel[:, 0, :] + dof_pos = self._dof_pos + dof_vel = self._dof_vel + key_body_pos = self._rigid_body_pos[:, self._key_body_ids, :] + else: + root_pos = self._rigid_body_pos[env_ids][:, 0, :] + root_rot = self._rigid_body_rot[env_ids][:, 0, :] + root_vel = self._rigid_body_vel[env_ids][:, 0, :] + root_ang_vel = self._rigid_body_ang_vel[env_ids][:, 0, :] + dof_pos = self._dof_pos[env_ids] + dof_vel = self._dof_vel[env_ids] + key_body_pos = self._rigid_body_pos[env_ids][:, self._key_body_ids, :] + + if (self.humanoid_type in ["smpl", "smplh", "smplx"] ) and self.self.has_shape_obs: + if (env_ids is None): + body_shape_params = self.humanoid_shapes + else: + body_shape_params = self.humanoid_shapes[env_ids] + obs = compute_humanoid_observations_smpl(root_pos, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_body_pos, self._dof_obs_size, self._dof_offsets, body_shape_params, self._local_root_obs, self._root_height_obs, self._has_upright_start, self._has_shape_obs) + else: + obs = compute_humanoid_observations(root_pos, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_body_pos, self._local_root_obs, self._root_height_obs, self._dof_obs_size, self._dof_offsets) + return obs + + def _reset_actors(self, env_ids): + self._humanoid_root_states[env_ids] = self._initial_humanoid_root_states[env_ids] + self._dof_pos[env_ids] = self._initial_dof_pos[env_ids] + self._dof_vel[env_ids] = self._initial_dof_vel[env_ids] + return + + + def pre_physics_step(self, actions): + # if flags.debug: + # actions *= 0 + + self.actions = actions.to(self.device).clone() + if len(self.actions.shape) == 1: + self.actions = self.actions[None, ] + + if (self._pd_control): + if self.humanoid_type in ["smpl", "smplh", "smplx"]: + if self.reduce_action: + actions_full = torch.zeros([actions.shape[0], self._dof_size]).to(self.device) + actions_full[:, self.action_idx] = self.actions + pd_tar = self._action_to_pd_targets(actions_full) + + else: + pd_tar = self._action_to_pd_targets(self.actions) + if self._freeze_hand: + pd_tar[:, self._dof_names.index("L_Hand") * 3:(self._dof_names.index("L_Hand") * 3 + 3)] = 0 + pd_tar[:, self._dof_names.index("R_Hand") * 3:(self._dof_names.index("R_Hand") * 3 + 3)] = 0 + if self._freeze_toe: + pd_tar[:, self._dof_names.index("L_Toe") * 3:(self._dof_names.index("L_Toe") * 3 + 3)] = 0 + pd_tar[:, self._dof_names.index("R_Toe") * 3:(self._dof_names.index("R_Toe") * 3 + 3)] = 0 + + pd_tar_tensor = gymtorch.unwrap_tensor(pd_tar) + self.gym.set_dof_position_target_tensor(self.sim, pd_tar_tensor) + + else: + if self.control_mode == "force": + actions_full = self.actions + forces = actions_full * self.motor_efforts.unsqueeze(0) * self.power_scale + force_tensor = gymtorch.unwrap_tensor(forces) + self.gym.set_dof_actuation_force_tensor(self.sim, force_tensor) + elif self.control_mode == "pd": + self.pd_tar = self._action_to_pd_targets(self.actions) + return + + + def _compute_torques(self, actions): + """ Compute torques from actions. + Actions can be interpreted as position or velocity targets given to a PD controller, or directly as scaled torques. + [NOTE]: torques must have the same dimension as the number of DOFs, even if some DOFs are not actuated. + Args: + actions (torch.Tensor): Actions + Returns: + [torch.Tensor]: Torques sent to the simulation + """ + #pd controller + action_scale = 1 + control_type = "P" # self.cfg.control.control_type + if control_type=="P": # default + torques = self.kp_gains*(actions - self._dof_pos) - self.kd_gains*self._dof_vel + else: + raise NameError(f"Unknown controller type: {control_type}") + # if self.cfg.domain_rand.randomize_torque_rfi: + # torques = torques + (torch.rand_like(torques)*2.-1.) * self.cfg.domain_rand.rfi_lim * self.torque_limits + + return torch.clip(torques, -self.torque_limits, self.torque_limits) + + + def _physics_step(self): + for i in range(self.control_freq_inv): + self.control_i = i + self.render() + if not self.paused and self.enable_viewer_sync: + if self.control_mode == "pd": #### Using simple pd controller. + self.torques = self._compute_torques(self.pd_tar) + self.gym.set_dof_actuation_force_tensor(self.sim, gymtorch.unwrap_tensor(self.torques)) + self.gym.simulate(self.sim) + if self.device == 'cpu': + self.gym.fetch_results(self.sim, True) + self.gym.refresh_dof_state_tensor(self.sim) + else: + self.gym.simulate(self.sim) + + return + + + + + def _init_tensor_history(self, env_ids): + self._rigid_body_pos_hist[env_ids] = self._rigid_body_pos[env_ids].unsqueeze(1).repeat(1, self.past_track_steps, 1, 1) + self._rigid_body_rot_hist[env_ids] = self._rigid_body_rot[env_ids].unsqueeze(1).repeat(1, self.past_track_steps, 1, 1) + self._rigid_body_vel_hist[env_ids] = self._rigid_body_vel[env_ids].unsqueeze(1).repeat(1, self.past_track_steps, 1, 1) + self._rigid_body_ang_vel_hist[env_ids] = self._rigid_body_ang_vel[env_ids].unsqueeze(1).repeat(1, self.past_track_steps, 1, 1) + + def _update_tensor_history(self): + self._rigid_body_pos_hist = torch.cat([self._rigid_body_pos_hist[:, 1:], self._rigid_body_pos.unsqueeze(1)], dim=1) + self._rigid_body_rot_hist = torch.cat([self._rigid_body_rot_hist[:, 1:], self._rigid_body_rot.unsqueeze(1)], dim=1) + self._rigid_body_vel_hist = torch.cat([self._rigid_body_vel_hist[:, 1:], self._rigid_body_vel.unsqueeze(1)], dim=1) + self._rigid_body_ang_vel_hist = torch.cat([self._rigid_body_ang_vel_hist[:, 1:], self._rigid_body_ang_vel.unsqueeze(1)], dim=1) + + + def post_physics_step(self): + # This is after stepping, so progress buffer got + 1. Compute reset/reward do not need to forward 1 timestep since they are for "this" frame now. + if not self.paused: + self.progress_buf += 1 + + + if self.self_obs_v == 2: + self._update_tensor_history() + + self._refresh_sim_tensors() + self._compute_reward(self.actions) # ZL swapped order of reward & objecation computes. should be fine. + self._compute_reset() + + self._compute_observations() # observation for the next step. + + self.extras["terminate"] = self._terminate_buf + self.extras["reward_raw"] = self.reward_raw.detach() + + # debug viz + if self.viewer and self.debug_viz: + self._update_debug_viz() + + + # Debugging + # if flags.debug: + # body_vel = self._rigid_body_vel.clone() + # speeds = body_vel.norm(dim = -1).mean(dim = -1) + # sorted_speed, sorted_idx = speeds.sort() + # print(sorted_speed.numpy()[::-1][:20], sorted_idx.numpy()[::-1][:20].tolist()) + # # import ipdb; ipdb.set_trace() + + return + + def render(self, sync_frame_time=False): + if self.viewer or flags.server_mode: + self._update_camera() + + super().render(sync_frame_time) + return + + def _build_key_body_ids_tensor(self, key_body_names): + if self.humanoid_type in ["smpl", "smplh", "smplx"] : + body_ids = [self._body_names.index(name) for name in key_body_names] + body_ids = to_torch(body_ids, device=self.device, dtype=torch.long) + + else: + env_ptr = self.envs[0] + actor_handle = self.humanoid_handles[0] + body_ids = [] + + for body_name in key_body_names: + body_id = self.gym.find_actor_rigid_body_handle(env_ptr, actor_handle, body_name) + assert (body_id != -1) + body_ids.append(body_id) + + body_ids = to_torch(body_ids, device=self.device, dtype=torch.long) + + return body_ids + + def _build_key_body_ids_orig_tensor(self, key_body_names): + body_ids = [self._body_names_orig.index(name) for name in key_body_names] + body_ids = to_torch(body_ids, device=self.device, dtype=torch.long) + return body_ids + + def _build_contact_body_ids_tensor(self, contact_body_names): + env_ptr = self.envs[0] + actor_handle = self.humanoid_handles[0] + body_ids = [] + + for body_name in contact_body_names: + body_id = self.gym.find_actor_rigid_body_handle(env_ptr, actor_handle, body_name) + assert (body_id != -1) + body_ids.append(body_id) + + body_ids = to_torch(body_ids, device=self.device, dtype=torch.long) + return body_ids + + def _action_to_pd_targets(self, action): + pd_tar = self._pd_action_offset + self._pd_action_scale * action + return pd_tar + + def _init_camera(self): + self.gym.refresh_actor_root_state_tensor(self.sim) + self._cam_prev_char_pos = self._humanoid_root_states[0, 0:3].cpu().numpy() + + cam_pos = gymapi.Vec3(self._cam_prev_char_pos[0], self._cam_prev_char_pos[1] - 3.0, 1.0) + cam_target = gymapi.Vec3(self._cam_prev_char_pos[0], self._cam_prev_char_pos[1], 1.0) + if self.viewer: + self.gym.viewer_camera_look_at(self.viewer, None, cam_pos, cam_target) + return + + def _update_camera(self): + self.gym.refresh_actor_root_state_tensor(self.sim) + char_root_pos = self._humanoid_root_states[0, 0:3].cpu().numpy() + + cam_trans = self.gym.get_viewer_camera_transform(self.viewer, None) + + cam_pos = np.array([cam_trans.p.x, cam_trans.p.y, cam_trans.p.z]) + cam_delta = cam_pos - self._cam_prev_char_pos + + new_cam_target = gymapi.Vec3(char_root_pos[0], char_root_pos[1], 1.0) + new_cam_pos = gymapi.Vec3(char_root_pos[0] + cam_delta[0], char_root_pos[1] + cam_delta[1], cam_pos[2]) + + self.gym.set_camera_location(self.recorder_camera_handle, self.envs[0], new_cam_pos, new_cam_target) + + if self.viewer: + self.gym.viewer_camera_look_at(self.viewer, None, new_cam_pos, new_cam_target) + + self._cam_prev_char_pos[:] = char_root_pos + return + + def _update_debug_viz(self): + self.gym.clear_lines(self.viewer) + return + + +##################################################################### +###=========================jit functions=========================### +##################################################################### + + +@torch.jit.script +def dof_to_obs_smpl(pose): + # type: (Tensor) -> Tensor + joint_obs_size = 6 + B, jts = pose.shape + num_joints = int(jts / 3) + + joint_dof_obs = torch_utils.quat_to_tan_norm(torch_utils.exp_map_to_quat(pose.reshape(-1, 3))).reshape(B, -1) + assert ((num_joints * joint_obs_size) == joint_dof_obs.shape[1]) + + return joint_dof_obs + + +@torch.jit.script +def dof_to_obs(pose, dof_obs_size, dof_offsets): + # ZL this can be sped up for SMPL + # type: (Tensor, int, List[int]) -> Tensor + joint_obs_size = 6 + num_joints = len(dof_offsets) - 1 + + dof_obs_shape = pose.shape[:-1] + (dof_obs_size,) + dof_obs = torch.zeros(dof_obs_shape, device=pose.device) + dof_obs_offset = 0 + + for j in range(num_joints): + dof_offset = dof_offsets[j] + dof_size = dof_offsets[j + 1] - dof_offsets[j] + joint_pose = pose[:, dof_offset:(dof_offset + dof_size)] + + # assume this is a spherical joint + if (dof_size == 3): + joint_pose_q = torch_utils.exp_map_to_quat(joint_pose) + elif (dof_size == 1): + axis = torch.tensor([0.0, 1.0, 0.0], dtype=joint_pose.dtype, device=pose.device) + joint_pose_q = quat_from_angle_axis(joint_pose[..., 0], axis) + else: + joint_pose_q = None + assert (False), "Unsupported joint type" + + joint_dof_obs = torch_utils.quat_to_tan_norm(joint_pose_q) + dof_obs[:, (j * joint_obs_size):((j + 1) * joint_obs_size)] = joint_dof_obs + + assert ((num_joints * joint_obs_size) == dof_obs_size) + + return dof_obs + + +@torch.jit.script +def compute_humanoid_observations(root_pos, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_body_pos, local_root_obs, root_height_obs, dof_obs_size, dof_offsets): + # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, bool, bool, int, List[int]) -> Tensor + root_h = root_pos[:, 2:3] + heading_rot = torch_utils.calc_heading_quat_inv(root_rot) + + if (local_root_obs): + root_rot_obs = quat_mul(heading_rot, root_rot) + else: + root_rot_obs = root_rot + root_rot_obs = torch_utils.quat_to_tan_norm(root_rot_obs) + + if (not root_height_obs): + root_h_obs = torch.zeros_like(root_h) + else: + root_h_obs = root_h + + local_root_vel = torch_utils.my_quat_rotate(heading_rot, root_vel) + local_root_ang_vel = torch_utils.my_quat_rotate(heading_rot, root_ang_vel) + + root_pos_expand = root_pos.unsqueeze(-2) + local_key_body_pos = key_body_pos - root_pos_expand + + heading_rot_expand = heading_rot.unsqueeze(-2) + heading_rot_expand = heading_rot_expand.repeat((1, local_key_body_pos.shape[1], 1)) + flat_end_pos = local_key_body_pos.view(local_key_body_pos.shape[0] * local_key_body_pos.shape[1], local_key_body_pos.shape[2]) + flat_heading_rot = heading_rot_expand.view(heading_rot_expand.shape[0] * heading_rot_expand.shape[1], heading_rot_expand.shape[2]) + local_end_pos = torch_utils.my_quat_rotate(flat_heading_rot, flat_end_pos) + flat_local_key_pos = local_end_pos.view(local_key_body_pos.shape[0], local_key_body_pos.shape[1] * local_key_body_pos.shape[2]) + + dof_obs = dof_to_obs(dof_pos, dof_obs_size, dof_offsets) + + obs = torch.cat((root_h_obs, root_rot_obs, local_root_vel, local_root_ang_vel, dof_obs, dof_vel, flat_local_key_pos), dim=-1) + return obs + + +@torch.jit.script +def compute_humanoid_observations_max(body_pos, body_rot, body_vel, body_ang_vel, local_root_obs, root_height_obs): + # type: (Tensor, Tensor, Tensor, Tensor, bool, bool) -> Tensor + root_pos = body_pos[:, 0, :] + root_rot = body_rot[:, 0, :] + + root_h = root_pos[:, 2:3] + heading_rot = torch_utils.calc_heading_quat_inv(root_rot) + + if (not root_height_obs): + root_h_obs = torch.zeros_like(root_h) + else: + root_h_obs = root_h + + heading_rot_expand = heading_rot.unsqueeze(-2) + heading_rot_expand = heading_rot_expand.repeat((1, body_pos.shape[1], 1)) + flat_heading_rot = heading_rot_expand.reshape(heading_rot_expand.shape[0] * heading_rot_expand.shape[1], heading_rot_expand.shape[2]) + + root_pos_expand = root_pos.unsqueeze(-2) + local_body_pos = body_pos - root_pos_expand + flat_local_body_pos = local_body_pos.reshape(local_body_pos.shape[0] * local_body_pos.shape[1], local_body_pos.shape[2]) + flat_local_body_pos = torch_utils.my_quat_rotate(flat_heading_rot, flat_local_body_pos) + local_body_pos = flat_local_body_pos.reshape(local_body_pos.shape[0], local_body_pos.shape[1] * local_body_pos.shape[2]) + local_body_pos = local_body_pos[..., 3:] # remove root pos + + flat_body_rot = body_rot.reshape(body_rot.shape[0] * body_rot.shape[1], body_rot.shape[2]) # global body rotation + flat_local_body_rot = quat_mul(flat_heading_rot, flat_body_rot) + flat_local_body_rot_obs = torch_utils.quat_to_tan_norm(flat_local_body_rot) + local_body_rot_obs = flat_local_body_rot_obs.reshape(body_rot.shape[0], body_rot.shape[1] * flat_local_body_rot_obs.shape[1]) + + if (local_root_obs): + root_rot_obs = torch_utils.quat_to_tan_norm(root_rot) + local_body_rot_obs[..., 0:6] = root_rot_obs + + flat_body_vel = body_vel.reshape(body_vel.shape[0] * body_vel.shape[1], body_vel.shape[2]) + flat_local_body_vel = torch_utils.my_quat_rotate(flat_heading_rot, flat_body_vel) + local_body_vel = flat_local_body_vel.reshape(body_vel.shape[0], body_vel.shape[1] * body_vel.shape[2]) + + flat_body_ang_vel = body_ang_vel.reshape(body_ang_vel.shape[0] * body_ang_vel.shape[1], body_ang_vel.shape[2]) + flat_local_body_ang_vel = torch_utils.my_quat_rotate(flat_heading_rot, flat_body_ang_vel) + local_body_ang_vel = flat_local_body_ang_vel.reshape(body_ang_vel.shape[0], body_ang_vel.shape[1] * body_ang_vel.shape[2]) + + obs = torch.cat((root_h_obs, local_body_pos, local_body_rot_obs, local_body_vel, local_body_ang_vel), dim=-1) + return obs + + +@torch.jit.script +def compute_humanoid_reward(obs_buf): + # type: (Tensor) -> Tensor + reward = torch.ones_like(obs_buf[:, 0]) + return reward + + +@torch.jit.script +def compute_humanoid_reset(reset_buf, progress_buf, contact_buf, contact_body_ids, rigid_body_pos, max_episode_length, enable_early_termination, termination_heights): + # type: (Tensor, Tensor, Tensor, Tensor, Tensor, float, bool, Tensor) -> Tuple[Tensor, Tensor] + terminated = torch.zeros_like(reset_buf) + + if (enable_early_termination): + masked_contact_buf = contact_buf.clone() + masked_contact_buf[:, contact_body_ids, :] = 0 + fall_contact = torch.any(torch.abs(masked_contact_buf) > 0.1, dim=-1) + fall_contact = torch.any(fall_contact, dim=-1) + + # if fall_contact.any(): + # print(masked_contact_buf[0, :, 0].nonzero()) + # import ipdb + # ipdb.set_trace() + + body_height = rigid_body_pos[..., 2] + fall_height = body_height < termination_heights + fall_height[:, contact_body_ids] = False + fall_height = torch.any(fall_height, dim=-1) + + ############################## Debug ############################## + # mujoco_joint_names = np.array(['Pelvis', 'L_Hip', 'L_Knee', 'L_Ankle', 'L_Toe', 'R_Hip', 'R_Knee', 'R_Ankle', 'R_Toe', 'Torso', 'Spine', 'Chest', 'Neck', 'Head', 'L_Thorax', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'L_Hand', 'R_Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'R_Hand']); print( mujoco_joint_names[masked_contact_buf[0, :, 0].nonzero().cpu().numpy()]) + ############################## Debug ############################## + + has_fallen = torch.logical_and(fall_contact, fall_height) + + # first timestep can sometimes still have nonzero contact forces + # so only check after first couple of steps + has_fallen *= (progress_buf > 1) + terminated = torch.where(has_fallen, torch.ones_like(reset_buf), terminated) + + reset = torch.where(progress_buf >= max_episode_length - 1, torch.ones_like(reset_buf), terminated) + # import ipdb + # ipdb.set_trace() + + return reset, terminated + + +##################################################################### +###=========================jit functions=========================### +##################################################################### + + +@torch.jit.script +def remove_base_rot(quat): + base_rot = quat_conjugate(torch.tensor([[0.5, 0.5, 0.5, 0.5]]).to(quat)) #SMPL + shape = quat.shape[0] + return quat_mul(quat, base_rot.repeat(shape, 1)) + + +@torch.jit.script +def compute_humanoid_observations_smpl(root_pos, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_body_pos, dof_obs_size, dof_offsets, smpl_params, local_root_obs, root_height_obs, upright, has_smpl_params): + # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, int, List[int], Tensor, bool, bool,bool, bool) -> Tensor + root_h = root_pos[:, 2:3] + if not upright: + root_rot = remove_base_rot(root_rot) + heading_rot = torch_utils.calc_heading_quat_inv(root_rot) + + if (local_root_obs): + root_rot_obs = quat_mul(heading_rot, root_rot) + else: + root_rot_obs = root_rot + root_rot_obs = torch_utils.quat_to_tan_norm(root_rot_obs) + + if (not root_height_obs): + root_h_obs = torch.zeros_like(root_h) + else: + root_h_obs = root_h + + local_root_vel = torch_utils.my_quat_rotate(heading_rot, root_vel) + local_root_ang_vel = torch_utils.my_quat_rotate(heading_rot, root_ang_vel) + + root_pos_expand = root_pos.unsqueeze(-2) + local_key_body_pos = key_body_pos - root_pos_expand + + heading_rot_expand = heading_rot.unsqueeze(-2) + heading_rot_expand = heading_rot_expand.repeat((1, local_key_body_pos.shape[1], 1)) + flat_end_pos = local_key_body_pos.view(local_key_body_pos.shape[0] * local_key_body_pos.shape[1], local_key_body_pos.shape[2]) + flat_heading_rot = heading_rot_expand.view(heading_rot_expand.shape[0] * heading_rot_expand.shape[1], heading_rot_expand.shape[2]) + local_end_pos = torch_utils.my_quat_rotate(flat_heading_rot, flat_end_pos) + flat_local_key_pos = local_end_pos.view(local_key_body_pos.shape[0], local_key_body_pos.shape[1] * local_key_body_pos.shape[2]) + + dof_obs = dof_to_obs(dof_pos, dof_obs_size, dof_offsets) + + obs_list = [] + if root_height_obs: + obs_list.append(root_h_obs) + obs_list += [ + root_rot_obs, + local_root_vel, + local_root_ang_vel, + dof_obs, + dof_vel, + flat_local_key_pos, + ] + if has_smpl_params: + obs_list.append(smpl_params) + obs = torch.cat(obs_list, dim=-1) + + return obs + + +@torch.jit.script +def compute_humanoid_observations_smpl_max(body_pos, body_rot, body_vel, body_ang_vel, smpl_params, limb_weight_params, local_root_obs, root_height_obs, upright, has_smpl_params, has_limb_weight_params): + # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, bool, bool, bool, bool, bool) -> Tensor + root_pos = body_pos[:, 0, :] + root_rot = body_rot[:, 0, :] + + root_h = root_pos[:, 2:3] + if not upright: + root_rot = remove_base_rot(root_rot) + heading_rot_inv = torch_utils.calc_heading_quat_inv(root_rot) + + if (not root_height_obs): + root_h_obs = torch.zeros_like(root_h) + else: + root_h_obs = root_h + + heading_rot_inv_expand = heading_rot_inv.unsqueeze(-2) + heading_rot_inv_expand = heading_rot_inv_expand.repeat((1, body_pos.shape[1], 1)) + flat_heading_rot_inv = heading_rot_inv_expand.reshape(heading_rot_inv_expand.shape[0] * heading_rot_inv_expand.shape[1], heading_rot_inv_expand.shape[2]) + + root_pos_expand = root_pos.unsqueeze(-2) + local_body_pos = body_pos - root_pos_expand + flat_local_body_pos = local_body_pos.reshape(local_body_pos.shape[0] * local_body_pos.shape[1], local_body_pos.shape[2]) + flat_local_body_pos = torch_utils.my_quat_rotate(flat_heading_rot_inv, flat_local_body_pos) + local_body_pos = flat_local_body_pos.reshape(local_body_pos.shape[0], local_body_pos.shape[1] * local_body_pos.shape[2]) + local_body_pos = local_body_pos[..., 3:] # remove root pos + + flat_body_rot = body_rot.reshape(body_rot.shape[0] * body_rot.shape[1], body_rot.shape[2]) # This is global rotation of the body + flat_local_body_rot = quat_mul(flat_heading_rot_inv, flat_body_rot) + flat_local_body_rot_obs = torch_utils.quat_to_tan_norm(flat_local_body_rot) + local_body_rot_obs = flat_local_body_rot_obs.reshape(body_rot.shape[0], body_rot.shape[1] * flat_local_body_rot_obs.shape[1]) + + if not (local_root_obs): + root_rot_obs = torch_utils.quat_to_tan_norm(root_rot) # If not local root obs, you override it. + local_body_rot_obs[..., 0:6] = root_rot_obs + + flat_body_vel = body_vel.reshape(body_vel.shape[0] * body_vel.shape[1], body_vel.shape[2]) + flat_local_body_vel = torch_utils.my_quat_rotate(flat_heading_rot_inv, flat_body_vel) + local_body_vel = flat_local_body_vel.reshape(body_vel.shape[0], body_vel.shape[1] * body_vel.shape[2]) + + flat_body_ang_vel = body_ang_vel.reshape(body_ang_vel.shape[0] * body_ang_vel.shape[1], body_ang_vel.shape[2]) + flat_local_body_ang_vel = torch_utils.my_quat_rotate(flat_heading_rot_inv, flat_body_ang_vel) + local_body_ang_vel = flat_local_body_ang_vel.reshape(body_ang_vel.shape[0], body_ang_vel.shape[1] * body_ang_vel.shape[2]) + + obs_list = [] + if root_height_obs: + obs_list.append(root_h_obs) + obs_list += [local_body_pos, local_body_rot_obs, local_body_vel, local_body_ang_vel] + + if has_smpl_params: + obs_list.append(smpl_params) + + if has_limb_weight_params: + obs_list.append(limb_weight_params) + + obs = torch.cat(obs_list, dim=-1) + return obs + + +@torch.jit.script +def compute_humanoid_observations_smpl_max_v2(body_pos, body_rot, body_vel, body_ang_vel, smpl_params, limb_weight_params, local_root_obs, root_height_obs, upright, has_smpl_params, has_limb_weight_params, time_steps): + ### V2 has time steps. + # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, bool, bool, bool, bool, bool, int) -> Tensor + root_pos = body_pos[:, -1, 0, :] + root_rot = body_rot[:, -1, 0, :] + B, T, J, C = body_pos.shape + + if not upright: + root_rot = remove_base_rot(root_rot) + + root_h_obs = root_pos[:, 2:3] + heading_inv_rot = torch_utils.calc_heading_quat_inv(root_rot) + heading_rot = torch_utils.calc_heading_quat(root_rot) + # heading_rot_inv_expand = heading_inv_rot.unsqueeze(-2) + # heading_rot_inv_expand = heading_rot_inv_expand.repeat((1, body_pos.shape[1], 1)) + # flat_heading_rot_inv = heading_rot_inv_expand.reshape(heading_rot_inv_expand.shape[0] * heading_rot_inv_expand.shape[1], heading_rot_inv_expand.shape[2]) + + heading_inv_rot_expand = heading_inv_rot.unsqueeze(-2).repeat((1, J, 1)).repeat_interleave(time_steps, 0).view(-1, 4) + heading_rot_expand = heading_rot.unsqueeze(-2).repeat((1, J, 1)).repeat_interleave(time_steps, 0) + + root_pos_expand = root_pos.unsqueeze(-2).unsqueeze(-2) + local_body_pos = body_pos - root_pos_expand + flat_local_body_pos = torch_utils.my_quat_rotate(heading_inv_rot_expand, local_body_pos.view(-1, 3)) + local_body_pos = flat_local_body_pos.reshape(B, time_steps, J * C) + local_body_pos = local_body_pos[..., 3:] # remove root pos + + flat_local_body_rot = quat_mul(heading_inv_rot_expand, body_rot.view(-1, 4)) + local_body_rot_obs = torch_utils.quat_to_tan_norm(flat_local_body_rot).view(B, time_steps, J * 6) + + if not (local_root_obs): + root_rot_obs = torch_utils.quat_to_tan_norm(body_rot[:, :, 0].view(-1, 4)) # If not local root obs, you override it. + local_body_rot_obs[..., 0:6] = root_rot_obs + + local_body_vel = torch_utils.my_quat_rotate(heading_inv_rot_expand, body_vel.view(-1, 3)).view(B, time_steps, J * 3) + + local_body_ang_vel = torch_utils.my_quat_rotate(heading_inv_rot_expand, body_ang_vel.view(-1, 3)).view(B, time_steps, J * 3) + + ##################### Compute_history ##################### + body_obs = torch.cat([local_body_pos, local_body_rot_obs, local_body_vel, local_body_ang_vel], dim = -1) + + obs_list = [] + if root_height_obs: + body_obs = torch.cat([body_pos[:, :, 0, 2:3], body_obs], dim = -1) + + + obs_list += [local_body_pos, local_body_rot_obs, local_body_vel, local_body_ang_vel] + + if has_smpl_params: + raise NotImplementedError + + if has_limb_weight_params: + raise NotImplementedError + + obs = body_obs.view(B, -1) + return obs + + + +@torch.jit.script +def compute_humanoid_observations_smpl_max_v3(body_pos, body_rot, body_vel, body_ang_vel, force_sensor_readings, smpl_params, limb_weight_params, local_root_obs, root_height_obs, upright, has_smpl_params, has_limb_weight_params): + # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, bool, bool, bool, bool, bool) -> Tensor + root_pos = body_pos[:, 0, :] + root_rot = body_rot[:, 0, :] + + root_h = root_pos[:, 2:3] + if not upright: + root_rot = remove_base_rot(root_rot) + heading_rot_inv = torch_utils.calc_heading_quat_inv(root_rot) + + if (not root_height_obs): + root_h_obs = torch.zeros_like(root_h) + else: + root_h_obs = root_h + + heading_rot_inv_expand = heading_rot_inv.unsqueeze(-2) + heading_rot_inv_expand = heading_rot_inv_expand.repeat((1, body_pos.shape[1], 1)) + flat_heading_rot_inv = heading_rot_inv_expand.reshape(heading_rot_inv_expand.shape[0] * heading_rot_inv_expand.shape[1], heading_rot_inv_expand.shape[2]) + + root_pos_expand = root_pos.unsqueeze(-2) + local_body_pos = body_pos - root_pos_expand + flat_local_body_pos = local_body_pos.reshape(local_body_pos.shape[0] * local_body_pos.shape[1], local_body_pos.shape[2]) + flat_local_body_pos = torch_utils.my_quat_rotate(flat_heading_rot_inv, flat_local_body_pos) + local_body_pos = flat_local_body_pos.reshape(local_body_pos.shape[0], local_body_pos.shape[1] * local_body_pos.shape[2]) + local_body_pos = local_body_pos[..., 3:] # remove root pos + + flat_body_rot = body_rot.reshape(body_rot.shape[0] * body_rot.shape[1], body_rot.shape[2]) # This is global rotation of the body + flat_local_body_rot = quat_mul(flat_heading_rot_inv, flat_body_rot) + flat_local_body_rot_obs = torch_utils.quat_to_tan_norm(flat_local_body_rot) + local_body_rot_obs = flat_local_body_rot_obs.reshape(body_rot.shape[0], body_rot.shape[1] * flat_local_body_rot_obs.shape[1]) + + if not (local_root_obs): + root_rot_obs = torch_utils.quat_to_tan_norm(root_rot) # If not local root obs, you override it. + local_body_rot_obs[..., 0:6] = root_rot_obs + + flat_body_vel = body_vel.reshape(body_vel.shape[0] * body_vel.shape[1], body_vel.shape[2]) + flat_local_body_vel = torch_utils.my_quat_rotate(flat_heading_rot_inv, flat_body_vel) + local_body_vel = flat_local_body_vel.reshape(body_vel.shape[0], body_vel.shape[1] * body_vel.shape[2]) + + flat_body_ang_vel = body_ang_vel.reshape(body_ang_vel.shape[0] * body_ang_vel.shape[1], body_ang_vel.shape[2]) + flat_local_body_ang_vel = torch_utils.my_quat_rotate(flat_heading_rot_inv, flat_body_ang_vel) + local_body_ang_vel = flat_local_body_ang_vel.reshape(body_ang_vel.shape[0], body_ang_vel.shape[1] * body_ang_vel.shape[2]) + + + obs_list = [] + if root_height_obs: + obs_list.append(root_h_obs) + obs_list += [local_body_pos, local_body_rot_obs, local_body_vel, local_body_ang_vel, force_sensor_readings] + + if has_smpl_params: + obs_list.append(smpl_params) + + if has_limb_weight_params: + obs_list.append(limb_weight_params) + + obs = torch.cat(obs_list, dim=-1) + return obs \ No newline at end of file diff --git a/phc/env/tasks/humanoid_amp.py b/phc/env/tasks/humanoid_amp.py new file mode 100644 index 0000000..e0ce495 --- /dev/null +++ b/phc/env/tasks/humanoid_amp.py @@ -0,0 +1,1009 @@ +# Copyright (c) 2018-2023, NVIDIA Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from ast import Try +import glob +import os +import sys +import pdb +import os.path as osp + +sys.path.append(os.getcwd()) +from enum import Enum +from matplotlib.pyplot import flag +import numpy as np +import torch +from torch import Tensor +from typing import Dict, Optional + +from isaacgym import gymapi +from isaacgym import gymtorch + +from phc.env.tasks.humanoid import Humanoid, dof_to_obs, remove_base_rot, dof_to_obs_smpl +from phc.env.util import gym_util +from phc.utils.motion_lib_smpl import MotionLibSMPL +from phc.utils.motion_lib_base import FixHeightMode +from easydict import EasyDict + +from isaacgym.torch_utils import * +from phc.utils import torch_utils + +from smpl_sim.smpllib.smpl_parser import ( + SMPL_Parser, + SMPLH_Parser, + SMPLX_Parser, +) +import gc +from phc.utils.flags import flags +from collections import OrderedDict + +HACK_MOTION_SYNC = False +# HACK_MOTION_SYNC = True +HACK_CONSISTENCY_TEST = False +HACK_OUTPUT_MOTION = False +HACK_OUTPUT_MOTION_ALL = False + + +class HumanoidAMP(Humanoid): + + class StateInit(Enum): + Default = 0 + Start = 1 + Random = 2 + Hybrid = 3 + + def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, headless): + if (HACK_MOTION_SYNC or HACK_CONSISTENCY_TEST): + control_freq_inv = cfg["env"]["controlFrequencyInv"] + self._motion_sync_dt = control_freq_inv * sim_params.dt + cfg["env"]["controlFrequencyInv"] = 1 + cfg["env"]["pd_control"] = False + + state_init = cfg["env"]["stateInit"] + + self._state_init = HumanoidAMP.StateInit[state_init] + self._hybrid_init_prob = cfg["env"]["hybridInitProb"] + self._num_amp_obs_steps = cfg["env"]["numAMPObsSteps"] + self._amp_root_height_obs = cfg["env"].get("ampRootHeightObs", cfg["env"].get("root_height_obs", True)) + + self._num_amp_obs_enc_steps = cfg["env"].get("numAMPEncObsSteps", self._num_amp_obs_steps) # Calm + + assert (self._num_amp_obs_steps >= 2) + + if ("enableHistObs" in cfg["env"]): + self._enable_hist_obs = cfg["env"]["enableHistObs"] + else: + self._enable_hist_obs = False + + self._reset_default_env_ids = [] + self._reset_ref_env_ids = [] + self._state_reset_happened = False + + super().__init__(cfg=cfg, sim_params=sim_params, physics_engine=physics_engine, device_type=device_type, device_id=device_id, headless=headless) + + self._motion_start_times = torch.zeros(self.num_envs).to(self.device) + self._sampled_motion_ids = torch.zeros(self.num_envs).long().to(self.device) + motion_file = cfg['env']['motion_file'] + self._load_motion(motion_file) + + self._amp_obs_buf = torch.zeros((self.num_envs, self._num_amp_obs_steps, self._num_amp_obs_per_step), device=self.device, dtype=torch.float) + self._curr_amp_obs_buf = self._amp_obs_buf[:, 0] + self._hist_amp_obs_buf = self._amp_obs_buf[:, 1:] + + self._amp_obs_demo_buf = None + + data_dir = "data/smpl" + self.smpl_parser_n = SMPL_Parser(model_path=data_dir, gender="neutral").to(self.device) + self.smpl_parser_m = SMPL_Parser(model_path=data_dir, gender="male").to(self.device) + self.smpl_parser_f = SMPL_Parser(model_path=data_dir, gender="female").to(self.device) + + self.start = True # camera flag + self.ref_motion_cache = {} + + # ZL Hack + self._add_amp_input_noise = cfg["env"].get("add_amp_input_noise", False) + return + + ## Disabled. + # def get_self_obs_size(self): + # if self.obs_v == 2: + # return self._num_self_obs * self.past_track_steps + # else: + # return self._num_self_obs + + def _compute_observations(self, env_ids=None): + if env_ids is None: + env_ids = torch.arange(self.num_envs).to(self.device) + obs = self._compute_humanoid_obs(env_ids) + + + if self.obs_v == 2: + # Double sub will return a copy. + B, N = obs.shape + sums = self.obs_buf[env_ids, 0:self.past_track_steps].abs().sum(dim=1) + zeros = sums == 0 + nonzero = ~zeros + obs_slice = self.obs_buf[env_ids] + obs_slice[zeros] = torch.tile(obs[zeros], (1, self.past_track_steps)) + obs_slice[nonzero] = torch.cat([obs_slice[nonzero, N:], obs[nonzero]], dim=-1) + self.obs_buf[env_ids] = obs_slice + else: + self.obs_buf[env_ids] = obs + + return + + def resample_motions(self): + # self.gym.destroy_sim(self.sim) + # del self.sim + # if not self.headless: + # self.gym.destroy_viewer(self.viewer) + # self.create_sim() + # self.gym.prepare_sim(self.sim) + # self.create_viewer() + # self._setup_tensors() + print("Partial solution, only resample motions...") + self._motion_lib.load_motions(skeleton_trees=self.skeleton_trees, limb_weights=self.humanoid_limb_and_weights.cpu(), gender_betas=self.humanoid_shapes.cpu()) # For now, only need to sample motions since there are only 400 hmanoids + # self.reset() + # torch.cuda.empty_cache() + # gc.collect() + + def pre_physics_step(self, actions): + if (HACK_MOTION_SYNC or HACK_CONSISTENCY_TEST): + actions *= 0 + + super().pre_physics_step(actions) + return + + def get_task_obs_size_detail(self): + task_obs_detail = OrderedDict() + + + return task_obs_detail + + def post_physics_step(self): + super().post_physics_step() + + + if (HACK_MOTION_SYNC): + self._hack_motion_sync() + + if (HACK_OUTPUT_MOTION): + self._hack_output_motion() + + self._update_hist_amp_obs() # One step for the amp obs + + self._compute_amp_observations() + + amp_obs_flat = self._amp_obs_buf.view(-1, self.get_num_amp_obs()) + self.extras["amp_obs"] = amp_obs_flat ## ZL: hooks for adding amp_obs for trianing + return + + def get_num_amp_obs(self): + return self._num_amp_obs_steps * self._num_amp_obs_per_step + + def fetch_amp_obs_demo(self, num_samples): + # Creates the reference motion amp obs. For discrinminiator + + if (self._amp_obs_demo_buf is None): + self._build_amp_obs_demo_buf(num_samples) + else: + assert (self._amp_obs_demo_buf.shape[0] == num_samples) + + motion_ids = self._motion_lib.sample_motions(num_samples) + motion_times0 = self._sample_time(motion_ids) + amp_obs_demo = self.build_amp_obs_demo(motion_ids, motion_times0) + self._amp_obs_demo_buf[:] = amp_obs_demo.view(self._amp_obs_demo_buf.shape) + amp_obs_demo_flat = self._amp_obs_demo_buf.view(-1, self.get_num_amp_obs()) + + + return amp_obs_demo_flat + + def build_amp_obs_demo_steps(self, motion_ids, motion_times0, num_steps): + dt = self.dt + + motion_ids = torch.tile(motion_ids.unsqueeze(-1), [1, num_steps]) + motion_times = motion_times0.unsqueeze(-1) + time_steps = -dt * torch.arange(0, num_steps, device=self.device) + motion_times = torch.clip(motion_times + time_steps, min=0) + + motion_ids = motion_ids.view(-1) + motion_times = motion_times.view(-1) + motion_res = self._motion_lib.get_motion_state(motion_ids, motion_times) + root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, smpl_params, limb_weights, pose_aa, rb_pos, rb_rot, body_vel, body_ang_vel = \ + motion_res["root_pos"], motion_res["root_rot"], motion_res["dof_pos"], motion_res["root_vel"], motion_res["root_ang_vel"], motion_res["dof_vel"], \ + motion_res["motion_bodies"], motion_res["motion_limb_weights"], motion_res["motion_aa"], motion_res["rg_pos"], motion_res["rb_rot"], motion_res["body_vel"], motion_res["body_ang_vel"] + + key_pos = rb_pos[:, self._key_body_ids] + key_vel = body_vel[:, self._key_body_ids] + amp_obs_demo = self._compute_amp_observations_from_state(root_pos, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_pos, key_vel, smpl_params, limb_weights, self.dof_subset, self._local_root_obs, self._amp_root_height_obs, self._has_dof_subset, self._has_shape_obs_disc, self._has_limb_weight_obs_disc, + self._has_upright_start) + return amp_obs_demo + + def build_amp_obs_demo(self, motion_ids, motion_times0): + # Compute observation for the motion starting point + dt = self.dt + motion_ids = torch.tile(motion_ids.unsqueeze(-1), [1, self._num_amp_obs_steps]) + + motion_times = motion_times0.unsqueeze(-1) + time_steps = -dt * torch.arange(0, self._num_amp_obs_steps, device=self.device) + motion_times = motion_times + time_steps + + motion_ids = motion_ids.view(-1) + motion_times = motion_times.view(-1) + + if self.humanoid_type in ["smpl", "smplh", "smplx"]: + motion_res = self._get_state_from_motionlib_cache(motion_ids, motion_times) + + root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, smpl_params, limb_weights, pose_aa, rb_pos, rb_rot, body_vel, body_ang_vel = \ + motion_res["root_pos"], motion_res["root_rot"], motion_res["dof_pos"], motion_res["root_vel"], motion_res["root_ang_vel"], motion_res["dof_vel"], \ + motion_res["motion_bodies"], motion_res["motion_limb_weights"], motion_res["motion_aa"], motion_res["rg_pos"], motion_res["rb_rot"], motion_res["body_vel"], motion_res["body_ang_vel"] + + key_pos = rb_pos[:, self._key_body_ids] + key_vel = body_vel[:, self._key_body_ids] + amp_obs_demo = self._compute_amp_observations_from_state(root_pos, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_pos, key_vel, smpl_params, limb_weights, self.dof_subset, self._local_root_obs, self._amp_root_height_obs, self._has_dof_subset, self._has_shape_obs_disc, self._has_limb_weight_obs_disc, + self._has_upright_start) + else: + root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, key_pos = self._motion_lib.get_motion_state_amp(motion_ids, motion_times) + + amp_obs_demo = build_amp_observations(root_pos, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_pos, self._local_root_obs, self._amp_root_height_obs, self._dof_obs_size, self._dof_offsets) + + if self._add_amp_input_noise: + amp_obs_demo = amp_obs_demo + torch.randn_like(amp_obs_demo) * 0.01 + + return amp_obs_demo + + def _build_amp_obs_demo_buf(self, num_samples): + self._amp_obs_demo_buf = torch.zeros((num_samples, self._num_amp_obs_steps, self._num_amp_obs_per_step), device=self.device, dtype=torch.float32) + return + + def _setup_character_props(self, key_bodies): + super()._setup_character_props(key_bodies) + # ZL: + + asset_file = self.cfg.robot.asset.assetFileName + num_key_bodies = len(key_bodies) + + if (asset_file == "mjcf/amp_humanoid.xml"): + self._num_amp_obs_per_step = 13 + self._dof_obs_size + 28 + 3 * num_key_bodies # [root_h, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_body_pos] + elif self.humanoid_type in ["smpl", "smplh", "smplx"]: + if self.amp_obs_v == 1: + self._num_amp_obs_per_step = 13 + self._dof_obs_size + len(self._dof_names) * 3 + 3 * num_key_bodies # [root_h, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_body_pos] + else: + self._num_amp_obs_per_step = 13 + self._dof_obs_size + len(self._dof_names) * 3 + 6 * num_key_bodies # [root_h, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_body_pos, key_body_vel] + + if not self._amp_root_height_obs: + self._num_amp_obs_per_step -= 1 + + if self._has_dof_subset: + self._num_amp_obs_per_step -= (6 + 3) * int((len(self._dof_names) * 3 - len(self.dof_subset)) / 3) + + if self._has_shape_obs_disc: + self._num_amp_obs_per_step += 11 if (asset_file == "mjcf/smpl_humanoid.xml") else 12 + if self._has_limb_weight_obs_disc: + self._num_amp_obs_per_step += 10 + else: + print("Unsupported character config file: {s}".format(asset_file)) + assert (False) + + if (self._enable_hist_obs): + self._num_self_obs += self._num_amp_obs_steps * self._num_amp_obs_per_step + return + + def _load_motion(self, motion_file): + assert (self._dof_offsets[-1] == self.num_dof) + if self.humanoid_type in ["smpl", "smplh", "smplx"]: + motion_lib_cfg = EasyDict({ + "motion_file": motion_file, + "device": torch.device("cpu"), + "fix_height": FixHeightMode.full_fix, + "min_length": -1, + "max_length": -1, + "im_eval": flags.im_eval, + "multi_thread": True , + "smpl_type": self.humanoid_type, + "randomrize_heading": True, + "device": self.device, + "min_length": self._min_motion_len, + }) + self._motion_lib = MotionLibSMPL(motion_lib_cfg=motion_lib_cfg) + self._motion_lib.load_motions(skeleton_trees=self.skeleton_trees, gender_betas=self.humanoid_shapes.cpu(), limb_weights=self.humanoid_limb_and_weights.cpu(), random_sample=not HACK_MOTION_SYNC) + else: + self._motion_lib = MotionLib(motion_file=motion_file, dof_body_ids=self._dof_body_ids, dof_offsets=self._dof_offsets, key_body_ids=self._key_body_ids.cpu().numpy(), device=self.device) + + return + + def _reset_envs(self, env_ids): + self._reset_default_env_ids = [] + self._reset_ref_env_ids = [] + if len(env_ids) > 0: + self._state_reset_happened = True + + super()._reset_envs(env_ids) + self._init_amp_obs(env_ids) + + return + + def _reset_actors(self, env_ids): + if (self._state_init == HumanoidAMP.StateInit.Default): + self._reset_default(env_ids) + elif (self._state_init == HumanoidAMP.StateInit.Start or self._state_init == HumanoidAMP.StateInit.Random): + self._reset_ref_state_init(env_ids) + elif (self._state_init == HumanoidAMP.StateInit.Hybrid): + self._reset_hybrid_state_init(env_ids) + else: + assert (False), "Unsupported state initialization strategy: {:s}".format(str(self._state_init)) + return + + def _reset_default(self, env_ids): + self._humanoid_root_states[env_ids] = self._initial_humanoid_root_states[env_ids] + self._dof_pos[env_ids] = self._initial_dof_pos[env_ids] + self._dof_vel[env_ids] = self._initial_dof_vel[env_ids] + self._reset_default_env_ids = env_ids + return + + def _sample_time(self, motion_ids): + if self.humanoid_type in ["smpl", "smplh", "smplx"]: + return self._motion_lib.sample_time_interval(motion_ids) + else: + return self._motion_lib.sample_time(motion_ids) + + def _get_fixed_smpl_state_from_motionlib(self, motion_ids, motion_times, curr_gender_betas): + # Used for intialization. Not used for sampling. Only used for AMP, not imitation. + motion_res = self._get_state_from_motionlib_cache(motion_ids, motion_times) + root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, _, pose_aa, rb_pos, rb_rot, body_vel, body_ang_vel = \ + motion_res["root_pos"], motion_res["root_rot"], motion_res["dof_pos"], motion_res["root_vel"], motion_res["root_ang_vel"], motion_res["dof_vel"], \ + motion_res["motion_bodies"], motion_res["motion_aa"], motion_res["rg_pos"], motion_res["rb_rot"], motion_res["body_vel"], motion_res["body_ang_vel"] + + with torch.no_grad(): + gender = curr_gender_betas[:, 0] + betas = curr_gender_betas[:, 1:] + B, _ = betas.shape + + genders_curr = gender == 2 + height_tolorance = 0.02 + if genders_curr.sum() > 0: + poses_curr = pose_aa[genders_curr] + root_pos_curr = root_pos[genders_curr] + betas_curr = betas[genders_curr] + vertices_curr, joints_curr = self.smpl_parser_f.get_joints_verts(poses_curr, betas_curr, root_pos_curr) + offset = joints_curr[:, 0] - root_pos[genders_curr] + diff_fix = ((vertices_curr - offset[:, None])[..., -1].min(dim=-1).values - height_tolorance) + root_pos[genders_curr, ..., -1] -= diff_fix + rb_pos[genders_curr, ..., -1] -= diff_fix[:, None] + + genders_curr = gender == 1 + if genders_curr.sum() > 0: + poses_curr = pose_aa[genders_curr] + root_pos_curr = root_pos[genders_curr] + betas_curr = betas[genders_curr] + vertices_curr, joints_curr = self.smpl_parser_m.get_joints_verts(poses_curr, betas_curr, root_pos_curr) + + offset = joints_curr[:, 0] - root_pos[genders_curr] + diff_fix = ((vertices_curr - offset[:, None])[..., -1].min(dim=-1).values - height_tolorance) + root_pos[genders_curr, ..., -1] -= diff_fix + rb_pos[genders_curr, ..., -1] -= diff_fix[:, None] + + genders_curr = gender == 0 + if genders_curr.sum() > 0: + poses_curr = pose_aa[genders_curr] + root_pos_curr = root_pos[genders_curr] + betas_curr = betas[genders_curr] + vertices_curr, joints_curr = self.smpl_parser_n.get_joints_verts(poses_curr, betas_curr, root_pos_curr) + + offset = joints_curr[:, 0] - root_pos[genders_curr] + diff_fix = ((vertices_curr - offset[:, None])[..., -1].min(dim=-1).values - height_tolorance) + root_pos[genders_curr, ..., -1] -= diff_fix + rb_pos[genders_curr, ..., -1] -= diff_fix[:, None] + + return root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, rb_pos, rb_rot, body_vel, body_ang_vel + + def _get_state_from_motionlib_cache(self, motion_ids, motion_times, offset=None): + ## Cache the motion + offset + if offset is None or not "motion_ids" in self.ref_motion_cache or self.ref_motion_cache['offset'] is None or len(self.ref_motion_cache['motion_ids']) != len(motion_ids) or len(self.ref_motion_cache['offset']) != len(offset) \ + or (self.ref_motion_cache['motion_ids'] - motion_ids).abs().sum() + (self.ref_motion_cache['motion_times'] - motion_times).abs().sum() + (self.ref_motion_cache['offset'] - offset).abs().sum() > 0 : + self.ref_motion_cache['motion_ids'] = motion_ids.clone() # need to clone; otherwise will be overriden + self.ref_motion_cache['motion_times'] = motion_times.clone() # need to clone; otherwise will be overriden + self.ref_motion_cache['offset'] = offset.clone() if not offset is None else None + else: + return self.ref_motion_cache + motion_res = self._motion_lib.get_motion_state(motion_ids, motion_times, offset=offset) + + self.ref_motion_cache.update(motion_res) + + return self.ref_motion_cache + + def _sample_ref_state(self, env_ids): + + num_envs = env_ids.shape[0] + motion_ids = self._motion_lib.sample_motions(num_envs) + + if (self._state_init == HumanoidAMP.StateInit.Random or self._state_init == HumanoidAMP.StateInit.Hybrid): + motion_times = self._sample_time(motion_ids) + elif (self._state_init == HumanoidAMP.StateInit.Start): + motion_times = torch.zeros(num_envs, device=self.device) + else: + assert (False), "Unsupported state initialization strategy: {:s}".format(str(self._state_init)) + + if self.humanoid_type in ["smpl", "smplh", "smplx"]: + curr_gender_betas = self.humanoid_shapes[env_ids] + root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, rb_pos, rb_rot, body_vel, body_ang_vel = self._get_fixed_smpl_state_from_motionlib(motion_ids, motion_times, curr_gender_betas) + else: + root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, key_pos = self._motion_lib.get_motion_state_amp(motion_ids, motion_times) + rb_pos, rb_rot = None, None + + return motion_ids, motion_times, root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, rb_pos, rb_rot, body_vel, body_ang_vel + + def _reset_ref_state_init(self, env_ids): + num_envs = env_ids.shape[0] + motion_ids, motion_times, root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, rb_pos, rb_rot, body_vel, body_ang_vel = self._sample_ref_state(env_ids) + + # if flags.debug: + # print('raising for debug') + # root_pos[..., 2] += 0.5 + + # if flags.fixed: + # x_grid, y_grid = torch.meshgrid(torch.arange(64), torch.arange(64)) + # root_pos[:, 0], root_pos[:, 1] = x_grid.flatten()[env_ids] * 2, y_grid.flatten()[env_ids] * 2 + self._set_env_state(env_ids=env_ids, root_pos=root_pos, root_rot=root_rot, dof_pos=dof_pos, root_vel=root_vel, root_ang_vel=root_ang_vel, dof_vel=dof_vel, rigid_body_pos=rb_pos, rigid_body_rot=rb_rot, rigid_body_vel=body_vel, rigid_body_ang_vel=body_ang_vel) + + self._reset_ref_env_ids = env_ids + self._reset_ref_motion_ids = motion_ids + self._reset_ref_motion_times = motion_times + self._motion_start_times[env_ids] = motion_times + self._sampled_motion_ids[env_ids] = motion_ids + if flags.follow: + self.start = True ## Updating camera when reset + return + + def _reset_hybrid_state_init(self, env_ids): + num_envs = env_ids.shape[0] + ref_probs = to_torch(np.array([self._hybrid_init_prob] * num_envs), device=self.device) + ref_init_mask = torch.bernoulli(ref_probs) == 1.0 + + ref_reset_ids = env_ids[ref_init_mask] + + if (len(ref_reset_ids) > 0): + self._reset_ref_state_init(ref_reset_ids) + + default_reset_ids = env_ids[torch.logical_not(ref_init_mask)] + if (len(default_reset_ids) > 0): + self._reset_default(default_reset_ids) + + return + + def _compute_humanoid_obs(self, env_ids=None): + obs = super()._compute_humanoid_obs(env_ids) + + if (self._enable_hist_obs): + if (env_ids is None): + hist_obs = self._amp_obs_buf.view(-1, self.get_num_amp_obs()) + else: + hist_obs = self._amp_obs_buf[env_ids].view(-1, self.get_num_amp_obs()) + + obs = torch.cat([obs, hist_obs], dim=-1) + + return obs + + def _init_amp_obs(self, env_ids): + self._compute_amp_observations(env_ids) + + if (len(self._reset_default_env_ids) > 0): + self._init_amp_obs_default(self._reset_default_env_ids) + + if (len(self._reset_ref_env_ids) > 0): + self._init_amp_obs_ref(self._reset_ref_env_ids, self._reset_ref_motion_ids, self._reset_ref_motion_times) + + return + + def _init_amp_obs_default(self, env_ids): + curr_amp_obs = self._curr_amp_obs_buf[env_ids].unsqueeze(-2) + self._hist_amp_obs_buf[env_ids] = curr_amp_obs + return + + def _init_amp_obs_ref(self, env_ids, motion_ids, motion_times): + dt = self.dt + motion_ids = torch.tile(motion_ids.unsqueeze(-1), [1, self._num_amp_obs_steps - 1]) + motion_times = motion_times.unsqueeze(-1) + + time_steps = -dt * (torch.arange(0, self._num_amp_obs_steps - 1, device=self.device) + 1) + motion_times = motion_times + time_steps + + motion_ids = motion_ids.view(-1) + motion_times = motion_times.view(-1) + + if self.humanoid_type in ["smpl", "smplh", "smplx"] : + motion_res = self._get_state_from_motionlib_cache(motion_ids, motion_times) + root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, smpl_params, limb_weights, pose_aa, rb_pos, rb_rot, body_vel, body_ang_vel = \ + motion_res["root_pos"], motion_res["root_rot"], motion_res["dof_pos"], motion_res["root_vel"], motion_res["root_ang_vel"], motion_res["dof_vel"], \ + motion_res["motion_bodies"], motion_res["motion_limb_weights"], motion_res["motion_aa"], motion_res["rg_pos"], motion_res["rb_rot"], motion_res["body_vel"], motion_res["body_ang_vel"] + + key_pos = rb_pos[:, self._key_body_ids] + key_vel = body_vel[:, self._key_body_ids] + amp_obs_demo = self._compute_amp_observations_from_state(root_pos, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_pos, key_vel, smpl_params, limb_weights, self.dof_subset, self._local_root_obs, self._amp_root_height_obs, self._has_dof_subset, self._has_shape_obs_disc, self._has_limb_weight_obs_disc, + self._has_upright_start) + + else: + root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, key_pos \ + = self._motion_lib.get_motion_state_amp(motion_ids, motion_times) + amp_obs_demo = build_amp_observations(root_pos, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_pos, self._local_root_obs, self._amp_root_height_obs, self._dof_obs_size, self._dof_offsets) + + self._hist_amp_obs_buf[env_ids] = amp_obs_demo.view(self._hist_amp_obs_buf[env_ids].shape) + return + + def _set_env_state( + self, + env_ids, + root_pos, + root_rot, + dof_pos, + root_vel, + root_ang_vel, + dof_vel, + rigid_body_pos=None, + rigid_body_rot=None, + rigid_body_vel=None, + rigid_body_ang_vel=None, + ): + self._humanoid_root_states[env_ids, 0:3] = root_pos + self._humanoid_root_states[env_ids, 3:7] = root_rot + self._humanoid_root_states[env_ids, 7:10] = root_vel + self._humanoid_root_states[env_ids, 10:13] = root_ang_vel + self._dof_pos[env_ids] = dof_pos + self._dof_vel[env_ids] = dof_vel + + if (not rigid_body_pos is None) and (not rigid_body_rot is None): + self._rigid_body_pos[env_ids] = rigid_body_pos + self._rigid_body_rot[env_ids] = rigid_body_rot + self._rigid_body_vel[env_ids] = rigid_body_vel + self._rigid_body_ang_vel[env_ids] = rigid_body_ang_vel + + self._reset_rb_pos = self._rigid_body_pos[env_ids].clone() + self._reset_rb_rot = self._rigid_body_rot[env_ids].clone() + self._reset_rb_vel = self._rigid_body_vel[env_ids].clone() + self._reset_rb_ang_vel = self._rigid_body_ang_vel[env_ids].clone() + + return + + def _refresh_sim_tensors(self): + + self.gym.refresh_dof_state_tensor(self.sim) + self.gym.refresh_actor_root_state_tensor(self.sim) + self.gym.refresh_rigid_body_state_tensor(self.sim) + + if self._state_reset_happened and "_reset_rb_pos" in self.__dict__: + # ZL: Hack to get rigidbody pos and rot to be the correct values. Needs to be called after _set_env_state + # Also needs to be after refresh_rigid_body_state_tensor + env_ids = self._reset_ref_env_ids + if len(env_ids) > 0: + self._rigid_body_pos[env_ids] = self._reset_rb_pos + self._rigid_body_rot[env_ids] = self._reset_rb_rot + self._rigid_body_vel[env_ids] = self._reset_rb_vel + self._rigid_body_ang_vel[env_ids] = self._reset_rb_ang_vel + self._state_reset_happened = False + + self.gym.refresh_force_sensor_tensor(self.sim) + self.gym.refresh_dof_force_tensor(self.sim) + self.gym.refresh_net_contact_force_tensor(self.sim) + + return + + def _update_hist_amp_obs(self, env_ids=None): + if (env_ids is None): + try: + self._hist_amp_obs_buf[:] = self._amp_obs_buf[:, 0:(self._num_amp_obs_steps - 1)] + except: + self._hist_amp_obs_buf[:] = self._amp_obs_buf[:, 0:(self._num_amp_obs_steps - 1)].clone() + else: + self._hist_amp_obs_buf[env_ids] = self._amp_obs_buf[env_ids, 0:(self._num_amp_obs_steps - 1)] + return + + def _compute_amp_observations(self, env_ids=None): + key_body_pos = self._rigid_body_pos[:, self._key_body_ids, :] + key_body_vel = self._rigid_body_vel[:, self._key_body_ids, :] + + if self.humanoid_type in ["smpl", "smplh", "smplx"] and self.dof_subset is None: + # ZL hack + self._dof_pos[:, 9:12], self._dof_pos[:, 21:24], self._dof_pos[:, 51:54], self._dof_pos[:, 66:69] = 0, 0, 0, 0 + self._dof_vel[:, 9:12], self._dof_vel[:, 21:24], self._dof_vel[:, 51:54], self._dof_vel[:, 66:69] = 0, 0, 0, 0 + + # if (key_body_pos[..., 2].mean(dim = -1) > 2).sum(): + # self.humanoid_shapes[torch.where((key_body_pos[.. + # ., 2].mean(dim = -1) > 2))].cpu().numpy() + # import ipdb; ipdb.set_trace() + # print('bugg') + # if flags.debug: + # print(torch.topk(self._dof_pos.abs().sum(dim=-1), 5)) + + if (env_ids is None): + if self.humanoid_type in ["smpl", "smplh", "smplx"] : + self._curr_amp_obs_buf[:] = self._compute_amp_observations_from_state(self._rigid_body_pos[:, 0, :], self._rigid_body_rot[:, 0, :], self._rigid_body_vel[:, 0, :], self._rigid_body_ang_vel[:, 0, :], self._dof_pos, self._dof_vel, key_body_pos, key_body_vel, self.humanoid_shapes, self.humanoid_limb_and_weights, + self.dof_subset, self._local_root_obs, self._amp_root_height_obs, self._has_dof_subset, self._has_shape_obs_disc, self._has_limb_weight_obs_disc, self._has_upright_start) + + else: + self._curr_amp_obs_buf[:] = build_amp_observations(self._rigid_body_pos[:, 0, :], self._rigid_body_rot[:, 0, :], self._rigid_body_vel[:, 0, :], self._rigid_body_ang_vel[:, 0, :], self._dof_pos, self._dof_vel, key_body_pos, self._local_root_obs, self._amp_root_height_obs, + self._dof_obs_size, self._dof_offsets) + else: + if len(env_ids) == 0: + return + if self.humanoid_type in ["smpl", "smplh", "smplx"] : + self._curr_amp_obs_buf[env_ids] = self._compute_amp_observations_from_state(self._rigid_body_pos[env_ids][:, 0, :], self._rigid_body_rot[env_ids][:, 0, :], self._rigid_body_vel[env_ids][:, 0, :], self._rigid_body_ang_vel[env_ids][:, 0, :], self._dof_pos[env_ids], self._dof_vel[env_ids], + key_body_pos[env_ids], key_body_vel[env_ids], self.humanoid_shapes[env_ids], self.humanoid_limb_and_weights[env_ids], self.dof_subset, self._local_root_obs, self._amp_root_height_obs, self._has_dof_subset, self._has_shape_obs_disc, + self._has_limb_weight_obs_disc, self._has_upright_start) + else: + self._curr_amp_obs_buf[env_ids] = build_amp_observations(self._rigid_body_pos[env_ids][:, 0, :], self._rigid_body_rot[env_ids][:, 0, :], self._rigid_body_vel[env_ids][:, 0, :], self._rigid_body_ang_vel[env_ids][:, 0, :], self._dof_pos[env_ids], self._dof_vel[env_ids], + key_body_pos[env_ids], self._local_root_obs, self._amp_root_height_obs, self._dof_obs_size, self._dof_offsets) + return + + def _compute_amp_observations_from_state(self, root_pos, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_body_pos, key_body_vels, smpl_params, limb_weight_params, dof_subset, local_root_obs, root_height_obs, has_dof_subset, has_shape_obs_disc, has_limb_weight_obs, upright): + if self.amp_obs_v == 1: + if self.humanoid_type in ["smpl", "smplh", "smplx"]: + smpl_params = smpl_params[:, :-6] + return build_amp_observations_smpl(root_pos, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_body_pos, smpl_params, limb_weight_params, dof_subset, local_root_obs, root_height_obs, has_dof_subset, has_shape_obs_disc, has_limb_weight_obs, upright) + elif self.amp_obs_v == 2: + return build_amp_observations_smpl_v2(root_pos, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_body_pos, key_body_vels, smpl_params, limb_weight_params, dof_subset, local_root_obs, root_height_obs, has_dof_subset, has_shape_obs_disc, has_limb_weight_obs, upright) + + def _hack_motion_sync(self): + + if (not hasattr(self, "_hack_motion_time")): + self._hack_motion_time = 0.0 + + num_motions = self._motion_lib.num_motions() + motion_ids = np.arange(self.num_envs, dtype=np.int) + motion_ids = torch.from_numpy(np.mod(motion_ids, num_motions)) + # motion_ids[:] = 2 + motion_times = torch.tensor([self._hack_motion_time] * self.num_envs, dtype=torch.float32, device=self.device) + if self.humanoid_type in ["smpl", "smplh", "smplx"] : + motion_res = self._get_state_from_motionlib_cache(motion_ids, motion_times) + root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, smpl_params, limb_weights, pose_aa, rb_pos, rb_rot, body_vel, body_ang_vel = \ + motion_res["root_pos"], motion_res["root_rot"], motion_res["dof_pos"], motion_res["root_vel"], motion_res["root_ang_vel"], motion_res["dof_vel"], \ + motion_res["motion_bodies"], motion_res["motion_limb_weights"], motion_res["motion_aa"], motion_res["rg_pos"], motion_res["rb_rot"], motion_res["body_vel"], motion_res["body_ang_vel"] + + # betas = self.humanoid_shapes[0:1, 1:] # ZL Hack before real body variation kicks in + # vertices, joints = self.smpl_parser_n.get_joints_verts( + # torch.cat([ + # torch_utils.quat_to_exp_map(root_rot).to(dof_pos), dof_pos + # ], + # dim=-1), betas, root_pos) + # offset = joints[:, 0] - root_pos + # root_pos[...,-1] -= (vertices - offset[:, None])[..., -1].min(dim=-1).values + # root_pos[...,-1] += 0.03 # ALways slightly above the ground to avoid issue + + else: + root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, key_pos \ + = self._motion_lib.get_motion_state_amp(motion_ids, motion_times) + rb_pos, rb_rot = None, None + + env_ids = torch.arange(self.num_envs, dtype=torch.long, device=self.device) + + self._set_env_state(env_ids=env_ids, root_pos=root_pos, root_rot=root_rot, dof_pos=dof_pos, root_vel=root_vel, root_ang_vel=root_ang_vel, dof_vel=dof_vel, rigid_body_pos=rb_pos, rigid_body_rot=rb_rot, rigid_body_vel=body_vel, rigid_body_ang_vel=body_ang_vel) + + self._reset_env_tensors(env_ids) + motion_dur = self._motion_lib._motion_lengths[0] + self._hack_motion_time = np.fmod(self._hack_motion_time + self._motion_sync_dt, motion_dur.cpu().numpy()) + + return + + def _update_camera(self): + self.gym.refresh_actor_root_state_tensor(self.sim) + char_root_pos = self._humanoid_root_states[self.viewing_env_idx, 0:3].cpu().numpy() + + if self.viewer: + cam_trans = self.gym.get_viewer_camera_transform(self.viewer, None) + cam_pos = np.array([cam_trans.p.x, cam_trans.p.y, cam_trans.p.z]) + else: + cam_pos = np.array([char_root_pos[0] + 2.5, char_root_pos[1] + 2.5, char_root_pos[2]]) + + cam_delta = cam_pos - self._cam_prev_char_pos + + new_cam_target = gymapi.Vec3(char_root_pos[0], char_root_pos[1], char_root_pos[2]) + # if np.abs(cam_pos[2] - char_root_pos[2]) > 5: + cam_pos[2] = char_root_pos[2] + 0.5 + new_cam_pos = gymapi.Vec3(char_root_pos[0] + cam_delta[0], char_root_pos[1] + cam_delta[1], cam_pos[2]) + + self.gym.set_camera_location(self.recorder_camera_handle, self.envs[self.viewing_env_idx], new_cam_pos, new_cam_target) + + if flags.follow: + self.start = True + else: + self.start = False + + if self.start: + self.gym.viewer_camera_look_at(self.viewer, None, new_cam_pos, new_cam_target) + + self._cam_prev_char_pos[:] = char_root_pos + return + + def _hack_consistency_test(self): + if (not hasattr(self, "_hack_motion_time")): + self._hack_motion_time = 0.0 + + motion_ids = np.array([0] * self.num_envs, dtype=np.int) + motion_times = np.array([self._hack_motion_time] * self.num_envs) + root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, key_pos \ + = self._motion_lib.get_motion_state_amp(motion_ids, motion_times) + + env_ids = torch.arange(self.num_envs, dtype=torch.long, device=self.device) + self._set_env_state(env_ids=env_ids, root_pos=root_pos, root_rot=root_rot, dof_pos=dof_pos, root_vel=root_vel, root_ang_vel=root_ang_vel, dof_vel=dof_vel) + + self._reset_env_tensors(env_ids) + + motion_dur = self._motion_lib._motion_lengths[0] + self._hack_motion_time = np.fmod(self._hack_motion_time + self.dt, motion_dur) + + self._refresh_sim_tensors() + + sim_key_body_pos = self._rigid_body_pos[:, self._key_body_ids, :] + if self.humanoid_type in ["smpl", "smplh", "smplx"]: + print("ZL NOT FIXED YET") + sim_amp_obs = build_amp_observations_smpl(self._rigid_body_pos[:, 0, :], self._rigid_body_rot[:, 0, :], self._rigid_body_vel[:, 0, :], self._rigid_body_ang_vel[:, 0, :], self._dof_pos, self._dof_vel, sim_key_body_pos, self._local_root_obs, self._amp_root_height_obs, self._dof_offsets) + + ref_amp_obs = build_amp_observations_smpl(root_pos, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_pos, self._local_root_obs, self._amp_root_height_obs, self._dof_offsets) + else: + sim_amp_obs = build_amp_observations(self._rigid_body_pos[:, 0, :], self._rigid_body_rot[:, 0, :], self._rigid_body_vel[:, 0, :], self._rigid_body_ang_vel[:, 0, :], self._dof_pos, self._dof_vel, sim_key_body_pos, self._local_root_obs, self._amp_root_height_obs, self._dof_obs_size, + self._dof_offsets) + + ref_amp_obs = build_amp_observations(root_pos, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_pos, self._local_root_obs, self._amp_root_height_obs, self._dof_obs_size, self._dof_offsets) + + obs_diff = sim_amp_obs - ref_amp_obs + obs_diff = torch.abs(obs_diff) + obs_err = torch.max(obs_diff, dim=0) + + return + + def _hack_output_motion(self): + fps = 1.0 / self.dt + from poselib.poselib.skeleton.skeleton3d import SkeletonMotion, SkeletonState + from poselib.poselib.visualization.common import plot_skeleton_motion_interactive + + if (not hasattr(self, '_output_motion_root_pos')): + self._output_motion_root_pos = [] + self._output_motion_global_rot = [] + + root_pos = self._humanoid_root_states[..., 0:3].cpu().numpy() + self._output_motion_root_pos.append(root_pos) + + body_rot = self._rigid_body_rot.cpu().numpy() + rot_mask = body_rot[..., -1] < 0 + body_rot[rot_mask] = -body_rot[rot_mask] + self._output_motion_global_rot.append(body_rot) + + reset = self.reset_buf[0].cpu().numpy() == 1 + + if (reset and len(self._output_motion_root_pos) > 1): + output_root_pos = np.array(self._output_motion_root_pos) + output_body_rot = np.array(self._output_motion_global_rot) + output_root_pos = to_torch(output_root_pos, device='cpu') + output_body_rot = to_torch(output_body_rot, device='cpu') + + skeleton_tree = self._motion_lib._motions[0].skeleton_tree + + if (HACK_OUTPUT_MOTION_ALL): + num_envs = self.num_envs + else: + num_envs = 1 + + for i in range(num_envs): + curr_body_rot = output_body_rot[:, i, :] + curr_root_pos = output_root_pos[:, i, :] + sk_state = SkeletonState.from_rotation_and_root_translation(skeleton_tree, curr_body_rot, curr_root_pos, is_local=False) + sk_motion = SkeletonMotion.from_skeleton_state(sk_state, fps=fps) + + output_file = 'output/record_char_motion{:04d}.npy'.format(i) + sk_motion.to_file(output_file) + + #plot_skeleton_motion_interactive(sk_motion) + + self._output_motion_root_pos = [] + self._output_motion_global_rot = [] + + return + + def get_num_enc_amp_obs(self): + return self._num_amp_obs_enc_steps * self._num_amp_obs_per_step + + def fetch_amp_obs_demo_enc_pair(self, num_samples): + motion_ids = self._motion_lib.sample_motions(num_samples) + + # since negative times are added to these values in build_amp_obs_demo, + # we shift them into the range [0 + truncate_time, end of clip] + enc_window_size = self.dt * (self._num_amp_obs_enc_steps - 1) + + enc_motion_times = self._motion_lib.sample_time(motion_ids, truncate_time=enc_window_size) + # make sure not to add more than motion clip length, negative amp_obs will show zero index amp_obs instead + enc_motion_times += torch.clip(self._motion_lib._motion_lengths[motion_ids], max=enc_window_size) + + # sub-window-size is for the amp_obs contained within the enc-amp-obs. make sure we sample only within the valid portion of the motion + sub_window_size = torch.clip(self._motion_lib._motion_lengths[motion_ids], max=enc_window_size) - self.dt * self._num_amp_obs_steps + motion_times = enc_motion_times - torch.rand(enc_motion_times.shape, device=self.device) * sub_window_size + enc_amp_obs_demo = self.build_amp_obs_demo_steps(motion_ids, enc_motion_times, self._num_amp_obs_enc_steps).view(-1, self._num_amp_obs_enc_steps, self._num_amp_obs_per_step) + amp_obs_demo = self.build_amp_obs_demo_steps(motion_ids, motion_times, self._num_amp_obs_steps).view(-1, self._num_amp_obs_steps, self._num_amp_obs_per_step) + + enc_amp_obs_demo_flat = enc_amp_obs_demo.to(self.device).view(-1, self.get_num_enc_amp_obs()) + amp_obs_demo_flat = amp_obs_demo.to(self.device).view(-1, self.get_num_amp_obs()) + + return motion_ids, enc_motion_times, enc_amp_obs_demo_flat, motion_times, amp_obs_demo_flat + + def fetch_amp_obs_demo_pair(self, num_samples): + motion_ids = self._motion_lib.sample_motions(num_samples) + cat_motion_ids = torch.cat((motion_ids, motion_ids), dim=0) + + # since negative times are added to these values in build_amp_obs_demo, + # we shift them into the range [0 + truncate_time, end of clip] + enc_window_size = self.dt * (self._num_amp_obs_enc_steps - 1) + + motion_times0 = self._motion_lib.sample_time(motion_ids, truncate_time=enc_window_size) + motion_times0 += torch.clip(self._motion_lib._motion_lengths[motion_ids], max=enc_window_size) + + motion_times1 = motion_times0 + torch.rand(motion_times0.shape, device=self._motion_lib._device) * 0.5 + motion_times1 = torch.min(motion_times1, self._motion_lib._motion_lengths[motion_ids]) + + motion_times = torch.cat((motion_times0, motion_times1), dim=0) + + amp_obs_demo = self.build_amp_obs_demo_steps(cat_motion_ids, motion_times, self._num_amp_obs_enc_steps).view(-1, self._num_amp_obs_enc_steps, self._num_amp_obs_per_step) + amp_obs_demo0, amp_obs_demo1 = torch.split(amp_obs_demo, num_samples) + + amp_obs_demo0_flat = amp_obs_demo0.to(self.device).view(-1, self.get_num_enc_amp_obs()) + + amp_obs_demo1_flat = amp_obs_demo1.to(self.device).view(-1, self.get_num_enc_amp_obs()) + + return motion_ids, motion_times0, amp_obs_demo0_flat, motion_times1, amp_obs_demo1_flat + + ################## Calm stuff. Patchy I konw... ################## + + +##################################################################### +###=========================jit functions=========================### +##################################################################### +@torch.jit.script +def build_amp_observations(root_pos, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_body_pos, local_root_obs, root_height_obs, dof_obs_size, dof_offsets): + # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, bool, bool, int, List[int]) -> Tensor + root_h = root_pos[:, 2:3] + heading_rot = torch_utils.calc_heading_quat_inv(root_rot) + + if (local_root_obs): + root_rot_obs = quat_mul(heading_rot, root_rot) + else: + root_rot_obs = root_rot + root_rot_obs = torch_utils.quat_to_tan_norm(root_rot_obs) + + if (not root_height_obs): + root_h_obs = torch.zeros_like(root_h) + else: + root_h_obs = root_h + + local_root_vel = torch_utils.my_quat_rotate(heading_rot, root_vel) + local_root_ang_vel = torch_utils.my_quat_rotate(heading_rot, root_ang_vel) + + root_pos_expand = root_pos.unsqueeze(-2) + local_key_body_pos = key_body_pos - root_pos_expand + + heading_rot_expand = heading_rot.unsqueeze(-2) + heading_rot_expand = heading_rot_expand.repeat((1, local_key_body_pos.shape[1], 1)) + flat_end_pos = local_key_body_pos.view(local_key_body_pos.shape[0] * local_key_body_pos.shape[1], local_key_body_pos.shape[2]) + flat_heading_rot = heading_rot_expand.view(heading_rot_expand.shape[0] * heading_rot_expand.shape[1], heading_rot_expand.shape[2]) + local_end_pos = torch_utils.my_quat_rotate(flat_heading_rot, flat_end_pos) + flat_local_key_pos = local_end_pos.view(local_key_body_pos.shape[0], local_key_body_pos.shape[1] * local_key_body_pos.shape[2]) + + dof_obs = dof_to_obs(dof_pos, dof_obs_size, dof_offsets) + obs = torch.cat((root_h_obs, root_rot_obs, local_root_vel, local_root_ang_vel, dof_obs, dof_vel, flat_local_key_pos), dim=-1) + return obs + + +@torch.jit.script +def build_amp_observations_smpl(root_pos, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_body_pos, shape_params, limb_weight_params, dof_subset, local_root_obs, root_height_obs, has_dof_subset, has_shape_obs_disc, has_limb_weight_obs, upright): + # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, bool, bool, bool, bool, bool, bool) -> Tensor + B, N = root_pos.shape + root_h = root_pos[:, 2:3] + if not upright: + root_rot = remove_base_rot(root_rot) + heading_rot_inv = torch_utils.calc_heading_quat_inv(root_rot) + + if (local_root_obs): + root_rot_obs = quat_mul(heading_rot_inv, root_rot) + else: + root_rot_obs = root_rot + + root_rot_obs = torch_utils.quat_to_tan_norm(root_rot_obs) + + local_root_vel = torch_utils.my_quat_rotate(heading_rot_inv, root_vel) + local_root_ang_vel = torch_utils.my_quat_rotate(heading_rot_inv, root_ang_vel) + + root_pos_expand = root_pos.unsqueeze(-2) + local_key_body_pos = key_body_pos - root_pos_expand + + heading_rot_expand = heading_rot_inv.unsqueeze(-2) + heading_rot_expand = heading_rot_expand.repeat((1, local_key_body_pos.shape[1], 1)) + flat_end_pos = local_key_body_pos.view(local_key_body_pos.shape[0] * local_key_body_pos.shape[1], local_key_body_pos.shape[2]) + flat_heading_rot = heading_rot_expand.view(heading_rot_expand.shape[0] * heading_rot_expand.shape[1], heading_rot_expand.shape[2]) + local_end_pos = torch_utils.my_quat_rotate(flat_heading_rot, flat_end_pos) + flat_local_key_pos = local_end_pos.view(local_key_body_pos.shape[0], local_key_body_pos.shape[1] * local_key_body_pos.shape[2]) + + if has_dof_subset: + dof_vel = dof_vel[:, dof_subset] + dof_pos = dof_pos[:, dof_subset] + + dof_obs = dof_to_obs_smpl(dof_pos) + obs_list = [] + if root_height_obs: + obs_list.append(root_h) + obs_list += [root_rot_obs, local_root_vel, local_root_ang_vel, dof_obs, dof_vel, flat_local_key_pos] + # 1? + 6 + 3 + 3 + 114 + 57 + 12 + if has_shape_obs_disc: + obs_list.append(shape_params) + if has_limb_weight_obs: + obs_list.append(limb_weight_params) + obs = torch.cat(obs_list, dim=-1) + + return obs + + +@torch.jit.script +def build_amp_observations_smpl_v2(root_pos, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_body_pos, key_body_vel, shape_params, limb_weight_params, dof_subset, local_root_obs, root_height_obs, has_dof_subset, has_shape_obs_disc, has_limb_weight_obs, upright): + # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, bool, bool, bool, bool, bool, bool) -> Tensor + B, N = root_pos.shape + root_h = root_pos[:, 2:3] + if not upright: + root_rot = remove_base_rot(root_rot) + heading_rot_inv = torch_utils.calc_heading_quat_inv(root_rot) + + if (local_root_obs): + root_rot_obs = quat_mul(heading_rot_inv, root_rot) + else: + root_rot_obs = root_rot + + root_rot_obs = torch_utils.quat_to_tan_norm(root_rot_obs) + + local_root_vel = torch_utils.my_quat_rotate(heading_rot_inv, root_vel) + local_root_ang_vel = torch_utils.my_quat_rotate(heading_rot_inv, root_ang_vel) + + root_pos_expand = root_pos.unsqueeze(-2) + local_key_body_pos = key_body_pos - root_pos_expand + + heading_rot_expand = heading_rot_inv.unsqueeze(-2) + heading_rot_expand = heading_rot_expand.repeat((1, local_key_body_pos.shape[1], 1)) + flat_heading_rot = heading_rot_expand.view(heading_rot_expand.shape[0] * heading_rot_expand.shape[1], heading_rot_expand.shape[2]) + local_end_pos = torch_utils.my_quat_rotate(flat_heading_rot, local_key_body_pos.view(-1, 3)).view(local_key_body_pos.shape[0], local_key_body_pos.shape[1] * local_key_body_pos.shape[2]) + + local_vel = torch_utils.my_quat_rotate(flat_heading_rot, key_body_vel.view(-1, 3)).view(key_body_vel.shape[0], key_body_vel.shape[1] * key_body_vel.shape[2]) + + if has_dof_subset: + dof_vel = dof_vel[:, dof_subset] + dof_pos = dof_pos[:, dof_subset] + + dof_obs = dof_to_obs_smpl(dof_pos) + obs_list = [] + if root_height_obs: + obs_list.append(root_h) + obs_list += [root_rot_obs, local_root_vel, local_root_ang_vel, dof_obs, dof_vel, local_end_pos, local_vel] + # 1 + 6 + 3 + 3 + 114 + 57 + 12 + if has_shape_obs_disc: + obs_list.append(shape_params) + if has_limb_weight_obs: + obs_list.append(limb_weight_params) + obs = torch.cat(obs_list, dim=-1) + + return obs diff --git a/phc/env/tasks/humanoid_amp_getup.py b/phc/env/tasks/humanoid_amp_getup.py new file mode 100644 index 0000000..07bafaa --- /dev/null +++ b/phc/env/tasks/humanoid_amp_getup.py @@ -0,0 +1,170 @@ +# Copyright (c) 2018-2023, NVIDIA Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import torch + +from isaacgym import gymapi +from isaacgym import gymtorch + +from env.util import gym_util +from phc.env.tasks.humanoid_amp import HumanoidAMP +from isaacgym.torch_utils import * + +from phc.utils import torch_utils + + +class HumanoidAMPGetup(HumanoidAMP): + def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, headless): + + self._recovery_episode_prob = cfg["env"]["recoveryEpisodeProb"] + self._recovery_steps = cfg["env"]["recoverySteps"] + self._fall_init_prob = cfg["env"]["fallInitProb"] + + self._reset_fall_env_ids = [] + + super().__init__(cfg=cfg, + sim_params=sim_params, + physics_engine=physics_engine, + device_type=device_type, + device_id=device_id, + headless=headless) + + self._recovery_counter = torch.zeros(self.num_envs, device=self.device, dtype=torch.int) + + self._generate_fall_states() + + return + + + def pre_physics_step(self, actions): + super().pre_physics_step(actions) + + self._update_recovery_count() + return + + def _generate_fall_states(self): + max_steps = 150 + + env_ids = to_torch(np.arange(self.num_envs), device=self.device, dtype=torch.long) + root_states = self._initial_humanoid_root_states[env_ids].clone() + root_states[..., 3:7] = torch.randn_like(root_states[..., 3:7]) + root_states[..., 3:7] = torch.nn.functional.normalize(root_states[..., 3:7], dim=-1) + self._humanoid_root_states[env_ids] = root_states + + env_ids_int32 = self._humanoid_actor_ids[env_ids] + self.gym.set_actor_root_state_tensor_indexed(self.sim, + gymtorch.unwrap_tensor(self._root_states), + gymtorch.unwrap_tensor(env_ids_int32), len(env_ids_int32)) + self.gym.set_dof_state_tensor_indexed(self.sim, + gymtorch.unwrap_tensor(self._dof_state), + gymtorch.unwrap_tensor(env_ids_int32), len(env_ids_int32)) + + + rand_actions = np.random.uniform(-0.5, 0.5, size=[self.num_envs, self.get_action_size()]) + rand_actions = to_torch(rand_actions, device=self.device) + self.pre_physics_step(rand_actions) + + # step physics and render each frame + for i in range(max_steps): + self.render() + self.gym.simulate(self.sim) + + self._refresh_sim_tensors() + + self._fall_root_states = self._humanoid_root_states.clone() + self._fall_root_states[:, 7:13] = 0 + self._fall_dof_pos = self._dof_pos.clone() + self._fall_dof_vel = torch.zeros_like(self._dof_vel, device=self.device, dtype=torch.float) + + return + + def _reset_actors(self, env_ids): + num_envs = env_ids.shape[0] + recovery_probs = to_torch(np.array([self._recovery_episode_prob] * num_envs), device=self.device) + recovery_mask = torch.bernoulli(recovery_probs) == 1.0 + terminated_mask = (self._terminate_buf[env_ids] == 1) + recovery_mask = torch.logical_and(recovery_mask, terminated_mask) + + recovery_ids = env_ids[recovery_mask] + if (len(recovery_ids) > 0): + self._reset_recovery_episode(recovery_ids) + + + nonrecovery_ids = env_ids[torch.logical_not(recovery_mask)] + fall_probs = to_torch(np.array([self._fall_init_prob] * nonrecovery_ids.shape[0]), device=self.device) + fall_mask = torch.bernoulli(fall_probs) == 1.0 + fall_ids = nonrecovery_ids[fall_mask] + if (len(fall_ids) > 0): + self._reset_fall_episode(fall_ids) + + + nonfall_ids = nonrecovery_ids[torch.logical_not(fall_mask)] + if (len(nonfall_ids) > 0): + super()._reset_actors(nonfall_ids) + self._recovery_counter[nonfall_ids] = 0 + + return + + def _reset_recovery_episode(self, env_ids): + self._recovery_counter[env_ids] = self._recovery_steps + return + + def _reset_fall_episode(self, env_ids): + fall_state_ids = torch.randint_like(env_ids, low=0, high=self._fall_root_states.shape[0]) + self._humanoid_root_states[env_ids] = self._fall_root_states[fall_state_ids] + self._dof_pos[env_ids] = self._fall_dof_pos[fall_state_ids] + self._dof_vel[env_ids] = self._fall_dof_vel[fall_state_ids] + self._recovery_counter[env_ids] = self._recovery_steps + self._reset_fall_env_ids = env_ids + return + + def _reset_envs(self, env_ids): + self._reset_fall_env_ids = [] + super()._reset_envs(env_ids) + return + + def _init_amp_obs(self, env_ids): + super()._init_amp_obs(env_ids) + + if (len(self._reset_fall_env_ids) > 0): + self._init_amp_obs_default(self._reset_fall_env_ids) + + return + + def _update_recovery_count(self): + self._recovery_counter -= 1 + self._recovery_counter = torch.clamp_min(self._recovery_counter, 0) + return + + def _compute_reset(self): + super()._compute_reset() + + is_recovery = self._recovery_counter > 0 + self.reset_buf[is_recovery] = 0 + self._terminate_buf[is_recovery] = 0 + return \ No newline at end of file diff --git a/phc/env/tasks/humanoid_amp_task.py b/phc/env/tasks/humanoid_amp_task.py new file mode 100644 index 0000000..e2acd53 --- /dev/null +++ b/phc/env/tasks/humanoid_amp_task.py @@ -0,0 +1,116 @@ +# Copyright (c) 2018-2023, NVIDIA Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import torch + +import phc.env.tasks.humanoid_amp as humanoid_amp +from phc.utils.flags import flags +class HumanoidAMPTask(humanoid_amp.HumanoidAMP): + def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, headless): + self._enable_task_obs = cfg["env"]["enableTaskObs"] + + super().__init__(cfg=cfg, + sim_params=sim_params, + physics_engine=physics_engine, + device_type=device_type, + device_id=device_id, + headless=headless) + self.has_task = True + return + + + def get_obs_size(self): + obs_size = super().get_obs_size() + if (self._enable_task_obs): + task_obs_size = self.get_task_obs_size() + obs_size += task_obs_size + return obs_size + + def get_task_obs_size(self): + return 0 + + def pre_physics_step(self, actions): + super().pre_physics_step(actions) + self._update_task() + + return + + def render(self, sync_frame_time=False): + super().render(sync_frame_time) + + if self.viewer or flags.server_mode: + self._draw_task() + return + + def _update_task(self): + return + + def _reset_envs(self, env_ids): + super()._reset_envs(env_ids) + self._reset_task(env_ids) + return + + def _reset_task(self, env_ids): + return + + def _compute_observations(self, env_ids=None): + # env_ids is used for resetting + if env_ids is None: + env_ids = torch.arange(self.num_envs).to(self.device) + humanoid_obs = self._compute_humanoid_obs(env_ids) + + if (self._enable_task_obs): + task_obs = self._compute_task_obs(env_ids) + obs = torch.cat([humanoid_obs, task_obs], dim=-1) + else: + obs = humanoid_obs + + + if self.obs_v == 2: + # Double sub will return a copy. + B, N = obs.shape + sums = self.obs_buf[env_ids, 0:self.past_track_steps].abs().sum(dim=1) + zeros = sums == 0 + nonzero = ~zeros + obs_slice = self.obs_buf[env_ids] + obs_slice[zeros] = torch.tile(obs[zeros], (1, self.past_track_steps)) + obs_slice[nonzero] = torch.cat([obs_slice[nonzero, N:], obs[nonzero]], dim=-1) + self.obs_buf[env_ids] = obs_slice + else: + self.obs_buf[env_ids] = obs + + return + + def _compute_task_obs(self, env_ids=None): + return NotImplemented + + def _compute_reward(self, actions): + return NotImplemented + + def _draw_task(self): + return diff --git a/phc/env/tasks/humanoid_amp_z.py b/phc/env/tasks/humanoid_amp_z.py new file mode 100644 index 0000000..45f3733 --- /dev/null +++ b/phc/env/tasks/humanoid_amp_z.py @@ -0,0 +1,232 @@ +import time +import torch +import phc.env.tasks.humanoid_amp as humanoid_amp +from phc.env.tasks.humanoid_amp import remove_base_rot +from phc.utils import torch_utils +from typing import OrderedDict + +from isaacgym.torch_utils import * +from phc.utils.flags import flags +from rl_games.algos_torch import torch_ext +import torch.nn as nn +from phc.learning.pnn import PNN +from collections import deque +from phc.utils.torch_utils import project_to_norm + +from phc.utils.motion_lib import MotionLib +from phc.utils.motion_lib_smpl import MotionLibSMPL + +from phc.learning.network_loader import load_z_encoder, load_z_decoder + +HACK_MOTION_SYNC = False + +class HumanoidAMPZ(humanoid_amp.HumanoidAMP): + + def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, headless): + super().__init__(cfg=cfg, sim_params=sim_params, physics_engine=physics_engine, device_type=device_type, device_id=device_id, headless=headless) + + check_points = [torch_ext.load_checkpoint(ck_path) for ck_path in self.models_path] + + ### Loading Distill Model ### + self.distill_model_config = self.cfg['env']['distill_model_config'] + self.embedding_size_distill = self.distill_model_config['embedding_size'] + self.embedding_norm_distill = self.distill_model_config['embedding_norm'] + self.fut_tracks_distill = self.distill_model_config['fut_tracks'] + self.num_traj_samples_distill = self.distill_model_config['numTrajSamples'] + self.traj_sample_timestep_distill = self.distill_model_config['trajSampleTimestepInv'] + self.fut_tracks_dropout_distill = self.distill_model_config['fut_tracks_dropout'] + self.z_activation = self.distill_model_config['z_activation'] + self.distill_z_type = self.distill_model_config.get("z_type", "sphere") + + self.embedding_partition_distill = self.distill_model_config.get("embedding_partion", 1) + self.dict_size_distill = self.distill_model_config.get("dict_size", 1) + ### Loading Distill Model ### + + self.z_all = self.cfg['env'].get("z_all", False) + + self.use_vae_prior_loss = self.cfg['env'].get("use_vae_prior_loss", False) + self.use_vae_prior = self.cfg['env'].get("use_vae_prior", False) + self.use_vae_fixed_prior = self.cfg['env'].get("use_vae_fixed_prior", False) + self.use_vae_sphere_prior = self.cfg['env'].get("use_vae_sphere_prior", False) + self.use_vae_sphere_posterior = self.cfg['env'].get("use_vae_sphere_posterior", False) + + + self.decoder = load_z_decoder(check_points[0], activation = self.z_activation, z_type = self.distill_z_type, device = self.device) + self.encoder = load_z_encoder(check_points[0], activation = self.z_activation, z_type = self.distill_z_type, device = self.device) + self.power_acc = torch.zeros((self.num_envs, 2)).to(self.device) + self.power_usage_coefficient = cfg["env"].get("power_usage_coefficient", 0.005) + + self.running_mean, self.running_var = check_points[-1]['running_mean_std']['running_mean'], check_points[-1]['running_mean_std']['running_var'] + + if self.save_kin_info: + self.kin_dict = OrderedDict() + self.kin_dict.update({ + "gt_z": torch.zeros([self.num_envs,self.cfg['env'].get("embedding_size", 256) ]), + }) # current root pos + root for future aggergration + + return + + def _load_motion(self, motion_file): + assert (self._dof_offsets[-1] == self.num_dof) + if self.humanoid_type in ["smpl", "smplh", "smplx"]: + self._motion_lib = MotionLibSMPL(motion_file=motion_file, device=self.device, masterfoot_conifg=self._masterfoot_config) + + self._motion_lib.load_motions(skeleton_trees=self.skeleton_trees, gender_betas=self.humanoid_shapes.cpu(), limb_weights=self.humanoid_limb_and_weights.cpu(), random_sample=not HACK_MOTION_SYNC) + + else: + self._motion_lib = MotionLib(motion_file=motion_file, dof_body_ids=self._dof_body_ids, dof_offsets=self._dof_offsets, key_body_ids=self._key_body_ids.cpu().numpy(), device=self.device) + + return + + + def load_pnn(self, pnn_ck): + mlp_args = {'input_size': pnn_ck['model']['a2c_network.pnn.actors.0.0.weight'].shape[1], 'units': pnn_ck['model']['a2c_network.pnn.actors.0.2.weight'].shape[::-1], 'activation': "relu", 'dense_func': torch.nn.Linear} + pnn = PNN(mlp_args, output_size=69, numCols=self.num_prim, has_lateral=self.has_lateral) + state_dict = pnn.state_dict() + for k in pnn_ck['model'].keys(): + if "pnn" in k: + pnn_dict_key = k.split("pnn.")[1] + state_dict[pnn_dict_key].copy_(pnn_ck['model'][k]) + pnn.freeze_pnn(self.num_prim) + pnn.to(self.device) + return pnn + + + + def get_task_obs_size_detail(self): + task_obs_detail = OrderedDict() + + ### For Z + task_obs_detail['proj_norm'] = self.cfg['env'].get("proj_norm", True) + task_obs_detail['embedding_norm'] = self.cfg['env'].get("embedding_norm", 3) + task_obs_detail['embedding_size'] = self.cfg['env'].get("embedding_size", 256) + task_obs_detail['z_readout'] = self.cfg['env'].get("z_readout", False) + task_obs_detail['z_type'] = self.cfg['env'].get("z_type", "sphere") + task_obs_detail['num_unique_motions'] = self._motion_lib._num_unique_motions + + return task_obs_detail + + def _compute_reward(self, actions): + super()._compute_reward(actions) + + # power_all = torch.abs(torch.multiply(self.dof_force_tensor, self._dof_vel)) + # power_all = power_all.reshape(-1, 23, 3) + # left_power = power_all[:, self.left_indexes].reshape(self.num_envs, -1).sum(dim = -1) + # right_power = power_all[:, self.right_indexes].reshape(self.num_envs, -1).sum(dim = -1) + # self.power_acc[:, 0] += left_power + # self.power_acc[:, 1] += right_power + # self.power_acc[self.progress_buf <= 3] = 0 + # power_usage_reward = self.power_acc/(self.progress_buf + 1)[:, None] + # power_usage_reward = - self.power_usage_coefficient * (power_usage_reward[:, 0] - power_usage_reward[:, 1]).abs() + # power_usage_reward[self.progress_buf <= 3] = 0 # First 3 frame power reward should not be counted. since they could be dropped. on the ground to balance. + # self.rew_buf[:] = power_usage_reward + + # import ipdb; ipdb.set_trace() + + return + + + def step(self, action_z): + + # if self.dr_randomizations.get('actions', None): + # actions = self.dr_randomizations['actions']['noise_lambda'](actions) + # if flags.server_mode: + # t_s = time.time() + # t_s = time.time() + with torch.no_grad(): + # Apply trained Model. + + ################ GT-Z ################ + + self_obs_size = self.get_self_obs_size() + if self.obs_v == 2: + self_obs_size = self_obs_size//self.past_track_steps + obs_buf = self.obs_buf.view(self.num_envs, self.past_track_steps, -1) + curr_obs = obs_buf[:, -1] + self_obs = ((curr_obs[:, :self_obs_size] - self.running_mean.float()[:self_obs_size]) / torch.sqrt(self.running_var.float()[:self_obs_size] + 1e-05)) + else: + self_obs = (self.obs_buf[:, :self_obs_size] - self.running_mean.float()[:self_obs_size]) / torch.sqrt(self.running_var.float()[:self_obs_size] + 1e-05) + + if self.distill_z_type == "hyper": + action_z = self.decoder.hyper_layer(action_z) + if self.distill_z_type == "vq_vae": + if self.is_discrete: + indexes = action_z + else: + B, F = action_z.shape + indexes = action_z.reshape(B, -1, self.embedding_size_distill).argmax(dim = -1) + task_out_proj = self.decoder.quantizer.embedding.weight[indexes.view(-1)] + print(f"\r {indexes.numpy()[0]}", end = '') + action_z = task_out_proj.view(-1, self.embedding_size_distill) + elif self.distill_z_type == "vae": + if self.use_vae_prior: + z_prior_out = self.decoder.z_prior(self_obs) + prior_mu = self.decoder.z_prior_mu(z_prior_out) + action_z = prior_mu + action_z + + if self.use_vae_sphere_posterior: + action_z = project_to_norm(action_z, 1, "sphere") + else: + action_z = project_to_norm(action_z, self.cfg['env'].get("embedding_norm", 5), "none") + + else: + action_z = project_to_norm(action_z, self.cfg['env'].get("embedding_norm", 5), self.distill_z_type) + + + if self.z_all: + x_all = self.decoder.decoder(action_z) + else: + self_obs = torch.clamp(self_obs, min=-5.0, max=5.0) + x_all = self.decoder.decoder(torch.cat([self_obs, action_z], dim = -1)) + + # z_prior_out = self.decoder.z_prior(self_obs); prior_mu, prior_log_var = self.decoder.z_prior_mu(z_prior_out), self.decoder.z_prior_logvar(z_prior_out); print(prior_mu.max(), prior_mu.min()) + # print('....') + + actions = x_all + + # actions = x_all[:, 3] # Debugging + + # apply actions + self.pre_physics_step(actions) + + # step physics and render each frame + self._physics_step() + + # to fix! + if self.device == 'cpu': + self.gym.fetch_results(self.sim, True) + + # compute observations, rewards, resets, ... + self.post_physics_step() + if flags.server_mode: + dt = time.time() - t_s + print(f'\r {1/dt:.2f} fps', end='') + + # dt = time.time() - t_s + # self.fps.append(1/dt) + # print(f'\r {np.mean(self.fps):.2f} fps', end='') + + + if self.dr_randomizations.get('observations', None): + self.obs_buf = self.dr_randomizations['observations']['noise_lambda'](self.obs_buf) + + +@torch.jit.script +def compute_z_target(root_pos, root_rot, ref_body_pos, ref_body_vel, time_steps, upright): + # type: (Tensor, Tensor, Tensor, Tensor, int, bool) -> Tensor + # No rotation information. Leave IK for RL. + # Future tracks in this obs will not contain future diffs. + obs = [] + B, J, _ = ref_body_pos.shape + + if not upright: + root_rot = remove_base_rot(root_rot) + + heading_inv_rot = torch_utils.calc_heading_quat_inv(root_rot) + heading_rot = torch_utils.calc_heading_quat(root_rot) + heading_inv_rot_expand = heading_inv_rot.unsqueeze(-2).repeat((1, J, 1)).repeat_interleave(time_steps, 0) + local_ref_body_pos = ref_body_pos.view(B, time_steps, J, 3) - root_pos.view(B, 1, 1, 3) # preserves the body position + local_ref_body_pos = torch_utils.my_quat_rotate(heading_inv_rot_expand.view(-1, 4), local_ref_body_pos.view(-1, 3)) + + + return local_ref_body_pos.view(B, J, -1) \ No newline at end of file diff --git a/phc/env/tasks/humanoid_im.py b/phc/env/tasks/humanoid_im.py new file mode 100644 index 0000000..45a5574 --- /dev/null +++ b/phc/env/tasks/humanoid_im.py @@ -0,0 +1,1676 @@ + + +import os.path as osp +from typing import OrderedDict +import torch +import numpy as np +from phc.utils.torch_utils import quat_to_tan_norm +import phc.env.tasks.humanoid_amp_task as humanoid_amp_task +from phc.env.tasks.humanoid_amp import HumanoidAMP, remove_base_rot +from phc.utils.motion_lib_smpl import MotionLibSMPL +from phc.utils.motion_lib_base import FixHeightMode +from easydict import EasyDict + +from phc.utils import torch_utils + +from isaacgym import gymapi +from isaacgym import gymtorch +from isaacgym.torch_utils import * +from phc.utils.flags import flags +import joblib +import gc +from collections import defaultdict + +from poselib.poselib.skeleton.skeleton3d import SkeletonTree, SkeletonMotion, SkeletonState +from scipy.spatial.transform import Rotation as sRot +import open3d as o3d +from datetime import datetime +import imageio +from collections import deque +from tqdm import tqdm +import copy + + +class HumanoidIm(humanoid_amp_task.HumanoidAMPTask): + + def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, headless): + self._full_body_reward = cfg["env"].get("full_body_reward", True) + self._fut_tracks = cfg["env"].get("fut_tracks", False) + self._fut_tracks_dropout = cfg["env"].get("fut_tracks_dropout", False) + self.seq_motions = cfg["env"].get("seq_motions", False) + if self._fut_tracks: + self._num_traj_samples = cfg["env"]["numTrajSamples"] + else: + self._num_traj_samples = 1 + self._min_motion_len = cfg["env"].get("min_length", -1) + self._traj_sample_timestep = 1 / cfg["env"].get("trajSampleTimestepInv", 30) + + self.load_humanoid_configs(cfg) + self.cfg = cfg + self.num_envs = cfg["env"]["num_envs"] + self.device_type = cfg.get("device_type", "cuda") + self.device_id = cfg.get("device_id", 0) + self.headless = cfg["headless"] + self.start_idx = 0 + + self.reward_specs = cfg["env"].get("reward_specs", {"k_pos": 100, "k_rot": 10, "k_vel": 0.1, "k_ang_vel": 0.1, "w_pos": 0.5, "w_rot": 0.3, "w_vel": 0.1, "w_ang_vel": 0.1}) + + self._num_joints = len(self._body_names) + self.device = "cpu" + if self.device_type == "cuda" or self.device_type == "GPU": + self.device = "cuda" + ":" + str(self.device_id) + + self._track_bodies = cfg["env"].get("trackBodies", self._full_track_bodies) + self._track_bodies_id = self._build_key_body_ids_tensor(self._track_bodies) + self._reset_bodies = cfg["env"].get("reset_bodies", self._track_bodies) + + self._reset_bodies_id = self._build_key_body_ids_tensor(self._reset_bodies) + + self._full_track_bodies_id = self._build_key_body_ids_tensor(self._full_track_bodies) + self._eval_track_bodies_id = self._build_key_body_ids_tensor(self._eval_bodies) + self._motion_start_times_offset = torch.zeros(self.num_envs).to(self.device) + self._cycle_counter = torch.zeros(self.num_envs, device=self.device, dtype=torch.int) + + spacing = 5 + side_lenght = torch.ceil(torch.sqrt(torch.tensor(self.num_envs))) + pos_x, pos_y = torch.meshgrid(torch.arange(side_lenght) * spacing, torch.arange(side_lenght) * spacing) + self.start_pos_x, self.start_pos_y = pos_x.flatten(), pos_y.flatten() + self._global_offset = torch.zeros([self.num_envs, 3]).to(self.device) + # self._global_offset[:, 0], self._global_offset[:, 1] = self.start_pos_x[:self.num_envs], self.start_pos_y[:self.num_envs] + + self.offset_range = 0.8 + + ## ZL Hack Devs + #################### Devs #################### + self._point_goal = torch.zeros(self.num_envs, device=self.device) + self.random_occlu_idx = torch.zeros((self.num_envs, len(self._track_bodies)), device=self.device, dtype=torch.bool) + self.random_occlu_count = torch.zeros((self.num_envs, len(self._track_bodies)), device=self.device).long() + #################### Devs #################### + + super().__init__(cfg=cfg, sim_params=sim_params, physics_engine=physics_engine, device_type=device_type, device_id=device_id, headless=headless) + # Overriding + self.reward_raw = torch.zeros((self.num_envs, 5 if self.power_reward else 4)).to(self.device) + self.power_coefficient = cfg["env"].get("power_coefficient", 0.0005) + + if (not self.headless or flags.server_mode): + self._build_marker_state_tensors() + + self.ref_body_pos = torch.zeros_like(self._rigid_body_pos) + self.ref_body_vel = torch.zeros_like(self._rigid_body_vel) + self.ref_body_rot = torch.zeros_like(self._rigid_body_rot) + self.ref_body_pos_subset = torch.zeros_like(self._rigid_body_pos[:, self._track_bodies_id]) + self.ref_dof_pos = torch.zeros_like(self._dof_pos) + + + self.viewer_o3d = flags.render_o3d + self.vis_ref = True + self.vis_contact = False + self._sampled_motion_ids = torch.arange(self.num_envs).to(self.device) + self.create_o3d_viewer() + self.setup_kin_info() + return + + def setup_kin_info(self): + if self.cfg.env.save_kin_info: + root_pos, root_rot = self._rigid_body_pos[:, 0, :], self._rigid_body_rot[:, 0, :] + self.kin_dict = OrderedDict() + self.kin_dict.update({ # default set of kinemaitc information + "root_pos": root_pos.clone(), + "root_rot": root_rot.clone(), + "body_pos": self._rigid_body_pos.clone(), + "dof_pos": self._dof_pos.clone(), + "ref_body_pos": self.ref_body_pos.clone(), + "ref_body_vel": self.ref_body_vel.clone(), + "ref_body_rot": self.ref_body_rot.clone(), + }) # current root pos + root for future aggergration + + def pause_func(self, action): + self.paused = not self.paused + + def next_func(self, action): + self.resample_motions() + + def reset_func(self, action): + self.reset() + + def record_func(self, action): + self.recording = not self.recording + self.recording_state_change_o3d = True + self.recording_state_change_o3d_img = True + self.recording_state_change = True # only intialize from o3d. + + + def hide_ref(self, action): + flags.show_traj = not flags.show_traj + + def create_o3d_viewer(self): + ################################################ ZL Hack: o3d viewers. ################################################ + if self.viewer_o3d : + o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Debug) + self.o3d_vis = o3d.visualization.VisualizerWithKeyCallback() + self.o3d_vis.create_window() + + box = o3d.geometry.TriangleMesh() + ground_size, height = 5, 0.01 + box = box.create_box(width=ground_size, height=height, depth=ground_size) + box.translate(np.array([-ground_size / 2, -height, -ground_size / 2])) + box.compute_vertex_normals() + box.vertex_colors = o3d.utility.Vector3dVector(np.array([[0.1, 0.1, 0.1]]).repeat(8, axis=0)) + + + if self.humanoid_type in ["smpl", "smplh", "smplx"]: + from smpl_sim.smpllib.smpl_joint_names import SMPL_BONE_ORDER_NAMES, SMPLX_BONE_ORDER_NAMES, SMPLH_BONE_ORDER_NAMES, SMPL_MUJOCO_NAMES, SMPLH_MUJOCO_NAMES + + if self.humanoid_type == "smpl": + self.mujoco_2_smpl = [self._body_names_orig.index(q) for q in SMPL_BONE_ORDER_NAMES if q in self._body_names_orig] + elif self.humanoid_type in ["smplh", "smplx"]: + self.mujoco_2_smpl = [self._body_names_orig.index(q) for q in SMPLH_BONE_ORDER_NAMES if q in self._body_names_orig] + + with torch.no_grad(): + verts, joints = self._motion_lib.mesh_parsers[0].get_joints_verts(pose = torch.zeros(1, len(self._body_names_orig) * 3)) + np_triangles = self._motion_lib.mesh_parsers[0].faces + if self._has_upright_start: + self.pre_rot = sRot.from_quat([0.5, 0.5, 0.5, 0.5]) + else: + self.pre_rot = sRot.identity() + box.rotate(sRot.from_euler("xyz", [np.pi / 2, 0, 0]).as_matrix()) + self.mesh_parser = copy.deepcopy(self._motion_lib.mesh_parsers[0]) + self.mesh_parser = self.mesh_parser.cuda() + + self.sim_mesh = o3d.geometry.TriangleMesh() + self.sim_mesh.vertices = o3d.utility.Vector3dVector(verts.numpy()[0]) + self.sim_mesh.triangles = o3d.utility.Vector3iVector(np_triangles) + self.sim_mesh.vertex_colors = o3d.utility.Vector3dVector(np.array([[0, 0.5, 0.5]]).repeat(verts.shape[1], axis=0)) + if self.vis_ref: + self.ref_mesh = o3d.geometry.TriangleMesh() + self.ref_mesh.vertices = o3d.utility.Vector3dVector(verts.numpy()[0]) + self.ref_mesh.triangles = o3d.utility.Vector3iVector(np_triangles) + self.ref_mesh.vertex_colors = o3d.utility.Vector3dVector(np.array([[0.5, 0., 0.]]).repeat(verts.shape[1], axis=0)) + self.o3d_vis.add_geometry(self.ref_mesh) + + self.o3d_vis.add_geometry(box) + self.o3d_vis.add_geometry(self.sim_mesh) + self.coord_trans = torch.from_numpy(sRot.from_euler("xyz", [-np.pi / 2, 0, 0]).as_matrix()).float().cuda() + + self.o3d_vis.register_key_callback(32, self.pause_func) # space + self.o3d_vis.register_key_callback(82, self.reset_func) # R + self.o3d_vis.register_key_callback(76, self.record_func) # L + self.o3d_vis.register_key_callback(84, self.next_func) # T + self.o3d_vis.register_key_callback(75, self.hide_ref) # K + + self._video_queue_o3d = deque(maxlen=self.max_video_queue_size) + self._video_path_o3d = osp.join("output", "renderings", f"{self.cfg_name}-%s-o3d.mp4") + self.recording_state_change_o3d = False + + # if self.humanoid_type in ["smpl", "smplh", "smplx"]: + # self.control = control = self.o3d_vis.get_view_control() + # control.unset_constant_z_far() + # control.unset_constant_z_near() + # control.set_up(np.array([0, 0, 1])) + # control.set_front(np.array([1, 0, 0])) + # control.set_zoom(0.001) + + + def render(self, sync_frame_time = False): + super().render(sync_frame_time=sync_frame_time) + + if self.viewer_o3d and self.control_i == 0: + if self.humanoid_type in ["smpl", "smplh", "smplx"]: + assert(self._rigid_body_rot.shape[0] == 1) + if self._has_upright_start: + body_quat = self._rigid_body_rot + root_trans = self._rigid_body_pos[:, 0, :] + + if self.vis_ref and len(self.ref_motion_cache['dof_pos']) == self.num_envs: + ref_body_quat = self.ref_motion_cache['rb_rot'] + ref_root_trans = self.ref_motion_cache['root_pos'] + + body_quat = torch.cat([body_quat, ref_body_quat]) + root_trans = torch.cat([root_trans, ref_root_trans]) + + N = body_quat.shape[0] + offset = self.skeleton_trees[0].local_translation[0].cuda() + root_trans_offset = root_trans - offset + + pose_quat = (sRot.from_quat(body_quat.reshape(-1, 4).numpy()) * self.pre_rot).as_quat().reshape(N, -1, 4) + new_sk_state = SkeletonState.from_rotation_and_root_translation(self.skeleton_trees[0], torch.from_numpy(pose_quat), root_trans.cpu(), is_local=False) + local_rot = new_sk_state.local_rotation + pose_aa = sRot.from_quat(local_rot.reshape(-1, 4).numpy()).as_rotvec().reshape(N, -1, 3) + pose_aa = torch.from_numpy(pose_aa[:, self.mujoco_2_smpl, :].reshape(N, -1)).cuda() + else: + dof_pos = self._dof_pos + root_trans = self._rigid_body_pos[:, 0, :] + root_rot = self._rigid_body_rot[:, 0, :] + pose_aa = torch.cat([torch_utils.quat_to_exp_map(root_rot), dof_pos], dim=1).reshape(1, -1) + + if self.vis_ref and len(self.ref_motion_cache['dof_pos']) == self.num_envs: + ref_dof_pos = self.ref_motion_cache['dof_pos'] + ref_root_rot = self.ref_motion_cache['rb_rot'][:, 0, :] + ref_root_trans = self.ref_motion_cache['root_pos'] + + ref_pose_aa = torch.cat([torch_utils.quat_to_exp_map(ref_root_rot), ref_dof_pos], dim=1) + + pose_aa = torch.cat([pose_aa, ref_pose_aa]) + root_trans = torch.cat([root_trans, ref_root_trans]) + N = pose_aa.shape[0] + offset = self.skeleton_trees[0].local_translation[0].cuda() + root_trans_offset = root_trans - offset + pose_aa = pose_aa.view(N, -1, 3)[:, self.mujoco_2_smpl, :] + + + with torch.no_grad(): + verts, joints = self.mesh_parser.get_joints_verts(pose=pose_aa, th_trans=root_trans_offset.cuda()) + + sim_verts = verts.numpy()[0] + self.sim_mesh.vertices = o3d.utility.Vector3dVector(sim_verts) + if N > 1: + ref_verts = verts.numpy()[1] + if not flags.show_traj: + ref_verts[..., 0] += 2 + self.ref_mesh.vertices = o3d.utility.Vector3dVector(ref_verts) + + self.sim_mesh.compute_vertex_normals() + self.o3d_vis.update_geometry(self.sim_mesh) + if N > 1: + self.o3d_vis.update_geometry(self.ref_mesh) + + self.sim_mesh.compute_vertex_normals() + if self.vis_ref: + self.ref_mesh.compute_vertex_normals() + self.o3d_vis.poll_events() + self.o3d_vis.update_renderer() + + if self.recording_state_change_o3d: + if not self.recording: + curr_date_time = datetime.now().strftime('%Y-%m-%d-%H:%M:%S') + curr_video_file_name = self._video_path_o3d % curr_date_time + fps = 30 + writer = imageio.get_writer(curr_video_file_name, fps=fps, macro_block_size=None) + height, width, c = self._video_queue_o3d[0].shape + height, width = height if height % 2 == 0 else height - 1, width if width % 2 == 0 else width - 1 + + for frame in tqdm(np.array(self._video_queue_o3d)): + try: + writer.append_data(frame[:height, :width, :]) + except: + print('image size changed???') + import ipdb + ipdb.set_trace() + + writer.close() + self._video_queue_o3d = deque(maxlen=self.max_video_queue_size) + + print(f"============ Video finished writing O3D {curr_video_file_name}============") + else: + print(f"============ Writing video O3D ============") + + self.recording_state_change_o3d = False + + if self.recording: + rgb = self.o3d_vis.capture_screen_float_buffer() + rgb = (np.asarray(rgb) * 255).astype(np.uint8) + # w, h, _ = rgb.shape + # w, h = math.floor(w / 2.) * 2, math.floor(h / 2.) * 2 + # rgb = rgb[:w, :h, :] + self._video_queue_o3d.append(rgb) + + + + + + def _load_motion(self, motion_train_file, motion_test_file=[]): + assert (self._dof_offsets[-1] == self.num_dof) + + if self.humanoid_type in ["smpl", "smplh", "smplx"]: + motion_lib_cfg = EasyDict({ + "motion_file": motion_train_file, + "device": torch.device("cpu"), + "fix_height": FixHeightMode.full_fix, + "min_length": self._min_motion_len, + "max_length": -1, + "im_eval": flags.im_eval, + "multi_thread": True , + "smpl_type": self.humanoid_type, + "randomrize_heading": True, + "device": self.device, + }) + motion_eval_file = motion_train_file + self._motion_train_lib = MotionLibSMPL(motion_lib_cfg) + motion_lib_cfg.im_eval = True + self._motion_eval_lib = MotionLibSMPL(motion_lib_cfg) + + self._motion_lib = self._motion_train_lib + self._motion_lib.load_motions(skeleton_trees=self.skeleton_trees, gender_betas=self.humanoid_shapes.cpu(), limb_weights=self.humanoid_limb_and_weights.cpu(), random_sample=(not flags.test) and (not self.seq_motions), max_len=-1 if flags.test else self.max_len) + + else: + self._motion_lib = MotionLib(motion_file=motion_train_file, dof_body_ids=self._dof_body_ids, dof_offsets=self._dof_offsets, device=self.device) + + return + + def resample_motions(self): + # self.gym.destroy_sim(self.sim) + # del self.sim + # if not self.headless: + # self.gym.destroy_viewer(self.viewer) + # self.create_sim() + # self.gym.prepare_sim(self.sim) + # self.create_viewer() + # self._setup_tensors() + + print("Partial solution, only resample motions...") + # if self.hard_negative: + # self._motion_lib.update_sampling_weight() + + if flags.test: + self.forward_motion_samples() + else: + self._motion_lib.load_motions(skeleton_trees=self.skeleton_trees, limb_weights=self.humanoid_limb_and_weights.cpu(), gender_betas=self.humanoid_shapes.cpu(), random_sample=(not flags.test) and (not self.seq_motions), + max_len=-1 if flags.test else self.max_len) # For now, only need to sample motions since there are only 400 hmanoids + + # self.reset() # + # print("Reasmpling and resett!!!.") + + time = self.progress_buf * self.dt + self._motion_start_times + self._motion_start_times_offset + root_res = self._motion_lib.get_root_pos_smpl(self._sampled_motion_ids, time) + self._global_offset[:, :2] = self._humanoid_root_states[:, :2] - root_res['root_pos'][:, :2] + self.reset() + + + def get_motion_lengths(self): + return self._motion_lib.get_motion_lengths() + + def _record_states(self): + super()._record_states() + self.state_record['ref_body_pos_subset'].append(self.ref_body_pos_subset.cpu().clone()) + self.state_record['ref_body_pos_full'].append(self.ref_body_pos.cpu().clone()) + # self.state_record['ref_dof_pos'].append(self.ref_dof_pos.cpu().clone()) + + def _write_states_to_file(self, file_name): + self.state_record['skeleton_trees'] = self.skeleton_trees + self.state_record['humanoid_betas'] = self.humanoid_shapes + print(f"Dumping states into {file_name}") + + progress = torch.stack(self.state_record['progress'], dim=1) + progress_diff = torch.cat([progress, -10 * torch.ones(progress.shape[0], 1).to(progress)], dim=-1) + + diff = torch.abs(progress_diff[:, :-1] - progress_diff[:, 1:]) + split_idx = torch.nonzero(diff > 1) + split_idx[:, 1] += 1 + data_to_dump = {k: torch.stack(v) for k, v in self.state_record.items() if k not in ['skeleton_trees', 'humanoid_betas', "progress"]} + fps = 60 + motion_dict_dump = {} + num_for_this_humanoid = 0 + curr_humanoid_index = 0 + + for idx in range(len(split_idx)): + split_info = split_idx[idx] + humanoid_index = split_info[0] + + if humanoid_index != curr_humanoid_index: + num_for_this_humanoid = 0 + curr_humanoid_index = humanoid_index + + if num_for_this_humanoid == 0: + start = 0 + else: + start = split_idx[idx - 1][-1] + + end = split_idx[idx][-1] + + dof_pos_seg = data_to_dump['dof_pos'][start:end, humanoid_index] + B, H = dof_pos_seg.shape + root_states_seg = data_to_dump['root_states'][start:end, humanoid_index] + body_quat = torch.cat([root_states_seg[:, None, 3:7], torch_utils.exp_map_to_quat(dof_pos_seg.reshape(B, -1, 3))], dim=1) + + motion_dump = { + "skeleton_tree": self.state_record['skeleton_trees'][humanoid_index].to_dict(), + "body_quat": body_quat, + "trans": root_states_seg[:, :3], + "root_states_seg": root_states_seg, + "dof_pos": dof_pos_seg, + } + motion_dump['fps'] = fps + motion_dump['betas'] = self.humanoid_shapes[humanoid_index].detach().cpu().numpy() + motion_dump.update({k: v[start:end, humanoid_index] for k, v in data_to_dump.items() if k not in ['dof_pos', 'root_states', 'skeleton_trees', 'humanoid_betas', "progress"]}) + motion_dict_dump[f"{humanoid_index}_{num_for_this_humanoid}"] = motion_dump + num_for_this_humanoid += 1 + joblib.dump(motion_dict_dump, file_name) + self.state_record = defaultdict(list) + + def begin_seq_motion_samples(self): + # For evaluation + self.start_idx = 0 + self._motion_lib.load_motions(skeleton_trees=self.skeleton_trees, gender_betas=self.humanoid_shapes.cpu(), limb_weights=self.humanoid_limb_and_weights.cpu(), random_sample=False, start_idx=self.start_idx) + self.reset() + + def forward_motion_samples(self): + self.start_idx += self.num_envs + self._motion_lib.load_motions(skeleton_trees=self.skeleton_trees, gender_betas=self.humanoid_shapes.cpu(), limb_weights=self.humanoid_limb_and_weights.cpu(), random_sample=False, start_idx=self.start_idx) + self.reset() + + # Disabled. + # def get_self_obs_size(self): + # if self.obs_v == 4: + # return self._num_self_obs * self.past_track_steps + # else: + # return self._num_self_obs + + def get_task_obs_size(self): + obs_size = 0 + if (self._enable_task_obs): + if self.obs_v == 1: + obs_size = len(self._track_bodies) * self._num_traj_samples * 15 + elif self.obs_v == 2: # + dofdiff + obs_size = len(self._track_bodies) * self._num_traj_samples * 15 + obs_size += (len(self._track_bodies) - 1) * self._num_traj_samples * 3 + elif self.obs_v == 3: # reduced number + obs_size = len(self._track_bodies) * self._num_traj_samples * 9 + elif self.obs_v == 4: # 10 steps + v6 + + # obs_size = len(self._track_bodies) * self._num_traj_samples * 15 * 5 + obs_size = len(self._track_bodies) * 15 + obs_size += len(self._track_bodies) * self._num_traj_samples * 9 + obs_size *= self.past_track_steps + + elif self.obs_v == 5: # one hot vector for type of motions + obs_size = len(self._track_bodies) * self._num_traj_samples * 24 + 30 # Hard coded. + elif self.obs_v == 6: # local+ dof + pos (not diff) + obs_size = len(self._track_bodies) * self._num_traj_samples * 24 + + elif self.obs_v == 7: # local+ dof + pos (not diff) + obs_size = len(self._track_bodies) * self._num_traj_samples * 9 # linear position + velocity + + elif self.obs_v == 8: # local+ dof + pos (not diff) + vel (no diff). + obs_size = len(self._track_bodies) * 15 + obs_size += len(self._track_bodies) * self._num_traj_samples * 15 + + elif self.obs_v == 9: # local+ dof + pos (not diff) + vel (no diff). + obs_size = len(self._track_bodies) * self._num_traj_samples * 24 + obs_size -= (len(self._track_bodies) - 1) * self._num_traj_samples * 6 + + + return obs_size + + def get_task_obs_size_detail(self): + task_obs_detail = OrderedDict() + task_obs_detail['target'] = self.get_task_obs_size() + task_obs_detail['fut_tracks'] = self._fut_tracks + task_obs_detail['num_traj_samples'] = self._num_traj_samples + task_obs_detail['obs_v'] = self.obs_v + task_obs_detail['track_bodies'] = self._track_bodies + task_obs_detail['models_path'] = self.models_path + + # Dev + task_obs_detail['num_prim'] = self.cfg['env'].get("num_prim", 2) + task_obs_detail['training_prim'] = self.cfg['env'].get("training_prim", 1) + task_obs_detail['actors_to_load'] = self.cfg['env'].get("actors_to_load", 2) + task_obs_detail['has_lateral'] = self.cfg['env'].get("has_lateral", True) + + ### For Z + task_obs_detail['proj_norm'] = self.cfg['env'].get("proj_norm", True) + task_obs_detail['embedding_norm'] = self.cfg['env'].get("embedding_norm", 3) + task_obs_detail['embedding_size'] = self.cfg['env'].get("embedding_size", 256) + task_obs_detail['z_readout'] = self.cfg['env'].get("z_readout", False) + task_obs_detail['z_type'] = self.cfg['env'].get("z_type", "sphere") + task_obs_detail['z_all'] = self.cfg['env'].get("z_all", False) + task_obs_detail['use_vae_prior'] = self.cfg['env'].get("use_vae_prior", False) + task_obs_detail['use_vae_fixed_prior'] = self.cfg['env'].get("use_vae_fixed_prior", False) + task_obs_detail['use_vae_sphere_prior'] = self.cfg['env'].get("use_vae_sphere_prior", False) + task_obs_detail['use_vae_sphere_posterior'] = self.cfg['env'].get("use_vae_sphere_posterior", False) + task_obs_detail['use_vae_clamped_prior'] = self.cfg['env'].get("use_vae_clamped_prior", False) + task_obs_detail['vae_var_clamp_max'] = self.cfg['env'].get("vae_var_clamp_max", 0) + task_obs_detail['vae_prior_fixed_logvar'] = self.cfg['env'].get("vae_prior_fixed_logvar", 0) + task_obs_detail['num_unique_motions'] = self._motion_lib._num_unique_motions + task_obs_detail['vae_reader'] = self.cfg['env'].get("vae_reader", False) + task_obs_detail['dict_size'] = self.cfg['env'].get("dict_size", 1024) + task_obs_detail['embedding_partion'] = self.cfg['env'].get("embedding_partion", 1) + + + return task_obs_detail + + def _build_termination_heights(self): + super()._build_termination_heights() + termination_distance = self.cfg["env"].get("terminationDistance", 0.5) + self._termination_distances = to_torch(np.array([termination_distance] * self.num_bodies), device=self.device) + return + + def init_root_points(self): + # For debugging purpose + y = torch.tensor(np.linspace(-0.5, 0.5, 5), device=self.device, requires_grad=False) + x = torch.tensor(np.linspace(0, 1, 5), device=self.device, requires_grad=False) + grid_x, grid_y = torch.meshgrid(x, y) + + self.num_root_points = grid_x.numel() + points = torch.zeros(self.num_envs, self.num_root_points, 3, device=self.device, requires_grad=False) + points[:, :, 0] = grid_x.flatten() + points[:, :, 1] = grid_y.flatten() + return points + + def _create_envs(self, num_envs, spacing, num_per_row): + if (not self.headless or flags.server_mode): + self._marker_handles = [[] for _ in range(num_envs)] + self._load_marker_asset() + + if flags.add_proj: + self._proj_handles = [] + self._load_proj_asset() + + super()._create_envs(num_envs, spacing, num_per_row) + return + + def _load_marker_asset(self): + asset_root = "phc/data/assets/urdf/" + + asset_options = gymapi.AssetOptions() + asset_options.angular_damping = 0.0 + asset_options.linear_damping = 0.0 + asset_options.max_angular_velocity = 0.0 + asset_options.density = 0 + asset_options.fix_base_link = True + asset_options.default_dof_drive_mode = gymapi.DOF_MODE_NONE + + self._marker_asset = self.gym.load_asset(self.sim, asset_root, "traj_marker.urdf", asset_options) + + self._marker_asset_small = self.gym.load_asset(self.sim, asset_root, "traj_marker_small.urdf", asset_options) + + return + + def _build_env(self, env_id, env_ptr, humanoid_asset): + super()._build_env(env_id, env_ptr, humanoid_asset) + + if (not self.headless or flags.server_mode): + self._build_marker(env_id, env_ptr) + + if flags.add_proj: + self._build_proj(env_id, env_ptr) + + return + + def _update_marker(self): + if flags.show_traj: + + motion_times = (self.progress_buf + 1) * self.dt + self._motion_start_times + self._motion_start_times_offset # + 1 for target. + motion_res = self._get_state_from_motionlib_cache(self._sampled_motion_ids, motion_times, self._global_offset) + root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, smpl_params, limb_weights, pose_aa, ref_rb_pos, ref_rb_rot, ref_body_vel, ref_body_ang_vel = \ + motion_res["root_pos"], motion_res["root_rot"], motion_res["dof_pos"], motion_res["root_vel"], motion_res["root_ang_vel"], motion_res["dof_vel"], \ + motion_res["motion_bodies"], motion_res["motion_limb_weights"], motion_res["motion_aa"], motion_res["rg_pos"], motion_res["rb_rot"], motion_res["body_vel"], motion_res["body_ang_vel"] + + self._marker_pos[:] = ref_rb_pos + # self._marker_rotation[..., self._track_bodies_id, :] = ref_rb_rot[..., self._track_bodies_id, :] + + ## Only update the tracking points. + if flags.real_traj: + self._marker_pos[:] = 1000 + + self._marker_pos[..., self._track_bodies_id, :] = ref_rb_pos[..., self._track_bodies_id, :] + + if self._occl_training: + self._marker_pos[self.random_occlu_idx] = 0 + + else: + self._marker_pos[:] = 1000 + + # ######### Heading debug ####### + # points = self.init_root_points() + # base_quat = self._rigid_body_rot[0, 0:1] + # base_quat = remove_base_rot(base_quat) + # heading_rot = torch_utils.calc_heading_quat(base_quat) + # show_points = quat_apply(heading_rot.repeat(1, points.shape[0]).reshape(-1, 4), points) + (self._rigid_body_pos[0, 0:1]).unsqueeze(1) + # self._marker_pos[:] = show_points[:, :self._marker_pos.shape[1]] + # ######### Heading debug ####### + + self.gym.set_actor_root_state_tensor_indexed(self.sim, gymtorch.unwrap_tensor(self._root_states), gymtorch.unwrap_tensor(self._marker_actor_ids), len(self._marker_actor_ids)) + return + + def _build_marker(self, env_id, env_ptr): + default_pose = gymapi.Transform() + for i in range(self._num_joints): + # Giving hands smaller balls to indicate positions + if self.humanoid_type in ['smplx'] and self._body_names_orig[i] in ["L_Wrist", "R_Wrist", "L_Index1", "L_Index2", "L_Index3","L_Middle1","L_Middle2","L_Middle3","L_Pinky1","L_Pinky2", "L_Pinky3", "L_Ring1", "L_Ring2", "L_Ring3", "L_Thumb1", "L_Thumb2", "L_Thumb3", "R_Index1", "R_Index2", "R_Index3", "R_Middle1", "R_Middle2", "R_Middle3", "R_Pinky1", "R_Pinky2", "R_Pinky3", "R_Ring1", "R_Ring2", "R_Ring3", "R_Thumb1", "R_Thumb2", "R_Thumb3",]: + marker_handle = self.gym.create_actor(env_ptr, self._marker_asset_small, default_pose, "marker", self.num_envs + 10, 1, 0) + else: + marker_handle = self.gym.create_actor(env_ptr, self._marker_asset, default_pose, "marker", self.num_envs + 10, 1, 0) + + if i in self._track_bodies_id: + self.gym.set_rigid_body_color(env_ptr, marker_handle, 0, gymapi.MESH_VISUAL, gymapi.Vec3(0.8, 0.0, 0.0)) + else: + self.gym.set_rigid_body_color(env_ptr, marker_handle, 0, gymapi.MESH_VISUAL, gymapi.Vec3(1.0, 1.0, 1.0)) + self._marker_handles[env_id].append(marker_handle) + + return + + def _build_marker_state_tensors(self): + num_actors = self._root_states.shape[0] // self.num_envs + self._marker_states = self._root_states.view(self.num_envs, num_actors, self._root_states.shape[-1])[..., 1:(1 + self._num_joints), :] + self._marker_pos = self._marker_states[..., :3] + self._marker_rotation = self._marker_states[..., 3:7] + + self._marker_actor_ids = self._humanoid_actor_ids.unsqueeze(-1) + to_torch(self._marker_handles, dtype=torch.int32, device=self.device) + self._marker_actor_ids = self._marker_actor_ids.flatten() + + return + + def _sample_time(self, motion_ids): + # Motion imitation, no more blending and only sample at certain locations + return self._motion_lib.sample_time_interval(motion_ids) + # return self._motion_lib.sample_time(motion_ids) + + def _reset_task(self, env_ids): + super()._reset_task(env_ids) + # imitation task is resetted with the actions + return + + def post_physics_step(self): + if self.save_kin_info: # this needs to happen BEFORE the next time-step observation is computed, to collect the "current time-step target" + self.extras['kin_dict'] = self.kin_dict + super().post_physics_step() + + if flags.im_eval: + motion_times = (self.progress_buf) * self.dt + self._motion_start_times + self._motion_start_times_offset # already has time + 1, so don't need to + 1 to get the target for "this frame" + motion_res = self._get_state_from_motionlib_cache(self._sampled_motion_ids, motion_times, self._global_offset) # pass in the env_ids such that the motion is in synced. + body_pos = self._rigid_body_pos + self.extras['mpjpe'] = (body_pos - motion_res['rg_pos']).norm(dim=-1).mean(dim=-1) + self.extras['body_pos'] = body_pos.cpu().numpy() + self.extras['body_pos_gt'] = motion_res['rg_pos'].cpu().numpy() + + return + + def _compute_observations(self, env_ids=None): + # env_ids is used for resetting + if env_ids is None: + env_ids = torch.arange(self.num_envs).to(self.device) + + self_obs = self._compute_humanoid_obs(env_ids) + self.self_obs_buf[env_ids] = self_obs + + if (self._enable_task_obs): + task_obs = self._compute_task_obs(env_ids) + obs = torch.cat([self_obs, task_obs], dim=-1) + else: + obs = self_obs + + if self.add_obs_noise and not flags.test: + obs = obs + torch.randn_like(obs) * 0.1 + + if self.obs_v == 4: + # Double sub will return a copy. + B, N = obs.shape + sums = self.obs_buf[env_ids, 0:self.past_track_steps].abs().sum(dim=1) + zeros = sums == 0 + nonzero = ~zeros + obs_slice = self.obs_buf[env_ids] + obs_slice[zeros] = torch.tile(obs[zeros], (1, self.past_track_steps)) + obs_slice[nonzero] = torch.cat([obs_slice[nonzero, N:], obs[nonzero]], dim=-1) + self.obs_buf[env_ids] = obs_slice + else: + self.obs_buf[env_ids] = obs + return obs + + def _compute_task_obs(self, env_ids=None, save_buffer = True): + if (env_ids is None): + body_pos = self._rigid_body_pos + body_rot = self._rigid_body_rot + body_vel = self._rigid_body_vel + body_ang_vel = self._rigid_body_ang_vel + env_ids = torch.arange(self.num_envs, dtype=torch.long, device=self.device) + else: + body_pos = self._rigid_body_pos[env_ids] + body_rot = self._rigid_body_rot[env_ids] + body_vel = self._rigid_body_vel[env_ids] + body_ang_vel = self._rigid_body_ang_vel[env_ids] + + curr_gender_betas = self.humanoid_shapes[env_ids] + + if self._fut_tracks: + time_steps = self._num_traj_samples + B = env_ids.shape[0] + time_internals = torch.arange(time_steps).to(self.device).repeat(B).view(-1, time_steps) * self._traj_sample_timestep + motion_times_steps = ((self.progress_buf[env_ids, None] + 1) * self.dt + time_internals + self._motion_start_times[env_ids, None] + self._motion_start_times_offset[env_ids, None]).flatten() # Next frame, so +1 + env_ids_steps = self._sampled_motion_ids[env_ids].repeat_interleave(time_steps) + motion_res = self._get_state_from_motionlib_cache(env_ids_steps, motion_times_steps, self._global_offset[env_ids].repeat_interleave(time_steps, dim=0).view(-1, 3)) # pass in the env_ids such that the motion is in synced. + + else: + motion_times = (self.progress_buf[env_ids] + 1) * self.dt + self._motion_start_times[env_ids] + self._motion_start_times_offset[env_ids] # Next frame, so +1 + time_steps = 1 + motion_res = self._get_state_from_motionlib_cache(self._sampled_motion_ids[env_ids], motion_times, self._global_offset[env_ids]) # pass in the env_ids such that the motion is in synced. + + ref_root_pos, ref_root_rot, ref_dof_pos, ref_root_vel, ref_root_ang_vel, ref_dof_vel, ref_smpl_params, ref_limb_weights, ref_pose_aa, ref_rb_pos, ref_rb_rot, ref_body_vel, ref_body_ang_vel = \ + motion_res["root_pos"], motion_res["root_rot"], motion_res["dof_pos"], motion_res["root_vel"], motion_res["root_ang_vel"], motion_res["dof_vel"], \ + motion_res["motion_bodies"], motion_res["motion_limb_weights"], motion_res["motion_aa"], motion_res["rg_pos"], motion_res["rb_rot"], motion_res["body_vel"], motion_res["body_ang_vel"] + root_pos = body_pos[..., 0, :] + root_rot = body_rot[..., 0, :] + + body_pos_subset = body_pos[..., self._track_bodies_id, :] + body_rot_subset = body_rot[..., self._track_bodies_id, :] + body_vel_subset = body_vel[..., self._track_bodies_id, :] + body_ang_vel_subset = body_ang_vel[..., self._track_bodies_id, :] + + ref_rb_pos_subset = ref_rb_pos[..., self._track_bodies_id, :] + ref_rb_rot_subset = ref_rb_rot[..., self._track_bodies_id, :] + ref_body_vel_subset = ref_body_vel[..., self._track_bodies_id, :] + ref_body_ang_vel_subset = ref_body_ang_vel[..., self._track_bodies_id, :] + + if self.obs_v == 1 : + obs = compute_imitation_observations(root_pos, root_rot, body_pos_subset, body_rot_subset, body_vel_subset, body_ang_vel_subset, ref_rb_pos_subset, ref_rb_rot_subset, ref_body_vel_subset, ref_body_ang_vel_subset, time_steps, self._has_upright_start) + + elif self.obs_v == 2: + ref_dof_pos_subset = ref_dof_pos.reshape(-1, len(self._dof_names), 3)[..., self._track_bodies_id[1:] - 1, :] # Remove root from dof dim + dof_pos_subset = self._dof_pos[env_ids].reshape(-1, len(self._dof_names), 3)[..., self._track_bodies_id[1:] - 1, :] + obs = compute_imitation_observations_v2(root_pos, root_rot, body_pos_subset, body_rot_subset, body_vel_subset, body_ang_vel_subset, dof_pos_subset, ref_rb_pos_subset, ref_rb_rot_subset, ref_body_vel_subset, ref_body_ang_vel_subset, ref_dof_pos_subset, time_steps, self._has_upright_start) + elif self.obs_v == 3: + obs = compute_imitation_observations_v3(root_pos, root_rot, body_pos_subset, body_rot_subset, body_vel_subset, body_ang_vel_subset, ref_rb_pos_subset, ref_rb_rot_subset, ref_body_vel_subset, ref_body_ang_vel_subset, time_steps, self._has_upright_start) + elif self.obs_v == 4 or self.obs_v == 5 or self.obs_v == 6 or self.obs_v == 8 or self.obs_v == 9: + + if self.zero_out_far: + close_distance = self.close_distance + distance = torch.norm(root_pos - ref_rb_pos_subset[..., 0, :], dim=-1) + + zeros_subset = distance > close_distance + ref_rb_pos_subset[zeros_subset, 1:] = body_pos_subset[zeros_subset, 1:] + ref_rb_rot_subset[zeros_subset, 1:] = body_rot_subset[zeros_subset, 1:] + ref_body_vel_subset[zeros_subset, :] = body_vel_subset[zeros_subset, :] + ref_body_ang_vel_subset[zeros_subset, :] = body_ang_vel_subset[zeros_subset, :] + self._point_goal[env_ids] = distance + + far_distance = self.far_distance # does not seem to need this in particular... + vector_zero_subset = distance > far_distance # > 5 meters, it become just a direction + ref_rb_pos_subset[vector_zero_subset, 0] = ((ref_rb_pos_subset[vector_zero_subset, 0] - body_pos_subset[vector_zero_subset, 0]) / distance[vector_zero_subset, None] * far_distance) + body_pos_subset[vector_zero_subset, 0] + + if self._occl_training: + # ranomly occlude some of the body parts + random_occlu_idx = self.random_occlu_idx[env_ids] + ref_rb_pos_subset[random_occlu_idx] = body_pos_subset[random_occlu_idx] + ref_rb_rot_subset[random_occlu_idx] = body_rot_subset[random_occlu_idx] + ref_body_vel_subset[random_occlu_idx] = body_vel_subset[random_occlu_idx] + ref_body_ang_vel_subset[random_occlu_idx] = body_ang_vel_subset[random_occlu_idx] + + if self.obs_v == 4 or self.obs_v == 6: + obs = compute_imitation_observations_v6(root_pos, root_rot, body_pos_subset, body_rot_subset, body_vel_subset, body_ang_vel_subset, ref_rb_pos_subset, ref_rb_rot_subset, ref_body_vel_subset, ref_body_ang_vel_subset, time_steps, self._has_upright_start) + + # obs[:, -1] = env_ids.clone().float(); print('debugging') + # obs[:, -2] = self.progress_buf[env_ids].clone().float(); print('debugging') + + elif self.obs_v == 5: + obs = compute_imitation_observations_v6(root_pos, root_rot, body_pos_subset, body_rot_subset, body_vel_subset, body_ang_vel_subset, ref_rb_pos_subset, ref_rb_rot_subset, ref_body_vel_subset, ref_body_ang_vel_subset, time_steps, self._has_upright_start) + one_hots = self._motion_lib.one_hot_motions[env_ids] + obs = torch.cat([obs, one_hots], dim=-1) + + elif self.obs_v == 8: + obs = compute_imitation_observations_v8(root_pos, root_rot, body_pos_subset, body_rot_subset, body_vel_subset, body_ang_vel_subset, ref_rb_pos_subset, ref_rb_rot_subset, ref_body_vel_subset, ref_body_ang_vel_subset, time_steps, self._has_upright_start) + elif self.obs_v == 9: + ref_root_vel_subset = ref_body_vel_subset[:, 0] + ref_root_ang_vel_subset =ref_body_ang_vel_subset[:, 0] + obs = compute_imitation_observations_v9(root_pos, root_rot, body_pos_subset, body_rot_subset, body_vel_subset, body_ang_vel_subset, ref_rb_pos_subset, ref_rb_rot_subset, ref_root_vel_subset, ref_root_ang_vel_subset, time_steps, self._has_upright_start) + + if self._fut_tracks_dropout and not flags.test: + dropout_rate = 0.1 + curr_num_envs = env_ids.shape[0] + obs = obs.view(curr_num_envs, self._num_traj_samples, -1) + mask = torch.rand(curr_num_envs, self._num_traj_samples) < dropout_rate + obs[mask, :] = 0 + obs = obs.view(curr_num_envs, -1) + + elif self.obs_v == 7: + + if self.zero_out_far: + close_distance = self.close_distance + distance = torch.norm(root_pos - ref_rb_pos_subset[..., 0, :], dim=-1) + + zeros_subset = distance > close_distance + ref_rb_pos_subset[zeros_subset, 1:] = body_pos_subset[zeros_subset, 1:] + ref_body_vel_subset[zeros_subset, :] = body_vel_subset[zeros_subset, :] + self._point_goal[env_ids] = distance + + far_distance = self.far_distance # does not seem to need this in particular... + vector_zero_subset = distance > far_distance # > 5 meters, it become just a direction + ref_rb_pos_subset[vector_zero_subset, 0] = ((ref_rb_pos_subset[vector_zero_subset, 0] - body_pos_subset[vector_zero_subset, 0]) / distance[vector_zero_subset, None] * far_distance) + body_pos_subset[vector_zero_subset, 0] + + if self._occl_training: + # ranomly occlude some of the body parts + random_occlu_idx = self.random_occlu_idx[env_ids] + ref_rb_pos_subset[random_occlu_idx] = body_pos_subset[random_occlu_idx] + ref_rb_rot_subset[random_occlu_idx] = body_rot_subset[random_occlu_idx] + + obs = compute_imitation_observations_v7(root_pos, root_rot, body_pos_subset, body_vel_subset, ref_rb_pos_subset, ref_body_vel_subset, time_steps, self._has_upright_start) + + if save_buffer: + if self._fut_tracks: + self.ref_body_pos[env_ids] = ref_rb_pos[..., 0, :, :] + self.ref_body_vel[env_ids] = ref_body_vel[..., 0, :, :] + self.ref_body_rot[env_ids] = ref_rb_rot[..., 0, :, :] + self.ref_body_pos_subset[env_ids] = ref_rb_pos_subset[..., 0, :, :] + self.ref_dof_pos[env_ids] = ref_dof_pos[..., 0, :] + + else: + self.ref_body_pos[env_ids] = ref_rb_pos + self.ref_body_vel[env_ids] = ref_body_vel + self.ref_body_rot[env_ids] = ref_rb_rot + self.ref_body_pos_subset[env_ids] = ref_rb_pos_subset + self.ref_dof_pos[env_ids] = ref_dof_pos + + + return obs + + def _compute_reward(self, actions): + body_pos = self._rigid_body_pos + body_rot = self._rigid_body_rot + body_vel = self._rigid_body_vel + body_ang_vel = self._rigid_body_ang_vel + + motion_times = self.progress_buf * self.dt + self._motion_start_times + self._motion_start_times_offset # reward is computed after phsycis step, and progress_buf is already updated for next time step. + + motion_res = self._get_state_from_motionlib_cache(self._sampled_motion_ids, motion_times, self._global_offset) + + ref_root_pos, ref_root_rot, ref_dof_pos, ref_root_vel, ref_root_ang_vel, ref_dof_vel, ref_smpl_params, ref_limb_weights, ref_pose_aa, ref_rb_pos, ref_rb_rot, ref_body_vel, ref_body_ang_vel = \ + motion_res["root_pos"], motion_res["root_rot"], motion_res["dof_pos"], motion_res["root_vel"], motion_res["root_ang_vel"], motion_res["dof_vel"], \ + motion_res["motion_bodies"], motion_res["motion_limb_weights"], motion_res["motion_aa"], motion_res["rg_pos"], motion_res["rb_rot"], motion_res["body_vel"], motion_res["body_ang_vel"] + + root_pos = body_pos[..., 0, :] + root_rot = body_rot[..., 0, :] + + if self.zero_out_far: + transition_distance = 0.25 + distance = torch.norm(root_pos - ref_root_pos, dim=-1) + + zeros_subset = distance > transition_distance # For those that are outside, no imitation reward + self.reward_raw = torch.zeros((self.num_envs, 4)).to(self.device) + + # self.rew_buf, self.reward_raw[:, 0] = compute_location_reward(root_pos, ref_rb_pos[..., 0, :]) + self.rew_buf, self.reward_raw[:, 0] = compute_point_goal_reward(self._point_goal, distance) + + im_reward, im_reward_raw = compute_imitation_reward(root_pos[~zeros_subset, :], root_rot[~zeros_subset, :], body_pos[~zeros_subset, :], body_rot[~zeros_subset, :], body_vel[~zeros_subset, :], body_ang_vel[~zeros_subset, :], ref_rb_pos[~zeros_subset, :], ref_rb_rot[~zeros_subset, :], + ref_body_vel[~zeros_subset, :], ref_body_ang_vel[~zeros_subset, :], self.reward_specs) + + # self.rew_buf, self.reward_raw = self.rew_buf * 0.5, self.reward_raw * 0.5 # Half the reward for the location reward + self.rew_buf[~zeros_subset] = self.rew_buf[~zeros_subset] + im_reward * 0.5 # for those are inside, add imitation reward + self.reward_raw[~zeros_subset, :4] = self.reward_raw[~zeros_subset, :4] + im_reward_raw * 0.5 + + # local_rwd, _ = compute_location_reward(root_pos, ref_rb_pos[:, ..., 0, :]) + # im_rwd, _ = compute_imitation_reward( + # root_pos, root_rot, body_pos, body_rot, body_vel, body_ang_vel, + # ref_rb_pos, ref_rb_rot, ref_body_vel, ref_body_ang_vel, + # self.reward_specs) + # print(local_rwd, im_rwd) + + else: + if self._full_body_reward: + self.rew_buf[:], self.reward_raw = compute_imitation_reward(root_pos, root_rot, body_pos, body_rot, body_vel, body_ang_vel, ref_rb_pos, ref_rb_rot, ref_body_vel, ref_body_ang_vel, self.reward_specs) + else: + body_pos_subset = body_pos[..., self._track_bodies_id, :] + body_rot_subset = body_rot[..., self._track_bodies_id, :] + body_vel_subset = body_vel[..., self._track_bodies_id, :] + body_ang_vel_subset = body_ang_vel[..., self._track_bodies_id, :] + + ref_rb_pos_subset = ref_rb_pos[..., self._track_bodies_id, :] + ref_rb_rot_subset = ref_rb_rot[..., self._track_bodies_id, :] + ref_body_vel_subset = ref_body_vel[..., self._track_bodies_id, :] + ref_body_ang_vel_subset = ref_body_ang_vel[..., self._track_bodies_id, :] + self.rew_buf[:], self.reward_raw = compute_imitation_reward(root_pos, root_rot, body_pos_subset, body_rot_subset, body_vel_subset, body_ang_vel_subset, ref_rb_pos_subset, ref_rb_rot_subset, ref_body_vel_subset, ref_body_ang_vel_subset, self.reward_specs) + + # print(self.dof_force_tensor.abs().max()) + if self.power_reward: + power = torch.abs(torch.multiply(self.dof_force_tensor, self._dof_vel)).sum(dim=-1) + # power_reward = -0.00005 * (power ** 2) + power_reward = -self.power_coefficient * power + power_reward[self.progress_buf <= 3] = 0 # First 3 frame power reward should not be counted. since they could be dropped. + + self.rew_buf[:] += power_reward + self.reward_raw = torch.cat([self.reward_raw, power_reward[:, None]], dim=-1) + + return + + def _reset_ref_state_init(self, env_ids): + self._motion_start_times_offset[env_ids] = 0 # Reset the motion time offsets + self._global_offset[env_ids] = 0 # Reset the global offset when resampling. + # self._global_offset[:, 0], self._global_offset[:, 1] = self.start_pos_x[:self.num_envs], self.start_pos_y[:self.num_envs] + + self._cycle_counter[env_ids] = 0 + super()._reset_ref_state_init(env_ids) # This function does not use the offset + # self._motion_lib.update_sampling_history(env_ids) + + if self.obs_v == 4: + self.obs_buf[env_ids] = 0 + if self.zero_out_far and self.zero_out_far_train: + # if self.zero_out_far and not flags.test: + # Moving the start position to a random location + # env_ids_pick = env_ids + env_ids_pick = env_ids[torch.arange(env_ids.shape[0]).long()] # All far away start. + max_distance = 5 + rand_distance = torch.sqrt(torch.rand(env_ids_pick.shape[0]).to(self.device)) * max_distance + + rand_angle = torch.rand(env_ids_pick.shape[0]).to(self.device) * np.pi * 2 + + self._global_offset[env_ids_pick, 0] = torch.cos(rand_angle) * rand_distance + self._global_offset[env_ids_pick, 1] = torch.sin(rand_angle) * rand_distance + + # self._global_offset[env_ids] + self._cycle_counter[env_ids_pick] = self._zero_out_far_steps + + return + + def _get_state_from_motionlib_cache(self, motion_ids, motion_times, offset=None): + ## Cache the motion + offset + if offset is None or not "motion_ids" in self.ref_motion_cache or self.ref_motion_cache['offset'] is None or len(self.ref_motion_cache['motion_ids']) != len(motion_ids) or len(self.ref_motion_cache['offset']) != len(offset) \ + or (self.ref_motion_cache['motion_ids'] - motion_ids).abs().sum() + (self.ref_motion_cache['motion_times'] - motion_times).abs().sum() + (self.ref_motion_cache['offset'] - offset).abs().sum() > 0 : + self.ref_motion_cache['motion_ids'] = motion_ids.clone() # need to clone; otherwise will be overriden + self.ref_motion_cache['motion_times'] = motion_times.clone() # need to clone; otherwise will be overriden + self.ref_motion_cache['offset'] = offset.clone() if not offset is None else None + else: + return self.ref_motion_cache + + motion_res = self._motion_lib.get_motion_state(motion_ids, motion_times, offset=offset) + + self.ref_motion_cache.update(motion_res) + + return self.ref_motion_cache + + def _sample_ref_state(self, env_ids): + num_envs = env_ids.shape[0] + + if (self._state_init == HumanoidAMP.StateInit.Random or self._state_init == HumanoidAMP.StateInit.Hybrid): + motion_times = self._sample_time(self._sampled_motion_ids[env_ids]) + elif (self._state_init == HumanoidAMP.StateInit.Start): + motion_times = torch.zeros(num_envs, device=self.device) + else: + assert (False), "Unsupported state initialization strategy: {:s}".format(str(self._state_init)) + + if flags.test: + motion_times[:] = 0 + + if self.humanoid_type in ["smpl", "smplh", "smplx"] : + motion_res = self._get_state_from_motionlib_cache(self._sampled_motion_ids[env_ids], motion_times, self._global_offset[env_ids]) + root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, smpl_params, limb_weights, pose_aa, ref_rb_pos, ref_rb_rot, ref_body_vel, ref_body_ang_vel = \ + motion_res["root_pos"], motion_res["root_rot"], motion_res["dof_pos"], motion_res["root_vel"], motion_res["root_ang_vel"], motion_res["dof_vel"], \ + motion_res["motion_bodies"], motion_res["motion_limb_weights"], motion_res["motion_aa"], motion_res["rg_pos"], motion_res["rb_rot"], motion_res["body_vel"], motion_res["body_ang_vel"] + + else: + root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, key_pos = self._motion_lib.get_motion_state(self._sampled_motion_ids[env_ids], motion_times) + rb_pos, rb_rot = None, None + + return self._sampled_motion_ids[env_ids], motion_times, root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, ref_rb_pos, ref_rb_rot, ref_body_vel, ref_body_ang_vel + + def _hack_motion_sync(self): + if (not hasattr(self, "_hack_motion_time")): + self._hack_motion_time = self._motion_start_times + self._motion_start_times_offset + + num_motions = self._motion_lib.num_motions() + motion_ids = np.arange(self.num_envs, dtype=np.int) + motion_ids = np.mod(motion_ids, num_motions) + motion_ids = torch.from_numpy(motion_ids).to(self.device) + # motion_ids[:] = 2 + motion_times = self._hack_motion_time + if self.humanoid_type in ["smpl", "smplh", "smplx"] : + motion_res = self._get_state_from_motionlib_cache(motion_ids, motion_times, self._global_offset) + root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, smpl_params, limb_weights, pose_aa, rb_pos, rb_rot, body_vel, body_ang_vel = \ + motion_res["root_pos"], motion_res["root_rot"], motion_res["dof_pos"], motion_res["root_vel"], motion_res["root_ang_vel"], motion_res["dof_vel"], \ + motion_res["motion_bodies"], motion_res["motion_limb_weights"], motion_res["motion_aa"], motion_res["rg_pos"], motion_res["rb_rot"], motion_res["body_vel"], motion_res["body_ang_vel"] + + root_pos[..., -1] += 0.03 # ALways slightly above the ground to avoid issue + else: + root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, key_pos \ + = self._motion_lib.get_motion_state(motion_ids, motion_times) + rb_pos, rb_rot = None, None + + env_ids = torch.arange(self.num_envs, dtype=torch.long, device=self.device) + + self._set_env_state( + env_ids=env_ids, + root_pos=root_pos, + root_rot=root_rot, + dof_pos=dof_pos, + root_vel=root_vel, + root_ang_vel=root_ang_vel, + dof_vel=dof_vel, + rigid_body_pos=rb_pos, + rigid_body_rot=rb_rot, + rigid_body_vel=body_vel, + rigid_body_ang_vel=body_ang_vel, + ) + + self._reset_env_tensors(env_ids) + motion_fps = self._motion_lib._motion_fps[0] + + motion_dur = self._motion_lib._motion_lengths[0] + if not self.paused: + self._hack_motion_time = (self._hack_motion_time + self._motion_sync_dt) # since the simulation is double + else: + pass + + # self.progress_buf[:] = (self._hack_motion_time * 2* motion_fps).long() # /2 is for simulation double speed... + + return + + def _update_cycle_count(self): + self._cycle_counter -= 1 + self._cycle_counter = torch.clamp_min(self._cycle_counter, 0) + return + + def _update_occl_training(self): + occu_training = torch.ones([self.num_envs, len(self._track_bodies)], device=self.device) * self._occl_training_prob + random_occlu_idx = torch.bernoulli(occu_training).bool() + random_occlu_idx[:, 0] = False + + self.random_occlu_count[random_occlu_idx] = torch.randint(30, 60, self.random_occlu_count[random_occlu_idx].shape).to(self.device) + self.random_occlu_count -= 1 + self.random_occlu_count = torch.clamp_min(self.random_occlu_count, 0) + self.random_occlu_idx = self.random_occlu_count > 0 + + self.random_occlu_idx[:] = True + self.random_occlu_idx[:, [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]] = False + + def step(self, actions): + if self.dr_randomizations.get('actions', None): + actions = self.dr_randomizations['actions']['noise_lambda'](actions) + # apply actions + self.pre_physics_step(actions) + + if self.save_kin_info: # this needs to happen after pre_physics_step to get the correctly scaled actions + self.update_kin_info() + + # step physics and render each frame + self._physics_step() + + # to fix! + if self.device == 'cpu': + self.gym.fetch_results(self.sim, True) + + # compute observations, rewards, resets, ... + self.post_physics_step() + + + if self.dr_randomizations.get('observations', None): + self.obs_buf = self.dr_randomizations['observations']['noise_lambda'](self.obs_buf) + + def update_kin_info(self): + root_pos = self._rigid_body_pos[..., 0, :] + root_rot = self._rigid_body_rot[..., 0, :] + self.kin_dict.update({ + "root_pos": root_pos.clone(), + "root_rot": root_rot.clone(), + "body_pos": self._rigid_body_pos.clone(), + "dof_pos": self._dof_pos.clone(), + "ref_body_pos": self.ref_body_pos.clone(), + "ref_body_vel": self.ref_body_vel.clone(), + "ref_body_rot": self.ref_body_rot.clone(), + }) # current root pos + root for future aggergration + + def _action_to_pd_targets(self, action): + if self._res_action: + pd_tar = self.ref_dof_pos + self._pd_action_scale * action + pd_lower = self._dof_pos - np.pi / 2 + pd_upper = self._dof_pos + np.pi / 2 + pd_tar = torch.maximum(torch.minimum(pd_tar, pd_upper), pd_lower) + + else: + pd_tar = self._pd_action_offset + self._pd_action_scale * action + + return pd_tar + + + def pre_physics_step(self, actions): + + super().pre_physics_step(actions) + self._update_cycle_count() + + if self._occl_training: + self._update_occl_training() + + return + + def _compute_reset(self): + time = (self.progress_buf) * self.dt + self._motion_start_times + self._motion_start_times_offset # Reset is also called after the progress_buf is updated. + + pass_time_max = self.progress_buf >= self.max_episode_length - 1 + pass_time_motion_len = time >= self._motion_lib._motion_lengths + + if self.cycle_motion: + pass_time = pass_time_max + if pass_time_motion_len.sum() > 0: + self._motion_start_times_offset[pass_time_motion_len] = -self.progress_buf[pass_time_motion_len] * self.dt # such that the proegress_buf will cancel out to 0. + self._motion_start_times[pass_time_motion_len] = self._sample_time(self._sampled_motion_ids[pass_time_motion_len]) + self._cycle_counter[pass_time_motion_len] = 60 + + root_res = self._motion_lib.get_root_pos_smpl(self._sampled_motion_ids[pass_time_motion_len], self._motion_start_times[pass_time_motion_len]) + if self.cycle_motion_xp: + self._global_offset[pass_time_motion_len, :2] = self._humanoid_root_states[pass_time_motion_len, :2] - root_res['root_pos'][:, :2] + torch.rand(pass_time_motion_len.sum(), 2).to(self.device) # one meter + elif self.zero_out_far and self.zero_out_far_train: + + max_distance = 5 + num_cycle_motion = pass_time_motion_len.sum() + rand_distance = torch.sqrt(torch.rand(num_cycle_motion).to(self.device)) * max_distance + rand_angle = torch.rand(num_cycle_motion).to(self.device) * np.pi * 2 + + self._global_offset[pass_time_motion_len, :2] = self._humanoid_root_states[pass_time_motion_len, :2] - root_res['root_pos'][:, :2] + torch.cat([(torch.cos(rand_angle) * rand_distance)[:, None], (torch.sin(rand_angle) * rand_distance)[:, None]], dim=-1) + else: + self._global_offset[pass_time_motion_len, :2] = self._humanoid_root_states[pass_time_motion_len, :2] - root_res['root_pos'][:, :2] + + time = self.progress_buf * self.dt + self._motion_start_times + self._motion_start_times_offset # update time + if flags.test: + print("cycling motion") + else: + pass_time = pass_time_motion_len + + motion_res = self._get_state_from_motionlib_cache(self._sampled_motion_ids, time, self._global_offset) + + ref_root_pos, ref_root_rot, ref_dof_pos, ref_root_vel, root_ang_vel, dof_vel, smpl_params, limb_weights, pose_aa, ref_rb_pos, ref_rb_rot, ref_body_vel, ref_body_ang_vel = \ + motion_res["root_pos"], motion_res["root_rot"], motion_res["dof_pos"], motion_res["root_vel"], motion_res["root_ang_vel"], motion_res["dof_vel"], \ + motion_res["motion_bodies"], motion_res["motion_limb_weights"], motion_res["motion_aa"], motion_res["rg_pos"], motion_res["rb_rot"], motion_res["body_vel"], motion_res["body_ang_vel"] + + if self.zero_out_far and self.zero_out_far_train: + # zeros_subset = torch.norm(self._rigid_body_pos[..., 0, :] - ref_rb_pos[..., 0, :], dim=-1) > self._termination_distances[..., 0] + # zeros_subset = torch.norm(self._rigid_body_pos[..., 0, :] - ref_rb_pos[..., 0, :], dim=-1) > 0.1 + # self.reset_buf[zeros_subset], self._terminate_buf[zeros_subset] = compute_humanoid_traj_reset( + # self.reset_buf[zeros_subset], self.progress_buf[zeros_subset], self._contact_forces[zeros_subset], + # self._contact_body_ids, self._rigid_body_pos[zeros_subset], self.max_episode_length, self._enable_early_termination, + # 0.3, flags.no_collision_check) + + # self.reset_buf[~zeros_subset], self._terminate_buf[~zeros_subset] = compute_humanoid_reset( + # self.reset_buf[~zeros_subset], self.progress_buf[~zeros_subset], self._contact_forces[~zeros_subset], + # self._contact_body_ids, self._rigid_body_pos[~zeros_subset][..., self._reset_bodies_id, :], ref_rb_pos[~zeros_subset][..., self._reset_bodies_id, :], + # pass_time[~zeros_subset], self._enable_early_termination, + # self._termination_distances[..., self._reset_bodies_id], flags.no_collision_check) + + # self.reset_buf, self._terminate_buf = compute_humanoid_traj_reset( # traj reset + # self.reset_buf, self.progress_buf, self._contact_forces, self._contact_body_ids, self._rigid_body_pos, pass_time_max, self._enable_early_termination, 0.3, flags.no_collision_check) + self.reset_buf[:], self._terminate_buf[:] = compute_humanoid_im_reset( # Humanoid reset + self.reset_buf, self.progress_buf, self._contact_forces, self._contact_body_ids, self._rigid_body_pos[..., self._reset_bodies_id, :], ref_rb_pos[..., self._reset_bodies_id, :], pass_time, self._enable_early_termination, self._termination_distances[..., self._reset_bodies_id], + flags.no_collision_check, flags.im_eval and (not self.strict_eval)) + + else: + body_pos = self._rigid_body_pos[..., self._reset_bodies_id, :].clone() + ref_body_pos = ref_rb_pos[..., self._reset_bodies_id, :].clone() + + if self._occl_training: + ref_body_pos[self.random_occlu_idx[:, self._reset_bodies_id]] = body_pos[self.random_occlu_idx[:, self._reset_bodies_id]] + + self.reset_buf[:], self._terminate_buf[:] = compute_humanoid_im_reset(self.reset_buf, self.progress_buf, self._contact_forces, self._contact_body_ids, \ + body_pos, ref_body_pos, pass_time, self._enable_early_termination, + self._termination_distances[..., self._reset_bodies_id], flags.no_collision_check, flags.im_eval and (not self.strict_eval)) + is_recovery = torch.logical_and(~pass_time, self._cycle_counter > 0) # pass time should override the cycle counter. + self.reset_buf[is_recovery] = 0 + self._terminate_buf[is_recovery] = 0 + + return + + def _draw_task(self): + self._update_marker() + return + + +##################################################################### +###=========================jit functions=========================### +##################################################################### + + +@torch.jit.script +def compute_imitation_observations(root_pos, root_rot, body_pos, body_rot, body_vel, body_ang_vel, ref_body_pos, ref_body_rot, ref_body_vel, ref_body_ang_vel, time_steps, upright): + # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor,Tensor, Tensor, int, bool) -> Tensor + # We do not use any dof in observation. + obs = [] + B, J, _ = body_pos.shape + + if not upright: + root_rot = remove_base_rot(root_rot) + + heading_inv_rot = torch_utils.calc_heading_quat_inv(root_rot) + heading_rot = torch_utils.calc_heading_quat(root_rot) + heading_inv_rot_expand = heading_inv_rot.unsqueeze(-2).repeat((1, body_pos.shape[1], 1)).repeat_interleave(time_steps, 0) + heading_rot_expand = heading_rot.unsqueeze(-2).repeat((1, body_pos.shape[1], 1)).repeat_interleave(time_steps, 0) + + diff_global_body_pos = ref_body_pos.view(B, time_steps, J, 3) - body_pos.view(B, 1, J, 3) + diff_global_body_rot = torch_utils.quat_mul(ref_body_rot.view(B, time_steps, J, 4), torch_utils.quat_conjugate(body_rot).repeat_interleave(time_steps, 0).view(B, time_steps, J, 4)) + + diff_local_body_pos_flat = torch_utils.my_quat_rotate(heading_inv_rot_expand.view(-1, 4), diff_global_body_pos.view(-1, 3)) + diff_local_body_rot_flat = torch_utils.quat_mul(torch_utils.quat_mul(heading_inv_rot_expand.view(-1, 4), diff_global_body_rot.view(-1, 4)), heading_rot_expand.view(-1, 4)) # Need to be change of basis + + obs.append(diff_local_body_pos_flat.view(B, -1)) # 1 * 10 * 3 * 3 + obs.append(torch_utils.quat_to_tan_norm(diff_local_body_rot_flat).view(B, -1)) # 1 * 10 * 3 * 6 + + ##### Velocities + diff_global_vel = ref_body_vel.view(B, time_steps, J, 3) - body_vel.view(B, 1, J, 3) + diff_global_ang_vel = ref_body_ang_vel.view(B, time_steps, J, 3) - body_ang_vel.view(B, 1, J, 3) + + diff_local_vel = torch_utils.my_quat_rotate(heading_inv_rot_expand.view(-1, 4), diff_global_vel.view(-1, 3)) + diff_local_ang_vel = torch_utils.my_quat_rotate(heading_inv_rot_expand.view(-1, 4), diff_global_ang_vel.view(-1, 3)) + obs.append(diff_local_vel.view(B, -1)) # 3 * 3 + obs.append(diff_local_ang_vel.view(B, -1)) # 3 * 3 + + obs = torch.cat(obs, dim=-1) + return obs + + +@torch.jit.script +def compute_imitation_observations_v2(root_pos, root_rot, body_pos, body_rot, body_vel, body_ang_vel, dof_pos, ref_body_pos, ref_body_rot, ref_body_vel, ref_body_ang_vel, ref_dof_pos, time_steps, upright): + # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor,Tensor, Tensor,Tensor,Tensor, int, bool) -> Tensor + # Adding dof + obs = [] + B, J, _ = body_pos.shape + + if not upright: + root_rot = remove_base_rot(root_rot) + + heading_inv_rot = torch_utils.calc_heading_quat_inv(root_rot) + heading_rot = torch_utils.calc_heading_quat(root_rot) + heading_inv_rot_expand = heading_inv_rot.unsqueeze(-2).repeat((1, body_pos.shape[1], 1)).repeat_interleave(time_steps, 0) + heading_rot_expand = heading_rot.unsqueeze(-2).repeat((1, body_pos.shape[1], 1)).repeat_interleave(time_steps, 0) + + diff_global_body_pos = ref_body_pos.view(B, time_steps, J, 3) - body_pos.view(B, 1, J, 3) + diff_global_body_rot = torch_utils.quat_mul(ref_body_rot.view(B, time_steps, J, 4), torch_utils.quat_conjugate(body_rot).repeat_interleave(time_steps, 0).view(B, time_steps, J, 4)) + + diff_local_body_pos_flat = torch_utils.my_quat_rotate(heading_inv_rot_expand.view(-1, 4), diff_global_body_pos.view(-1, 3)) + diff_local_body_rot_flat = torch_utils.quat_mul(torch_utils.quat_mul(heading_inv_rot_expand.view(-1, 4), diff_global_body_rot.view(-1, 4)), heading_rot_expand.view(-1, 4)) # Need to be change of basis + + obs.append(diff_local_body_pos_flat.view(B, -1)) # 1 * 10 * 3 * 3 + obs.append(torch_utils.quat_to_tan_norm(diff_local_body_rot_flat).view(B, -1)) # 1 * 10 * 3 * 6 + + ##### Velocities + diff_global_vel = ref_body_vel.view(B, time_steps, J, 3) - body_vel.view(B, 1, J, 3) + diff_global_ang_vel = ref_body_ang_vel.view(B, time_steps, J, 3) - body_ang_vel.view(B, 1, J, 3) + + diff_local_vel = torch_utils.my_quat_rotate(heading_inv_rot_expand.view(-1, 4), diff_global_vel.view(-1, 3)) + diff_local_ang_vel = torch_utils.my_quat_rotate(heading_inv_rot_expand.view(-1, 4), diff_global_ang_vel.view(-1, 3)) + obs.append(diff_local_vel.view(B, -1)) # 3 * 3 + obs.append(diff_local_ang_vel.view(B, -1)) # 3 * 3 + + ##### Dof_pos diff + diff_dof_pos = ref_dof_pos.view(B, time_steps, -1) - dof_pos.view(B, time_steps, -1) + obs.append(diff_dof_pos.view(B, -1)) # 23 * 3 + + obs = torch.cat(obs, dim=-1) + return obs + + +@torch.jit.script +def compute_imitation_observations_v3(root_pos, root_rot, body_pos, body_rot, body_vel, body_ang_vel, ref_body_pos, ref_body_rot, ref_body_vel, ref_body_ang_vel, time_steps, upright): + # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor,Tensor, Tensor, int, bool) -> Tensor + # No velocities + obs = [] + B, J, _ = body_pos.shape + + if not upright: + root_rot = remove_base_rot(root_rot) + + heading_inv_rot = torch_utils.calc_heading_quat_inv(root_rot) + heading_rot = torch_utils.calc_heading_quat(root_rot) + heading_inv_rot_expand = heading_inv_rot.unsqueeze(-2).repeat((1, body_pos.shape[1], 1)).repeat_interleave(time_steps, 0) + heading_rot_expand = heading_rot.unsqueeze(-2).repeat((1, body_pos.shape[1], 1)).repeat_interleave(time_steps, 0) + + diff_global_body_pos = ref_body_pos.view(B, time_steps, J, 3) - body_pos.view(B, 1, J, 3) + diff_local_body_pos_flat = torch_utils.my_quat_rotate(heading_inv_rot_expand.view(-1, 4), diff_global_body_pos.view(-1, 3)) + obs.append(diff_local_body_pos_flat.view(B, -1)) # 1 * 10 * 3 * 3 + + diff_global_body_rot = torch_utils.quat_mul(ref_body_rot.view(B, time_steps, J, 4), torch_utils.quat_conjugate(body_rot).repeat_interleave(time_steps, 0).view(B, time_steps, J, 4)) + diff_local_body_rot_flat = torch_utils.quat_mul(torch_utils.quat_mul(heading_inv_rot_expand.view(-1, 4), diff_global_body_rot.view(-1, 4)), heading_rot_expand.view(-1, 4)) # Need to be change of basis + obs.append(torch_utils.quat_to_tan_norm(diff_local_body_rot_flat).view(B, -1)) # 1 * 10 * 3 * 6 + + obs = torch.cat(obs, dim=-1) + + return obs + + +@torch.jit.script +def compute_imitation_observations_v6(root_pos, root_rot, body_pos, body_rot, body_vel, body_ang_vel, ref_body_pos, ref_body_rot, ref_body_vel, ref_body_ang_vel, time_steps, upright): + # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor,Tensor, Tensor,Tensor,Tensor, int, bool) -> Tensor + # Adding pose information at the back + # Future tracks in this obs will not contain future diffs. + obs = [] + B, J, _ = body_pos.shape + + if not upright: + root_rot = remove_base_rot(root_rot) + + heading_inv_rot = torch_utils.calc_heading_quat_inv(root_rot) + heading_rot = torch_utils.calc_heading_quat(root_rot) + heading_inv_rot_expand = heading_inv_rot.unsqueeze(-2).repeat((1, body_pos.shape[1], 1)).repeat_interleave(time_steps, 0) + heading_rot_expand = heading_rot.unsqueeze(-2).repeat((1, body_pos.shape[1], 1)).repeat_interleave(time_steps, 0) + + + ##### Body position and rotation differences + diff_global_body_pos = ref_body_pos.view(B, time_steps, J, 3) - body_pos.view(B, 1, J, 3) + diff_local_body_pos_flat = torch_utils.my_quat_rotate(heading_inv_rot_expand.view(-1, 4), diff_global_body_pos.view(-1, 3)) + + body_rot[:, None].repeat_interleave(time_steps, 1) + diff_global_body_rot = torch_utils.quat_mul(ref_body_rot.view(B, time_steps, J, 4), torch_utils.quat_conjugate(body_rot[:, None].repeat_interleave(time_steps, 1))) + diff_local_body_rot_flat = torch_utils.quat_mul(torch_utils.quat_mul(heading_inv_rot_expand.view(-1, 4), diff_global_body_rot.view(-1, 4)), heading_rot_expand.view(-1, 4)) # Need to be change of basis + + ##### linear and angular Velocity differences + diff_global_vel = ref_body_vel.view(B, time_steps, J, 3) - body_vel.view(B, 1, J, 3) + diff_local_vel = torch_utils.my_quat_rotate(heading_inv_rot_expand.view(-1, 4), diff_global_vel.view(-1, 3)) + + + diff_global_ang_vel = ref_body_ang_vel.view(B, time_steps, J, 3) - body_ang_vel.view(B, 1, J, 3) + diff_local_ang_vel = torch_utils.my_quat_rotate(heading_inv_rot_expand.view(-1, 4), diff_global_ang_vel.view(-1, 3)) + + + ##### body pos + Dof_pos This part will have proper futuers. + local_ref_body_pos = ref_body_pos.view(B, time_steps, J, 3) - root_pos.view(B, 1, 1, 3) # preserves the body position + local_ref_body_pos = torch_utils.my_quat_rotate(heading_inv_rot_expand.view(-1, 4), local_ref_body_pos.view(-1, 3)) + + local_ref_body_rot = torch_utils.quat_mul(heading_inv_rot_expand.view(-1, 4), ref_body_rot.view(-1, 4)) + local_ref_body_rot = torch_utils.quat_to_tan_norm(local_ref_body_rot) + + # make some changes to how futures are appended. + obs.append(diff_local_body_pos_flat.view(B, time_steps, -1)) # 1 * timestep * 24 * 3 + obs.append(torch_utils.quat_to_tan_norm(diff_local_body_rot_flat).view(B, time_steps, -1)) # 1 * timestep * 24 * 6 + obs.append(diff_local_vel.view(B, time_steps, -1)) # timestep * 24 * 3 + obs.append(diff_local_ang_vel.view(B, time_steps, -1)) # timestep * 24 * 3 + obs.append(local_ref_body_pos.view(B, time_steps, -1)) # timestep * 24 * 3 + obs.append(local_ref_body_rot.view(B, time_steps, -1)) # timestep * 24 * 6 + + obs = torch.cat(obs, dim=-1).view(B, -1) + return obs + + +@torch.jit.script +def compute_imitation_observations_v7(root_pos, root_rot, body_pos, body_vel, ref_body_pos, ref_body_vel, time_steps, upright): + # type: (Tensor, Tensor, Tensor,Tensor, Tensor, Tensor, int, bool) -> Tensor + # No rotation information. Leave IK for RL. + # Future tracks in this obs will not contain future diffs. + obs = [] + B, J, _ = body_pos.shape + + if not upright: + root_rot = remove_base_rot(root_rot) + + heading_inv_rot = torch_utils.calc_heading_quat_inv(root_rot) + heading_inv_rot_expand = heading_inv_rot.unsqueeze(-2).repeat((1, body_pos.shape[1], 1)).repeat_interleave(time_steps, 0) + + ##### Body position differences + diff_global_body_pos = ref_body_pos.view(B, time_steps, J, 3) - body_pos.view(B, 1, J, 3) + diff_local_body_pos_flat = torch_utils.my_quat_rotate(heading_inv_rot_expand.view(-1, 4), diff_global_body_pos.view(-1, 3)) + + ##### Linear Velocity differences + diff_global_vel = ref_body_vel.view(B, time_steps, J, 3) - body_vel.view(B, 1, J, 3) + diff_local_vel = torch_utils.my_quat_rotate(heading_inv_rot_expand.view(-1, 4), diff_global_vel.view(-1, 3)) + + ##### body pos + Dof_pos + local_ref_body_pos = ref_body_pos.view(B, time_steps, J, 3) - root_pos.view(B, 1, 1, 3) # preserves the body position + local_ref_body_pos = torch_utils.my_quat_rotate(heading_inv_rot_expand.view(-1, 4), local_ref_body_pos.view(-1, 3)) + + # make some changes to how futures are appended. + obs.append(diff_local_body_pos_flat.view(B, time_steps, -1)) # 1 * 10 * 3 * 3 + obs.append(diff_local_vel.view(B, time_steps, -1)) # 3 * 3 + obs.append(local_ref_body_pos.view(B, time_steps, -1)) # 2 + + obs = torch.cat(obs, dim=-1).view(B, -1) + return obs + +@torch.jit.script +def compute_imitation_observations_v8(root_pos, root_rot, body_pos, body_rot, body_vel, body_ang_vel, ref_body_pos, ref_body_rot, ref_body_vel, ref_body_ang_vel, time_steps, upright): + # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor,Tensor, Tensor,Tensor,Tensor, int, bool) -> Tensor + # Adding pose information at the back + # Future tracks in this obs will not contain future diffs. + obs = [] + B, J, _ = body_pos.shape + + if not upright: + root_rot = remove_base_rot(root_rot) + + heading_inv_rot = torch_utils.calc_heading_quat_inv(root_rot) + heading_rot = torch_utils.calc_heading_quat(root_rot) + heading_inv_rot_expand = heading_inv_rot.unsqueeze(-2).repeat((1, body_pos.shape[1], 1)) + heading_rot_expand = heading_rot.unsqueeze(-2).repeat((1, body_pos.shape[1], 1)) + + diff_global_body_pos = ref_body_pos.view(B, time_steps, J, 3)[:, 0:1] - body_pos.view(B, 1, J, 3) + diff_global_body_rot = torch_utils.quat_mul(ref_body_rot.view(B, time_steps, J, 4)[:, 0:1], torch_utils.quat_conjugate(body_rot).view(B, 1, J, 4)) + + diff_local_body_pos_flat = torch_utils.my_quat_rotate(heading_inv_rot_expand.view(-1, 4), diff_global_body_pos.view(-1, 3)) + diff_local_body_rot_flat = torch_utils.quat_mul(torch_utils.quat_mul(heading_inv_rot_expand.view(-1, 4), diff_global_body_rot.view(-1, 4)), heading_rot_expand.view(-1, 4)) # Need to be change of basis + + ##### Body position differences + obs.append(diff_local_body_pos_flat.view(B, -1)) # 1 * 10 * J * 3 + obs.append(torch_utils.quat_to_tan_norm(diff_local_body_rot_flat).view(B, -1)) # 1 * 10 * J * 6 + + ##### Velocity differences + diff_global_vel = ref_body_vel.view(B, time_steps, J, 3)[:, 0:1] - body_vel.view(B, 1, J, 3) + diff_global_ang_vel = ref_body_ang_vel.view(B, time_steps, J, 3)[:, 0:1] - body_ang_vel.view(B, 1, J, 3) + + diff_local_vel = torch_utils.my_quat_rotate(heading_inv_rot_expand.view(-1, 4), diff_global_vel.view(-1, 3)) + diff_local_ang_vel = torch_utils.my_quat_rotate(heading_inv_rot_expand.view(-1, 4), diff_global_ang_vel.view(-1, 3)) + obs.append(diff_local_vel.view(B, -1)) # 24 * 3 + obs.append(diff_local_ang_vel.view(B, -1)) # 24 * 3 + + ##### body pos + Dof_pos This part will have proper futuers. + heading_inv_rot_expand = heading_inv_rot.unsqueeze(-2).repeat((1, body_pos.shape[1], 1)).repeat_interleave(time_steps, 0) + local_ref_body_pos = ref_body_pos.view(B, time_steps, J, 3) - root_pos.view(B, 1, 1, 3) # preserves the body position + local_ref_body_pos = torch_utils.my_quat_rotate(heading_inv_rot_expand.view(-1, 4), local_ref_body_pos.view(-1, 3)) + + local_ref_body_rot = torch_utils.quat_mul(heading_inv_rot_expand.view(-1, 4), ref_body_rot.view(-1, 4)) + local_ref_body_rot = torch_utils.quat_to_tan_norm(local_ref_body_rot) + + local_ref_body_vel = torch_utils.my_quat_rotate(heading_inv_rot_expand.view(-1, 4), ref_body_vel.view(-1, 3)) + local_ref_body_ang_vel = torch_utils.my_quat_rotate(heading_inv_rot_expand.view(-1, 4), ref_body_ang_vel.view(-1, 3)) + + # make some changes to how futures are appended. + if time_steps > 1: + local_ref_body_pos = local_ref_body_pos.view(B, time_steps, -1) + local_ref_body_rot = local_ref_body_rot.view(B, time_steps, -1) + + obs.append(local_ref_body_pos[:, 0].view(B, -1)) # first append the current ones + obs.append(local_ref_body_rot[:, 0].view(B, -1)) + obs.append(local_ref_body_vel[:, 0].view(B, -1)) + obs.append(local_ref_body_ang_vel[:, 0].view(B, -1)) + + + obs.append(local_ref_body_pos[:, 1:].reshape(B, -1)) # then append the future ones + obs.append(local_ref_body_rot[:, 1:].reshape(B, -1)) + obs.append(local_ref_body_vel[:, 1:].view(B, -1)) + obs.append(local_ref_body_ang_vel[:, 1:].view(B, -1)) + else: + obs.append(local_ref_body_pos.view(B, -1)) # 24 * timestep * 3 + obs.append(local_ref_body_rot.view(B, -1)) # 24 * timestep * 6 + obs.append(local_ref_body_vel.view(B, -1)) # 24 * timestep * 3 + obs.append(local_ref_body_ang_vel.view(B, -1)) # 24 * timestep * 3 + # obs.append(ref_dof_pos.view(B, -1)) # 23 * 3 + + obs = torch.cat(obs, dim=-1) + return obs + + +@torch.jit.script +def compute_imitation_observations_v9(root_pos, root_rot, body_pos, body_rot, body_vel, body_ang_vel, ref_body_pos, ref_body_rot, ref_root_vel, ref_body_root_ang_vel, time_steps, upright): + # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor,Tensor, Tensor,Tensor,Tensor, int, bool) -> Tensor + # Adding pose information at the back + # Future tracks in this obs will not contain future diffs. + obs = [] + B, J, _ = body_pos.shape + + if not upright: + root_rot = remove_base_rot(root_rot) + + heading_inv_rot = torch_utils.calc_heading_quat_inv(root_rot) + heading_rot = torch_utils.calc_heading_quat(root_rot) + heading_inv_rot_expand = heading_inv_rot.unsqueeze(-2).repeat((1, body_pos.shape[1], 1)).repeat_interleave(time_steps, 0) + heading_rot_expand = heading_rot.unsqueeze(-2).repeat((1, body_pos.shape[1], 1)).repeat_interleave(time_steps, 0) + + + ##### Body position and rotation differences + diff_global_body_pos = ref_body_pos.view(B, time_steps, J, 3) - body_pos.view(B, 1, J, 3) + diff_local_body_pos_flat = torch_utils.my_quat_rotate(heading_inv_rot_expand.view(-1, 4), diff_global_body_pos.view(-1, 3)) + + + diff_global_body_rot = torch_utils.quat_mul(ref_body_rot.view(B, time_steps, J, 4), torch_utils.quat_conjugate(body_rot[:, None].repeat_interleave(time_steps, 1))) + diff_local_body_rot_flat = torch_utils.quat_mul(torch_utils.quat_mul(heading_inv_rot_expand.view(-1, 4), diff_global_body_rot.view(-1, 4)), heading_rot_expand.view(-1, 4)) # Need to be change of basis + + + ##### linear and angular Velocity differences + heading_inv_rot_expand_root = heading_inv_rot.unsqueeze(-1).repeat_interleave(time_steps, 0) + root_vel, root_ang_vel = body_vel[:, 0], body_ang_vel[:, 0] + diff_global_root_vel = ref_root_vel.view(B, time_steps, 3) - root_vel.view(B, 1, 3) + diff_local_root_vel = torch_utils.my_quat_rotate(heading_inv_rot_expand_root.view(-1, 4), diff_global_root_vel.view(-1, 3)) + + + diff_global_root_ang_vel = ref_body_root_ang_vel.view(B, time_steps, 3) - root_ang_vel.view(B, 1, 3) + diff_local_root_ang_vel = torch_utils.my_quat_rotate(heading_inv_rot_expand_root.view(-1, 4), diff_global_root_ang_vel.view(-1, 3)) + + + ##### body pos + Dof_pos This part will have proper futuers. + local_ref_body_pos = ref_body_pos.view(B, time_steps, J, 3) - root_pos.view(B, 1, 1, 3) # preserves the body position + local_ref_body_pos = torch_utils.my_quat_rotate(heading_inv_rot_expand.view(-1, 4), local_ref_body_pos.view(-1, 3)) + + local_ref_body_rot = torch_utils.quat_mul(heading_inv_rot_expand.view(-1, 4), ref_body_rot.view(-1, 4)) + local_ref_body_rot = torch_utils.quat_to_tan_norm(local_ref_body_rot) + + # make some changes to how futures are appended. + obs.append(diff_local_body_pos_flat.view(B, time_steps, -1)) # 1 * 10 * 3 * 3 + obs.append(torch_utils.quat_to_tan_norm(diff_local_body_rot_flat).view(B, time_steps, -1)) # 1 * 10 * 3 * 6 + obs.append(diff_local_root_vel.view(B, time_steps, -1)) # 3 * 3 + obs.append(diff_local_root_ang_vel.view(B, time_steps, -1)) # 3 * 3 + obs.append(local_ref_body_pos.view(B, time_steps, -1)) # 24 * timestep * 3 + obs.append(local_ref_body_rot.view(B, time_steps, -1)) # 24 * timestep * 6 + + obs = torch.cat(obs, dim=-1).view(B, -1) + return obs + + +@torch.jit.script +def compute_imitation_reward(root_pos, root_rot, body_pos, body_rot, body_vel, body_ang_vel, ref_body_pos, ref_body_rot, ref_body_vel, ref_body_ang_vel, rwd_specs): + # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor,Tensor, Tensor, Dict[str, float]) -> Tuple[Tensor, Tensor] + k_pos, k_rot, k_vel, k_ang_vel = rwd_specs["k_pos"], rwd_specs["k_rot"], rwd_specs["k_vel"], rwd_specs["k_ang_vel"] + w_pos, w_rot, w_vel, w_ang_vel = rwd_specs["w_pos"], rwd_specs["w_rot"], rwd_specs["w_vel"], rwd_specs["w_ang_vel"] + + # body position reward + diff_global_body_pos = ref_body_pos - body_pos + diff_body_pos_dist = (diff_global_body_pos**2).mean(dim=-1).mean(dim=-1) + r_body_pos = torch.exp(-k_pos * diff_body_pos_dist) + + # body rotation reward + diff_global_body_rot = torch_utils.quat_mul(ref_body_rot, torch_utils.quat_conjugate(body_rot)) + diff_global_body_angle = torch_utils.quat_to_angle_axis(diff_global_body_rot)[0] + diff_global_body_angle_dist = (diff_global_body_angle**2).mean(dim=-1) + r_body_rot = torch.exp(-k_rot * diff_global_body_angle_dist) + + # body linear velocity reward + diff_global_vel = ref_body_vel - body_vel + diff_global_vel_dist = (diff_global_vel**2).mean(dim=-1).mean(dim=-1) + r_vel = torch.exp(-k_vel * diff_global_vel_dist) + + # body angular velocity reward + diff_global_ang_vel = ref_body_ang_vel - body_ang_vel + diff_global_ang_vel_dist = (diff_global_ang_vel**2).mean(dim=-1).mean(dim=-1) + r_ang_vel = torch.exp(-k_ang_vel * diff_global_ang_vel_dist) + + reward = w_pos * r_body_pos + w_rot * r_body_rot + w_vel * r_vel + w_ang_vel * r_ang_vel + reward_raw = torch.stack([r_body_pos, r_body_rot, r_vel, r_ang_vel], dim=-1) + # import ipdb + # ipdb.set_trace() + return reward, reward_raw + + +@torch.jit.script +def compute_point_goal_reward(prev_dist, curr_dist): + # type: (Tensor, Tensor) -> Tuple[Tensor, Tensor] + reward = torch.clamp(prev_dist - curr_dist, max=1 / 3) * 9 + + return reward, reward + + +@torch.jit.script +def compute_location_reward(root_pos, tar_pos): + # type: (Tensor, Tensor) -> Tuple[Tensor, Tensor] + pos_err_scale = 1.0 + + pos_diff = tar_pos[..., 0:2] - root_pos[..., 0:2] + + pos_err = torch.sum(pos_diff * pos_diff, dim=-1) + pos_reward = torch.exp(-pos_err_scale * pos_err) + + reward = pos_reward + + return reward, reward + + +@torch.jit.script +def compute_humanoid_im_reset(reset_buf, progress_buf, contact_buf, contact_body_ids, rigid_body_pos, ref_body_pos, pass_time, enable_early_termination, termination_distance, disableCollision, use_mean): + # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, bool, Tensor, bool, bool) -> Tuple[Tensor, Tensor] + terminated = torch.zeros_like(reset_buf) + if (enable_early_termination): + if use_mean: + has_fallen = torch.any(torch.norm(rigid_body_pos - ref_body_pos, dim=-1).mean(dim=-1, keepdim=True) > termination_distance[0], dim=-1) # using average, same as UHC"s termination condition + else: + has_fallen = torch.any(torch.norm(rigid_body_pos - ref_body_pos, dim=-1) > termination_distance, dim=-1) # using max + # first timestep can sometimes still have nonzero contact forces + # so only check after first couple of steps + has_fallen *= (progress_buf > 1) + if disableCollision: + has_fallen[:] = False + terminated = torch.where(has_fallen, torch.ones_like(reset_buf), terminated) + + # if (contact_buf.abs().sum(dim=-1)[0] > 0).sum() > 2: + # np.set_printoptions(precision=4, suppress=1) + # print(contact_buf.numpy(), contact_buf.abs().sum(dim=-1)[0].nonzero().squeeze()) + + # if terminated.sum() > 0: + # import ipdb; ipdb.set_trace() + # print("Fallen") + + reset = torch.where(pass_time, torch.ones_like(reset_buf), terminated) + # import ipdb + # ipdb.set_trace() + + return reset, terminated + + +@torch.jit.script +def compute_location_observations(root_pos, root_rot, target_pos, upright): + # type: (Tensor, Tensor, Tensor, bool) -> Tensor + + if not upright: + root_rot = remove_base_rot(root_rot) + heading_inv_rot = torch_utils.calc_heading_quat_inv(root_rot) + + diff_global_body_pos = target_pos - root_pos + diff_local_body_pos_flat = torch_utils.my_quat_rotate(heading_inv_rot, diff_global_body_pos.view(-1, 3)) + max_distance = 7.5 + distances = torch.norm(diff_local_body_pos_flat, dim=-1) + smallers = distances < max_distance # 2.5 seconds, 5 time steps, + diff_locations = torch.zeros((smallers.shape[0], 5, 3)).to(diff_local_body_pos_flat) + diff_locations[smallers] = (diff_local_body_pos_flat[smallers, None] * torch.linspace(0.2, 1, 5)[None, :, None].repeat(smallers.sum(), 1, 1).to(diff_local_body_pos_flat)) # 5 time stpes, 2 seconds + modified_locals = diff_local_body_pos_flat[~smallers] * distances[~smallers, None] / max_distance + diff_locations[~smallers] = modified_locals[:, None] * torch.linspace(0.2, 1, 5)[None, :, None].repeat((~smallers).sum(), 1, 1).to(diff_local_body_pos_flat) + + local_traj_pos = diff_locations[..., 0:2] + + obs = torch.reshape(local_traj_pos, (local_traj_pos.shape[0], -1)) + return obs + + +@torch.jit.script +def compute_humanoid_traj_reset(reset_buf, progress_buf, contact_buf, contact_body_ids, rigid_body_pos, pass_time, enable_early_termination, termination_heights, disableCollision): + # type: (Tensor, Tensor, Tensor, Tensor, Tensor,Tensor, Tensor, float, bool) -> Tuple[Tensor, Tensor] + terminated = torch.zeros_like(reset_buf) + + if (enable_early_termination): + masked_contact_buf = contact_buf.clone() + masked_contact_buf[:, contact_body_ids, :] = 0 + ## torch.sum to disable self-collision. + # force_threshold = 200 + force_threshold = 50 + body_contact_force = torch.sqrt(torch.square(torch.abs(masked_contact_buf.sum(dim=-2))).sum(dim=-1)) > force_threshold + + has_contacted_fall = body_contact_force + has_contacted_fall *= (progress_buf > 1) + + body_height = rigid_body_pos[..., 2] + fall_height = body_height < termination_heights + fall_height[:, contact_body_ids] = False + fall_height = torch.any(fall_height, dim=-1) + + has_failed = torch.logical_and(has_contacted_fall, fall_height) + + if disableCollision: + has_failed[:] = False + + ############################## Debug ############################## + # if torch.sum(has_fallen) > 0: + # import ipdb; ipdb.set_trace() + # print("???") + # mujoco_joint_names = np.array(['Pelvis', 'L_Hip', 'L_Knee', 'L_Ankle', 'L_Toe', 'R_Hip', 'R_Knee', 'R_Ankle', 'R_Toe', 'Torso', 'Spine', 'Chest', 'Neck', 'Head', 'L_Thorax', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'L_Hand', 'R_Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'R_Hand']) + # print( mujoco_joint_names[masked_contact_buf[0, :, 0].nonzero().cpu().numpy()]) + ############################## Debug ############################## + + terminated = torch.where(has_failed, torch.ones_like(reset_buf), terminated) + + reset = torch.where(pass_time, torch.ones_like(reset_buf), terminated) + + return reset, terminated \ No newline at end of file diff --git a/phc/env/tasks/humanoid_im_demo.py b/phc/env/tasks/humanoid_im_demo.py new file mode 100644 index 0000000..32e0338 --- /dev/null +++ b/phc/env/tasks/humanoid_im_demo.py @@ -0,0 +1,169 @@ + +from typing import OrderedDict +import torch +import numpy as np +from phc.utils.torch_utils import quat_to_tan_norm +import phc.env.tasks.humanoid_im as humanoid_im +from phc.env.tasks.humanoid_amp import HumanoidAMP, remove_base_rot +from phc.utils.motion_lib_smpl import MotionLibSMPL + +from phc.utils import torch_utils + +from isaacgym import gymapi +from isaacgym import gymtorch +from isaacgym.torch_utils import * +from phc.utils.flags import flags +import joblib +import gc +from collections import defaultdict +import aiohttp, cv2, asyncio, json + + +class HumanoidImDemo(humanoid_im.HumanoidIm): + + def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, headless): + super().__init__(cfg=cfg, sim_params=sim_params, physics_engine=physics_engine, device_type=device_type, device_id=device_id, headless=headless) + self.j3d = torch.zeros([1, 24, 3]).to(self.device).float() + self.j3d_vel = torch.zeros([1, 24, 3]).to(self.device).float() + + async def talk(self): + URL = 'http://0.0.0.0:8081/ws' + print("Starting websocket client") + session = aiohttp.ClientSession() + async with session.ws_connect(URL) as ws: + self.ws = ws + await ws.send_str("get_pose") + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + if msg.data == 'close cmd': + await ws.close() + break + else: + json_data = json.loads(msg.data) + self.j3d = torch.tensor(json_data["j3d_curr"]).to(self.device).float() + self.j3d_vel = torch.tensor(json_data["j3d_curr_vel"]).to(self.device).float() + + await ws.send_str("get_pose") + + elif msg.type == aiohttp.WSMsgType.CLOSED: + break + elif msg.type == aiohttp.WSMsgType.ERROR: + break + + def _update_marker(self): + if flags.show_traj: + self._marker_pos[:] = 0 + else: + self._marker_pos[:] = self.ref_body_pos + + # ######### Heading debug ####### + # points = self.init_root_points() + # base_quat = self._rigid_body_rot[0, 0:1] + # base_quat = remove_base_rot(base_quat) + # heading_rot = torch_utils.calc_heading_quat(base_quat) + # show_points = quat_apply(heading_rot.repeat(1, points.shape[0]).reshape(-1, 4), points) + (self._rigid_body_pos[0, 0:1]).unsqueeze(1) + # self._marker_pos[:] = show_points[:, :self._marker_pos.shape[1]] + # ######### Heading debug ####### + + self.gym.set_actor_root_state_tensor_indexed(self.sim, gymtorch.unwrap_tensor(self._root_states), gymtorch.unwrap_tensor(self._marker_actor_ids), len(self._marker_actor_ids)) + + return + + def _reset_ref_state_init(self, env_ids): + num_envs = env_ids.shape[0] + motion_ids, motion_times, root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, rb_pos, rb_rot, body_vel, body_ang_vel = self._sample_ref_state(env_ids) + + from scipy.spatial.transform import Rotation as sRot + random_heading_quat = torch.from_numpy(sRot.from_euler("xyz", [0, 0, np.pi]).as_quat())[None,].float().to(self.device) + random_heading_quat_repeat = random_heading_quat[:, None].repeat(1, 24, 1) + root_rot = quat_mul(random_heading_quat, root_rot).clone() + rb_pos = quat_apply(random_heading_quat_repeat, rb_pos - root_pos[:, None, :]).clone() + rb_rot = quat_mul(random_heading_quat_repeat, rb_rot).clone() + root_ang_vel = quat_apply(random_heading_quat, root_ang_vel).clone() + rb_pos = rb_pos + (self.j3d[0, 0:1, :] - root_pos) + root_pos = self.j3d[0, 0:1, :] + root_pos[..., 2] = 0.93 + + self._set_env_state(env_ids=env_ids, root_pos=root_pos, root_rot=root_rot, dof_pos=dof_pos, root_vel=root_vel, root_ang_vel=root_ang_vel, dof_vel=dof_vel, rigid_body_pos=rb_pos, rigid_body_rot=rb_rot, rigid_body_vel=body_vel, rigid_body_ang_vel=body_ang_vel) + + self._reset_ref_env_ids = env_ids + self._reset_ref_motion_ids = motion_ids + self._reset_ref_motion_times = motion_times + self._motion_start_times[env_ids] = motion_times + self._sampled_motion_ids[env_ids] = motion_ids + if flags.follow: + self.start = True ## Updating camera when reset + return + + def _compute_observations(self, env_ids=None): + # env_ids is used for resetting + if env_ids is None: + env_ids = torch.arange(self.num_envs).to(self.device) + + self_obs = self._compute_humanoid_obs(env_ids) + self.self_obs_buf[env_ids] = self_obs + + if (self._enable_task_obs): + task_obs = self._compute_task_obs_demo(env_ids) + obs = torch.cat([self_obs, task_obs], dim=-1) + else: + obs = self_obs + + if self.obs_v == 4: + # Double sub will return a copy. + B, N = obs.shape + sums = self.obs_buf[env_ids, 0:10].abs().sum(dim=1) + zeros = sums == 0 + nonzero = ~zeros + obs_slice = self.obs_buf[env_ids] + obs_slice[zeros] = torch.tile(obs[zeros], (1, 5)) + obs_slice[nonzero] = torch.cat([obs_slice[nonzero, N:], obs[nonzero]], dim=-1) + self.obs_buf[env_ids] = obs_slice + else: + self.obs_buf[env_ids] = obs + return obs + + def _compute_task_obs_demo(self, env_ids=None): + if (env_ids is None): + body_pos = self._rigid_body_pos + body_rot = self._rigid_body_rot + body_vel = self._rigid_body_vel + body_ang_vel = self._rigid_body_ang_vel + env_ids = torch.arange(self.num_envs, dtype=torch.long, device=self.device) + else: + body_pos = self._rigid_body_pos[env_ids] + body_rot = self._rigid_body_rot[env_ids] + body_vel = self._rigid_body_vel[env_ids] + body_ang_vel = self._rigid_body_ang_vel[env_ids] + + root_pos = body_pos[..., 0, :] + root_rot = body_rot[..., 0, :] + + body_pos_subset = body_pos[..., self._track_bodies_id, :] + body_vel_subset = body_vel[..., self._track_bodies_id, :] + + # ref_rb_pos = self.j3d[((self.progress_buf[env_ids] + 1) / 2).long() % self.j3d.shape[0]] + # ref_body_vel = self.j3d_vel[((self.progress_buf[env_ids] + 1) / 2).long() % self.j3d_vel.shape[0]] + ref_rb_pos = self.j3d + ref_body_vel = self.j3d_vel + time_steps = 1 + + ref_rb_pos_subset = ref_rb_pos[..., self._track_bodies_id, :] + ref_body_vel_subset = ref_body_vel[..., self._track_bodies_id, :] + + if self.zero_out_far: + close_distance = 0.25 + distance = torch.norm(root_pos - ref_rb_pos_subset[..., 0, :], dim=-1) + + zeros_subset = distance > close_distance + ref_rb_pos_subset[zeros_subset, 1:] = body_pos_subset[zeros_subset, 1:] + ref_body_vel_subset[zeros_subset, :] = body_vel_subset[zeros_subset, :] + + obs = humanoid_im.compute_imitation_observations_v7(root_pos, root_rot, body_pos_subset, body_vel_subset, ref_rb_pos_subset, ref_body_vel_subset, time_steps, self._has_upright_start) + + if len(env_ids) == self.num_envs: + self.ref_body_pos = ref_rb_pos + self.ref_body_pos_subset = ref_rb_pos_subset + self.ref_pose_aa = None + + return obs diff --git a/phc/env/tasks/humanoid_im_distill.py b/phc/env/tasks/humanoid_im_distill.py new file mode 100644 index 0000000..573f1e0 --- /dev/null +++ b/phc/env/tasks/humanoid_im_distill.py @@ -0,0 +1,233 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import time +import torch +import phc.env.tasks.humanoid_im as humanoid_im +from phc.env.tasks.humanoid_amp import remove_base_rot +from phc.utils import torch_utils +from typing import OrderedDict + +from isaacgym.torch_utils import * +from phc.utils.flags import flags +from rl_games.algos_torch import torch_ext +import torch.nn as nn +from phc.learning.pnn import PNN +from collections import deque +from phc.utils.torch_utils import project_to_norm +from phc.learning.network_loader import load_z_encoder, load_z_decoder, load_pnn, load_mcp_mlp + +class HumanoidImDistill(humanoid_im.HumanoidIm): + + def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, headless): + super().__init__(cfg=cfg, sim_params=sim_params, physics_engine=physics_engine, device_type=device_type, device_id=device_id, headless=headless) + + + # if True: + if self.distill and not flags.test: + check_points = [torch_ext.load_checkpoint(ck_path) for ck_path in self.models_path] + self.distill_z_model = self.cfg['env'].get("distill_z_model", False) + self.distill_model_config = self.cfg['env']['distill_model_config'] + self.fut_tracks_distill = self.distill_model_config.get("fut_tracks", False) + self.num_traj_samples_distill = self.distill_model_config.get("numTrajSamples", -1) + self.traj_sample_timestep_distill = self.distill_model_config.get("trajSampleTimestepInv", -1) + self.fut_tracks_dropout_distill = self.distill_model_config.get('fut_tracks_dropout', False) + self.z_activation = self.distill_model_config['z_activation'] + self.root_height_obs_distill = self.distill_model_config.get('root_height_obs', True) + ### Loading Distill Model ### + + if self.distill_z_model: + self.embedding_size_distill = self.distill_model_config['embedding_size'] + self.embedding_norm_distill = self.distill_model_config['embedding_norm'] + self.z_all_distill = self.distill_model_config.get('z_all', False) + self.distill_z_type = self.distill_model_config.get("z_type", "sphere") + self.use_vae_prior_loss = self.cfg['env'].get("use_vae_prior_loss", False) + self.use_vae_prior = self.cfg['env'].get("use_vae_prior", False) + self.decoder = load_z_decoder(check_points[0], activation = self.z_activation, z_type = self.distill_z_type, device = self.device) + self.encoder = load_z_encoder(check_points[0], activation = self.z_activation, z_type = self.distill_z_type, device = self.device) + else: + self.has_pnn_distill = self.distill_model_config.get("has_pnn", False) + self.has_lateral_distill = self.distill_model_config.get("has_lateral", False) + self.num_prim_distill = self.distill_model_config.get("num_prim", 3) + self.discrete_moe_distill = self.distill_model_config.get("discrete_moe", False) + if self.has_pnn_distill: + assert (len(self.models_path) == 2) + self.pnn = load_pnn(check_points[0], num_prim = self.num_prim_distill, has_lateral = self.has_lateral_distill, activation = self.z_activation, device = self.device) + self.running_mean, self.running_var = check_points[0]['running_mean_std']['running_mean'], check_points[0]['running_mean_std']['running_var'] + self.composer = load_mcp_mlp(check_points[1], activation = self.z_activation, device = self.device, mlp_name = "composer") + else: + self.encoder = load_mcp_mlp(check_points[0], activation = self.z_activation, device = self.device) + # else: + # self.actors = [self.load_moe_actor(ck) for ck in check_points] + # composer_cp = torch_ext.load_checkpoint("output/klab/smpl_im_comp_10/Humanoid_00282500.pth") + # self.composer = self.load_moe_composer(composer_cp) + + self.running_mean, self.running_var = check_points[-1]['running_mean_std']['running_mean'], check_points[-1]['running_mean_std']['running_var'] + + + if self.save_kin_info: + self.kin_dict = OrderedDict() + self.kin_dict.update({ + "gt_action": torch.zeros([self.num_envs, self._num_actions]), + "progress_buf": self.progress_buf.clone(), + }) # current root pos + root for future aggergration + return + + def _setup_character_props(self, key_bodies): + super()._setup_character_props(key_bodies) + return + + def load_pnn(self, pnn_ck): + mlp_args = {'input_size': pnn_ck['model']['a2c_network.pnn.actors.0.0.weight'].shape[1], 'units': pnn_ck['model']['a2c_network.pnn.actors.0.2.weight'].shape[::-1], 'activation': "relu", 'dense_func': torch.nn.Linear} + pnn = PNN(mlp_args, output_size=69, numCols=self.num_prim_distill, has_lateral=self.has_lateral_distill) + state_dict = pnn.state_dict() + for k in pnn_ck['model'].keys(): + if "pnn" in k: + pnn_dict_key = k.split("pnn.")[1] + state_dict[pnn_dict_key].copy_(pnn_ck['model'][k]) + pnn.freeze_pnn(self.num_prim_distill) + pnn.to(self.device) + return pnn + + def load_moe_actor(self, checkpoint): + actvation_func = torch_utils.activation_facotry(self.z_activation) + key_name = "a2c_network.actor_mlp" + + loading_keys = [k for k in checkpoint['model'].keys() if k.startswith(key_name)] + ["a2c_network.mu.weight", 'a2c_network.mu.bias'] + loading_keys_linear = [k for k in loading_keys if k.endswith('weight')] + + nn_modules = [] + for idx, key in enumerate(loading_keys_linear): + layer = nn.Linear(*checkpoint['model'][key].shape[::-1]) + nn_modules.append(layer) + if idx < len(loading_keys_linear) - 1: + nn_modules.append(actvation_func()) + actor = nn.Sequential(*nn_modules) + + state_dict = actor.state_dict() + + for idx, key_affix in enumerate(state_dict.keys()): + state_dict[key_affix].copy_(checkpoint['model'][loading_keys[idx]]) + + for param in actor.parameters(): + param.requires_grad = False + actor.to(self.device) + return actor + + + def load_moe_composer(self, checkpoint): + actvation_func = torch_utils.activation_facotry(self.z_activation) + composer = nn.Sequential(nn.Linear(*checkpoint['model']['a2c_network.composer.0.weight'].shape[::-1]), actvation_func(), + nn.Linear(*checkpoint['model']['a2c_network.composer.2.weight'].shape[::-1]), actvation_func(), + nn.Linear(*checkpoint['model']['a2c_network.composer.4.weight'].shape[::-1]), + actvation_func()) ###### This final activation function.............. if silu, does not make any sense. + + state_dict = composer.state_dict() + state_dict['0.weight'].copy_(checkpoint['model']['a2c_network.composer.0.weight']) + state_dict['0.bias'].copy_(checkpoint['model']['a2c_network.composer.0.bias']) + state_dict['2.weight'].copy_(checkpoint['model']['a2c_network.composer.2.weight']) + state_dict['2.bias'].copy_(checkpoint['model']['a2c_network.composer.2.bias']) + state_dict['4.weight'].copy_(checkpoint['model']['a2c_network.composer.4.weight']) + state_dict['4.bias'].copy_(checkpoint['model']['a2c_network.composer.4.bias']) + + for param in composer.parameters(): + param.requires_grad = False + composer.to(self.device) + return composer + + + def step(self, actions): + + + # if self.dr_randomizations.get('actions', None): + # actions = self.dr_randomizations['actions']['noise_lambda'](actions) + # if flags.server_mode: + # t_s = time.time() + # t_s = time.time() + # if True: + if not flags.test and self.save_kin_info: + with torch.no_grad(): + # Apply trained Model. + + ################ GT-Action ################ + temp_tracks = self._track_bodies_id + self._track_bodies_id = self._full_track_bodies_id + temp_fut, temp_fut_drop, temp_timestep, temp_num_steps, temp_root_height_obs = self._fut_tracks, self._fut_tracks_dropout, self._traj_sample_timestep, self._num_traj_samples, self._root_height_obs + self._fut_tracks, self._fut_tracks_dropout, self._traj_sample_timestep, self._num_traj_samples, self._root_height_obs = self.fut_tracks_distill, self.fut_tracks_dropout_distill, 1/self.traj_sample_timestep_distill, self.num_traj_samples_distill, self.root_height_obs_distill + + if self.root_height_obs_distill != temp_root_height_obs: + self_obs = self.obs_buf[:, :self.get_self_obs_size()] + self_obs = torch.cat([self._rigid_body_pos[:, 0, 2:3], self_obs], dim = -1) + # self_obs = self._compute_humanoid_obs() # torch.cat([self._rigid_body_pos[:, 0, 2:3], self_obs], dim = -1) - self._compute_humanoid_obs() + self_obs_size = self_obs.shape[-1] + self_obs = ((self_obs - self.running_mean.float()[:self_obs_size]) / torch.sqrt(self.running_var.float()[:self_obs_size] + 1e-05)) + else: + self_obs_size = self.get_self_obs_size() + self_obs = ((self.obs_buf[:, :self_obs_size] - self.running_mean.float()[:self_obs_size]) / torch.sqrt(self.running_var.float()[:self_obs_size] + 1e-05)) + + if temp_fut == self.fut_tracks_distill and temp_fut_drop == self.fut_tracks_dropout_distill and temp_timestep == 1/self.traj_sample_timestep_distill and temp_num_steps == self.num_traj_samples_distill\ + and temp_root_height_obs == self.root_height_obs_distill: + task_obs = self.obs_buf[:, self.get_self_obs_size():] + else: + task_obs = self._compute_task_obs(save_buffer = False) + + self._track_bodies_id = temp_tracks + self._fut_tracks, self._fut_tracks_dropout, self._traj_sample_timestep, self._num_traj_samples, self._root_height_obs = temp_fut, temp_fut_drop, temp_timestep, temp_num_steps, temp_root_height_obs + + + task_obs = ((task_obs - self.running_mean.float()[self_obs_size:]) / torch.sqrt(self.running_var.float()[self_obs_size:] + 1e-05)) + full_obs = torch.cat([self_obs, task_obs], dim = -1) + full_obs = torch.clamp(full_obs, min=-5.0, max=5.0) + + if self.distill_z_model: + gt_z = self.encoder.encoder(full_obs) + gt_z = project_to_norm(gt_z, self.embedding_norm_distill) + if self.z_all_distill: + gt_action = self.decoder.decoder(gt_z) + else: + gt_action = self.decoder.decoder(torch.cat([self_obs, gt_z], dim = -1)) + else: + if self.has_pnn_distill: + _, pnn_actions = self.pnn(full_obs) + x_all = torch.stack(pnn_actions, dim=1) + weights = self.composer(full_obs) + gt_action = torch.sum(weights[:, :, None] * x_all, dim=1) + else: + gt_action = self.encoder(full_obs) + # x_all = torch.stack([net(full_obs) for net in self.actors], dim=1) + + if self.save_kin_info: + self.kin_dict['gt_action'] = gt_action.squeeze() + self.kin_dict['progress_buf'] = self.progress_buf.clone() + + ################ GT-Action ################ + # actions = gt_action; print("using gt action") # Debugging + + # apply actions + self.pre_physics_step(actions) + + # step physics and render each frame + self._physics_step() + + # to fix! + if self.device == 'cpu': + self.gym.fetch_results(self.sim, True) + + # compute observations, rewards, resets, ... + self.post_physics_step() + if flags.server_mode: + dt = time.time() - t_s + print(f'\r {1/dt:.2f} fps', end='') + + # dt = time.time() - t_s + # self.fps.append(1/dt) + # print(f'\r {np.mean(self.fps):.2f} fps', end='') + + + if self.dr_randomizations.get('observations', None): + self.obs_buf = self.dr_randomizations['observations']['noise_lambda'](self.obs_buf) + diff --git a/phc/env/tasks/humanoid_im_distill_getup.py b/phc/env/tasks/humanoid_im_distill_getup.py new file mode 100644 index 0000000..8fe0c22 --- /dev/null +++ b/phc/env/tasks/humanoid_im_distill_getup.py @@ -0,0 +1,36 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +from typing import OrderedDict +import torch +import numpy as np +from phc.utils.torch_utils import quat_to_tan_norm +import phc.env.tasks.humanoid_im_getup as humanoid_im_getup +import phc.env.tasks.humanoid_im_distill as humanoid_im_distill +from phc.env.tasks.humanoid_amp import HumanoidAMP, remove_base_rot +from phc.utils.motion_lib_smpl import MotionLibSMPL + +from phc.utils import torch_utils + +from isaacgym import gymapi +from isaacgym import gymtorch +from isaacgym.torch_utils import * +from phc.utils.flags import flags +import joblib +import gc +from poselib.poselib.skeleton.skeleton3d import SkeletonMotion, SkeletonState +from rl_games.algos_torch import torch_ext +import torch.nn as nn +from phc.learning.network_loader import load_mcp_mlp, load_pnn +from collections import deque + +class HumanoidImDistillGetup(humanoid_im_getup.HumanoidImGetup, humanoid_im_distill.HumanoidImDistill): + + def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, headless): + super().__init__(cfg=cfg, sim_params=sim_params, physics_engine=physics_engine, device_type=device_type, device_id=device_id, headless=headless) + return + diff --git a/phc/env/tasks/humanoid_im_getup.py b/phc/env/tasks/humanoid_im_getup.py new file mode 100644 index 0000000..a6084bc --- /dev/null +++ b/phc/env/tasks/humanoid_im_getup.py @@ -0,0 +1,210 @@ +# Copyright (c) 2018-2023, NVIDIA Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import torch + +from isaacgym import gymapi +from isaacgym import gymtorch + +from phc.env.util import gym_util +from phc.env.tasks.humanoid_im import HumanoidIm +from isaacgym.torch_utils import * + +from utils import torch_utils +from phc.utils.flags import flags + + +class HumanoidImGetup(HumanoidIm): + + def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, headless): + + self._recovery_episode_prob_tgt = self._recovery_episode_prob = cfg["env"]["recoveryEpisodeProb"] + self._recovery_steps_tgt = self._recovery_steps = cfg["env"]["recoverySteps"] + self._fall_init_prob_tgt = self._fall_init_prob = cfg["env"]["fallInitProb"] + if flags.server_mode: + self._recovery_episode_prob_tgt = self._recovery_episode_prob = 1 + self._fall_init_prob_tgt = self._fall_init_prob = 0 + + self._reset_fall_env_ids = [] + + self.availalbe_fall_states = torch.zeros(cfg["env"]['num_envs']).long().to(device_id) + self.fall_id_assignments = torch.zeros(cfg["env"]['num_envs']).long().to(device_id) + self.getup_udpate_epoch = cfg['env'].get("getup_udpate_epoch", 10000) + + super().__init__(cfg=cfg, sim_params=sim_params, physics_engine=physics_engine, device_type=device_type, device_id=device_id, headless=headless) + + self._recovery_counter = torch.zeros(self.num_envs, device=self.device, dtype=torch.int) + + self._generate_fall_states() + + return + + def update_getup_schedule(self, epoch_num, getup_udpate_epoch=5000): + ## Need to add aneal + if epoch_num > getup_udpate_epoch: + self._recovery_episode_prob = self._recovery_episode_prob_tgt + self._fall_init_prob = self._fall_init_prob_tgt + else: + self._recovery_episode_prob = 0 + self._fall_init_prob = 1 + + def pre_physics_step(self, actions): + super().pre_physics_step(actions) + self._update_recovery_count() + + return + + def _generate_fall_states(self): + print("#################### Generating Fall State ####################") + max_steps = 150 + # max_steps = 50000 + + env_ids = to_torch(np.arange(self.num_envs), device=self.device, dtype=torch.long) + root_states = self._initial_humanoid_root_states[env_ids].clone() + + root_states[..., 3:7] = torch.randn_like(root_states[..., 3:7]) ## Random root rotation + root_states[..., 3:7] = torch.nn.functional.normalize(root_states[..., 3:7], dim=-1) + self._humanoid_root_states[env_ids] = root_states + + env_ids_int32 = self._humanoid_actor_ids[env_ids] + self.gym.set_actor_root_state_tensor_indexed(self.sim, gymtorch.unwrap_tensor(self._root_states), gymtorch.unwrap_tensor(env_ids_int32), len(env_ids_int32)) + # _dof_state: from the currently simulated states + self.gym.set_dof_state_tensor_indexed(self.sim, gymtorch.unwrap_tensor(torch.zeros_like(self._dof_state)), gymtorch.unwrap_tensor(env_ids_int32), len(env_ids_int32)) + + rand_actions = np.random.uniform(-0.5, 0.5, size=[self.num_envs, self.get_dof_action_size()]) + rand_actions = to_torch(rand_actions, device=self.device) + self.pre_physics_step(rand_actions) + + # step physics and render each frame + for i in range(max_steps): + self.render() + self.gym.simulate(self.sim) + + self._refresh_sim_tensors() + + self._fall_root_states = self._humanoid_root_states.clone() + self._fall_root_states[:, 7:13] = 0 + + # if flags.im_eval: + # print("im eval fall state!!!!!") + # self._fall_root_states[:, :2] = 0 + # if self.zero_out_far and self.zero_out_far_train: + # self._fall_root_states[:, 1] = -3 + + self._fall_dof_pos = self._dof_pos.clone() + self._fall_dof_vel = torch.zeros_like(self._dof_vel, device=self.device, dtype=torch.float) + + self.availalbe_fall_states[:] = 0 + self.fall_id_assignments[:] = 0 + + return + + def resample_motions(self): + super().resample_motions() + if not flags.test: + self._generate_fall_states() + self.reset() # Reset here should not cause the model to have collopsing episode lengths + + return + + def _reset_actors(self, env_ids): + self.availalbe_fall_states[self.fall_id_assignments[env_ids]] = 0 # Clear out the assignment counters for these ones + num_envs = env_ids.shape[0] # For these enviorments + recovery_probs = to_torch(np.array([self._recovery_episode_prob] * num_envs), device=self.device) + recovery_mask = torch.bernoulli(recovery_probs) == 1.0 + terminated_mask = (self._terminate_buf[env_ids] == 1) # If the env is terminated + recovery_mask = torch.logical_and(recovery_mask, terminated_mask) # for those env that have failed, with prob, turns them into recovery envs (this is harnessing episodes that natraully creates fall states) + + # reset the recovery counter for these envs. These env has the 150 steps for recovery from fall + recovery_ids = env_ids[recovery_mask] + if (len(recovery_ids) > 0): + self._reset_recovery_episode(recovery_ids) # These are bonus recovery episodes + + # For the rest of the envs (terminated and not set to recovery), with probability self._fall_init_prob, make them to fall state + nonrecovery_ids = env_ids[torch.logical_not(recovery_mask)] + fall_probs = to_torch(np.array([self._fall_init_prob] * nonrecovery_ids.shape[0]), device=self.device) + fall_mask = torch.bernoulli(fall_probs) == 1.0 + fall_ids = nonrecovery_ids[fall_mask] + if (len(fall_ids) > 0): + self._reset_fall_episode(fall_ids) # these automatically have recovery counter set to 60 + + # These envs, are the normal ones with ref state init. + nonfall_ids = nonrecovery_ids[torch.logical_not(fall_mask)] + if (len(nonfall_ids) > 0): + super()._reset_actors(nonfall_ids) + self._recovery_counter[nonfall_ids] = 0 + + return + + def _reset_recovery_episode(self, env_ids): + self._recovery_counter[env_ids] = self._recovery_steps + return + + def _reset_fall_episode(self, env_ids): + # fall_state_ids = torch.randperm(self._fall_root_states.shape[0])[:env_ids.shape[0]] + self.availalbe_fall_states[self.fall_id_assignments[env_ids]] = 0 # ZL: Clear out the assignment counters for these ones. Clean out self. Should not need to do this? + available_fall_ids = (self.availalbe_fall_states == 0).nonzero() + assert (available_fall_ids.shape[0] >= env_ids.shape[0]) + fall_state_ids = available_fall_ids[torch.randperm(available_fall_ids.shape[0])][:env_ids.shape[0]].squeeze(-1) + self._humanoid_root_states[env_ids] = self._fall_root_states[fall_state_ids] + self._dof_pos[env_ids] = self._fall_dof_pos[fall_state_ids] + self._dof_vel[env_ids] = self._fall_dof_vel[fall_state_ids] + self._recovery_counter[env_ids] = self._recovery_steps + self._reset_fall_env_ids = env_ids + + self.availalbe_fall_states[fall_state_ids] = 1 + self.fall_id_assignments[env_ids] = fall_state_ids + return + + def _reset_envs(self, env_ids): + self._reset_fall_env_ids = [] + super()._reset_envs(env_ids) + + return + + def _init_amp_obs(self, env_ids): + super()._init_amp_obs(env_ids) + + if (len(self._reset_fall_env_ids) > 0): + self._init_amp_obs_default(self._reset_fall_env_ids) + + return + + def _update_recovery_count(self): + self._recovery_counter -= 1 + self._recovery_counter = torch.clamp_min(self._recovery_counter, 0) + return + + def _compute_reset(self): + super()._compute_reset() + + is_recovery = self._recovery_counter > 0 + self.reset_buf[is_recovery] = 0 + self._terminate_buf[is_recovery] = 0 + self.progress_buf[is_recovery] -= 1 # ZL: do not advance progress buffer for these. + return diff --git a/phc/env/tasks/humanoid_im_mcp.py b/phc/env/tasks/humanoid_im_mcp.py new file mode 100644 index 0000000..268108f --- /dev/null +++ b/phc/env/tasks/humanoid_im_mcp.py @@ -0,0 +1,92 @@ +import time +import torch +import phc.env.tasks.humanoid_im as humanoid_im + +from isaacgym.torch_utils import * +from phc.utils.flags import flags +from rl_games.algos_torch import torch_ext +import torch.nn as nn +from phc.learning.pnn import PNN +from collections import deque +from phc.learning.network_loader import load_mcp_mlp, load_pnn + +class HumanoidImMCP(humanoid_im.HumanoidIm): + + def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, headless): + self.num_prim = cfg["env"].get("num_prim", 3) + self.discrete_mcp = cfg["env"].get("discrete_moe", False) + self.has_pnn = cfg["env"].get("has_pnn", False) + self.has_lateral = cfg["env"].get("has_lateral", False) + self.z_activation = cfg["env"].get("z_activation", "relu") + + super().__init__(cfg=cfg, sim_params=sim_params, physics_engine=physics_engine, device_type=device_type, device_id=device_id, headless=headless) + + if self.has_pnn: + assert (len(self.models_path) == 1) + pnn_ck = torch_ext.load_checkpoint(self.models_path[0]) + self.pnn = load_pnn(pnn_ck, num_prim = self.num_prim, has_lateral = self.has_lateral, activation = self.z_activation, device = self.device) + self.running_mean, self.running_var = pnn_ck['running_mean_std']['running_mean'], pnn_ck['running_mean_std']['running_var'] + + self.fps = deque(maxlen=90) + + return + + def _setup_character_props(self, key_bodies): + super()._setup_character_props(key_bodies) + self._num_actions = self.num_prim + return + + def get_task_obs_size_detail(self): + task_obs_detail = super().get_task_obs_size_detail() + task_obs_detail['num_prim'] = self.num_prim + return task_obs_detail + + def step(self, weights): + + # if self.dr_randomizations.get('actions', None): + # actions = self.dr_randomizations['actions']['noise_lambda'](actions) + # if flags.server_mode: + # t_s = time.time() + + with torch.no_grad(): + # Apply trained Model. + curr_obs = ((self.obs_buf - self.running_mean.float()) / torch.sqrt(self.running_var.float() + 1e-05)) + + curr_obs = torch.clamp(curr_obs, min=-5.0, max=5.0) + if self.discrete_mcp: + max_idx = torch.argmax(weights, dim=1) + weights = torch.nn.functional.one_hot(max_idx, num_classes=self.num_prim).float() + + if self.has_pnn: + _, actions = self.pnn(curr_obs) + + x_all = torch.stack(actions, dim=1) + else: + x_all = torch.stack([net(curr_obs) for net in self.actors], dim=1) + # print(weights) + actions = torch.sum(weights[:, :, None] * x_all, dim=1) + + # actions = x_all[:, 3] # Debugging + # apply actions + self.pre_physics_step(actions) + + # step physics and render each frame + self._physics_step() + + # to fix! + if self.device == 'cpu': + self.gym.fetch_results(self.sim, True) + + # compute observations, rewards, resets, ... + self.post_physics_step() + # if flags.server_mode: + # dt = time.time() - t_s + # print(f'\r {1/dt:.2f} fps', end='') + + # dt = time.time() - t_s + # self.fps.append(1/dt) + # print(f'\r {np.mean(self.fps):.2f} fps', end='') + + + if self.dr_randomizations.get('observations', None): + self.obs_buf = self.dr_randomizations['observations']['noise_lambda'](self.obs_buf) diff --git a/phc/env/tasks/humanoid_im_mcp_demo.py b/phc/env/tasks/humanoid_im_mcp_demo.py new file mode 100644 index 0000000..099c5af --- /dev/null +++ b/phc/env/tasks/humanoid_im_mcp_demo.py @@ -0,0 +1,322 @@ + +import os +import torch +import numpy as np +from phc.utils.torch_utils import quat_to_tan_norm +import phc.env.tasks.humanoid_im_mcp as humanoid_im_mcp +import phc.env.tasks.humanoid_im as humanoid_im +from phc.env.tasks.humanoid_amp import HumanoidAMP, remove_base_rot +from phc.utils.motion_lib_smpl import MotionLibSMPL + +from phc.utils import torch_utils + +from isaacgym import gymapi +from isaacgym import gymtorch +from isaacgym.torch_utils import * +from phc.utils.flags import flags +import joblib +import gc +from collections import defaultdict +from scipy.spatial.transform import Rotation as sRot +import phc.utils.pytorch3d_transforms as ptr +from poselib.poselib.skeleton.skeleton3d import SkeletonMotion, SkeletonState + +import aiohttp, cv2, asyncio, json +import requests +from collections import deque +import scipy.ndimage.filters as filters +from smpl_sim.utils.transform_utils import quat_correct_two_batch +import subprocess + +SERVER = "0.0.0.0" +smpl_2_mujoco = [0, 1, 4, 7, 10, 2, 5, 8, 11, 3, 6, 9, 12, 15, 13, 16, 18, 20, 22, 14, 17, 19, 21, 23] + + +class HumanoidImMCPDemo(humanoid_im_mcp.HumanoidImMCP): + + def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, headless): + super().__init__(cfg=cfg, sim_params=sim_params, physics_engine=physics_engine, device_type=device_type, device_id=device_id, headless=headless) + + ## Debugging + # self.res_data = joblib.load("/home/zhengyiluo5/dev/meta/HybrIK/ik_res.pkl") + # self.rot_mat_ref = torch.from_numpy(sRot.from_rotvec(np.array(self.res_data['pose_aa']).reshape(-1, 3)).as_matrix().reshape(-1, 24, 3, 3)).float().to(self.device) + ## Debugging + + self.local_translation_batch = self.skeleton_trees[0].local_translation[None,] + self.parent_indices = self.skeleton_trees[0].parent_indices + self.pose_mat = torch.eye(3).repeat(self.num_envs, 24, 1, 1).to(self.device) + self.trans = torch.zeros(self.num_envs, 3).to(self.device) + + self.prev_ref_body_pos = torch.zeros(self.num_envs, 24, 3).to(self.device) + self.prev_ref_body_rot = torch.zeros(self.num_envs, 24, 4).to(self.device) + + self.zero_trans = torch.zeros([self.num_envs, 3]) + self.s_dt = 1 / 30 + + self.to_isaac_mat = torch.from_numpy(sRot.from_euler('xyz', np.array([-np.pi / 2, 0, 0]), degrees=False).as_matrix()).float() + self.to_global = torch.from_numpy(sRot.from_quat([0.5, 0.5, 0.5, 0.5]).inv().as_matrix()).float() + + self.root_pos_acc = deque(maxlen=30) + self.body_rot_acc = deque(maxlen=30) + self.body_pos_acc = deque(maxlen=30) + + flags.no_collision_check = True + flags.show_traj = True + self.close_distance = 0.5 + self.mean_limb_lengths = np.array([0.1061, 0.3624, 0.4015, 0.1384, 0.1132], dtype=np.float32)[None, :] + + async def talk(self): + URL = f'http://{SERVER}:8080/ws' + print("Starting websocket client") + session = aiohttp.ClientSession() + async with session.ws_connect(URL) as ws: + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + if msg.data == 'close cmd': + await ws.close() + break + else: + print(msg.data) + try: + msg = json.loads(msg.data) + if msg['action'] == 'reset': + self.reset() + elif msg['action'] == 'start_record': + subprocess.Popen(["simplescreenrecorder", "--start-recording"]) + print("start recording!!!!") + # self.recording = True + elif msg['action'] == 'end_record': + print("end_recording!!!!") + if not self.recording: + print("Not recording") + else: + self.recording = False + self.recording_state_change = True + elif msg['action'] == 'set_env': + query = msg['query'] + env_id = query['env'] + self.viewing_env_idx = int(env_id) + print("view env idx: ", self.viewing_env_idx) + except: + import ipdb + ipdb.set_trace() + print("error parsing server message") + elif msg.type == aiohttp.WSMsgType.CLOSED: + break + elif msg.type == aiohttp.WSMsgType.ERROR: + break + + def _update_marker(self): + if flags.show_traj: + self._marker_pos[:] = self.ref_body_pos + else: + self._marker_pos[:] = 0 + + # ######### Heading debug ####### + # points = self.init_root_points() + # base_quat = self._rigid_body_rot[0, 0:1] + # base_quat = remove_base_rot(base_quat) + # heading_rot = torch_utils.calc_heading_quat(base_quat) + # show_points = quat_apply(heading_rot.repeat(1, points.shape[0]).reshape(-1, 4), points) + (self._rigid_body_pos[0, 0:1]).unsqueeze(1) + # self._marker_pos[:] = show_points[:, :self._marker_pos.shape[1]] + # ######### Heading debug ####### + + self.gym.set_actor_root_state_tensor_indexed(self.sim, gymtorch.unwrap_tensor(self._root_states), gymtorch.unwrap_tensor(self._marker_actor_ids), len(self._marker_actor_ids)) + + return + + def _compute_observations(self, env_ids=None): + # env_ids is used for resetting + if env_ids is None: + env_ids = torch.arange(self.num_envs).to(self.device) + + self_obs = self._compute_humanoid_obs(env_ids) + self.self_obs_buf[env_ids] = self_obs + + if (self._enable_task_obs): + task_obs = self._compute_task_obs_demo(env_ids) + obs = torch.cat([self_obs, task_obs], dim=-1) + else: + obs = self_obs + + if self.obs_v == 4: + # Double sub will return a copy. + B, N = obs.shape + sums = self.obs_buf[env_ids, 0:10].abs().sum(dim=1) + zeros = sums == 0 + nonzero = ~zeros + obs_slice = self.obs_buf[env_ids] + obs_slice[zeros] = torch.tile(obs[zeros], (1, 5)) + obs_slice[nonzero] = torch.cat([obs_slice[nonzero, N:], obs[nonzero]], dim=-1) + self.obs_buf[env_ids] = obs_slice + else: + self.obs_buf[env_ids] = obs + return obs + + def _compute_task_obs_demo(self, env_ids=None): + if (env_ids is None): + body_pos = self._rigid_body_pos + body_rot = self._rigid_body_rot + body_vel = self._rigid_body_vel + body_ang_vel = self._rigid_body_ang_vel + env_ids = torch.arange(self.num_envs, dtype=torch.long, device=self.device) + else: + body_pos = self._rigid_body_pos[env_ids] + body_rot = self._rigid_body_rot[env_ids] + body_vel = self._rigid_body_vel[env_ids] + body_ang_vel = self._rigid_body_ang_vel[env_ids] + + root_pos = body_pos[..., 0, :] + root_rot = body_rot[..., 0, :] + + body_pos_subset = body_pos[..., self._track_bodies_id, :] + body_rot_subset = body_rot[..., self._track_bodies_id, :] + body_vel_subset = body_vel[..., self._track_bodies_id, :] + body_ang_vel_subset = body_ang_vel[..., self._track_bodies_id, :] + + if self.obs_v == 6: + raise NotImplementedError + # This part is not as good. use obs_v == 7 instead. + # ref_rb_pos = self.j3d[((self.progress_buf[env_ids] + 1) / 2).long() % self.j3d.shape[0]] + # ref_body_vel = self.j3d_vel[((self.progress_buf[env_ids] + 1) / 2).long() % self.j3d_vel.shape[0]] + # pose_mat = self.pose_mat.clone() + # trans = self.trans.clone() + + # pose_mat = self.rot_mat_ref[((self.progress_buf[env_ids] + 1) / 2).long() % self.rot_mat_ref.shape[0]] # debugging + pose_res = requests.get(f'http://{SERVER}:8080/get_pose') + json_data = pose_res.json() + pose_mat = torch.tensor(json_data["pose_mat"])[None,].float() + # trans = torch.tensor(json_data["trans"]).to(self.device).float() + + trans = np.array(json_data["trans"]).squeeze() + s_dt = json_data['dt'] + self.root_pos_acc.append(trans) + filtered_trans = filters.gaussian_filter1d(self.root_pos_acc, 3, axis=0, mode="mirror") + trans = torch.tensor(filtered_trans[-1]).float() + + new_root = self.to_isaac_mat.matmul(pose_mat[:, 0]) + pose_mat[:, 0] = new_root + trans = trans.matmul(self.to_isaac_mat.T) + _, global_rotation = humanoid_kin.forward_kinematics_batch(pose_mat[:, smpl_2_mujoco], self.zero_trans, self.local_translation_batch, self.parent_indices) + + ref_rb_rot = ptr.matrix_to_quaternion_ijkr(global_rotation.matmul(self.to_global)) + + ################## ################## + ref_rb_rot_np = ref_rb_rot.numpy()[0] + + if len(self.body_rot_acc) > 0: + ref_rb_rot_np = quat_correct_two_batch(self.body_rot_acc[-1], ref_rb_rot_np) + filtered_quats = filters.gaussian_filter1d(np.concatenate([self.body_rot_acc, ref_rb_rot_np[None,]], axis=0), 1, axis=0, mode="mirror") + new_quat = filtered_quats[-1] / np.linalg.norm(filtered_quats[-1], axis=1)[:, None] + self.body_rot_acc.append(new_quat) # add the filtered quat. + + # pose_quat_global = np.array(self.body_rot_acc) + # select_quats = np.linalg.norm(pose_quat_global[:-1, :] - pose_quat_global[1:, :], axis=2) > np.linalg.norm(pose_quat_global[:-1, :] + pose_quat_global[1:, :], axis=2) + ref_rb_rot = torch.tensor(new_quat[None,]).float() + else: + self.body_rot_acc.append(ref_rb_rot_np) + + ################## ################## + + ref_rb_pos = SkeletonState.from_rotation_and_root_translation(self.skeleton_trees[0], ref_rb_rot, trans, is_local=False).global_translation.to(self.device) # SLOWWWWWWW + ref_rb_rot = ref_rb_rot.to(self.device) + ref_rb_pos = ref_rb_pos.to(self.device) + ref_body_ang_vel = SkeletonMotion._compute_angular_velocity(torch.stack([self.prev_ref_body_rot, ref_rb_rot], dim=1), time_delta=s_dt, guassian_filter=False)[:, 0] + ref_body_vel = SkeletonMotion._compute_velocity(torch.stack([self.prev_ref_body_pos, ref_rb_pos], dim=1), time_delta=s_dt, guassian_filter=False)[:, 0] # this is slow! + + + time_steps = 1 + ref_rb_pos_subset = ref_rb_pos[..., self._track_bodies_id, :] + ref_body_vel_subset = ref_body_vel[..., self._track_bodies_id, :] + ref_rb_rot_subset = ref_rb_rot[..., self._track_bodies_id, :] + ref_body_ang_vel_subset = ref_body_ang_vel[..., self._track_bodies_id, :] + + if self.zero_out_far: + close_distance = self.close_distance + distance = torch.norm(root_pos - ref_rb_pos_subset[..., 0, :], dim=-1) + + zeros_subset = distance > close_distance + ref_rb_pos_subset[zeros_subset, 1:] = body_pos_subset[zeros_subset, 1:] + ref_rb_rot_subset[zeros_subset, 1:] = body_rot_subset[zeros_subset, 1:] + ref_body_vel_subset[zeros_subset, :] = body_vel_subset[zeros_subset, :] + ref_body_ang_vel_subset[zeros_subset, :] = body_ang_vel_subset[zeros_subset, :] + + far_distance = 3 # does not seem to need this in particular... + vector_zero_subset = distance > far_distance # > 5 meters, it become just a direction + ref_rb_pos_subset[vector_zero_subset, 0] = ((ref_rb_pos_subset[vector_zero_subset, 0] - body_pos_subset[vector_zero_subset, 0]) / distance[vector_zero_subset, None] * far_distance) + body_pos_subset[vector_zero_subset, 0] + + obs = humanoid_im.compute_imitation_observations_v6(root_pos, root_rot, body_pos_subset, body_rot_subset, body_vel_subset, body_ang_vel_subset, ref_rb_pos_subset, ref_rb_rot_subset, ref_body_vel_subset, ref_body_ang_vel_subset, time_steps, self._has_upright_start) + + self.prev_ref_body_pos = ref_rb_pos + self.prev_ref_body_rot = ref_rb_rot + elif self.obs_v == 7: + pose_res = requests.get(f'http://{SERVER}:8080/get_pose') + json_data = pose_res.json() + ref_rb_pos = np.array(json_data["j3d"])[:self.num_envs, smpl_2_mujoco] + trans = ref_rb_pos[:, [0]] + + # if len(self.root_pos_acc) > 0 and np.linalg.norm(trans - self.root_pos_acc[-1]) > 1: + # import ipdb; ipdb.set_trace() + # print("juping!!") + ref_rb_pos_orig = ref_rb_pos.copy() + + ref_rb_pos = ref_rb_pos - trans + ############################## Limb Length ############################## + limb_lengths = [] + for i in range(6): + parent = self.skeleton_trees[0].parent_indices[i] + if parent != -1: + limb_lengths.append(np.linalg.norm(ref_rb_pos[:, parent] - ref_rb_pos[:, i], axis = -1)) + limb_lengths = np.array(limb_lengths).transpose(1, 0) + scale = (limb_lengths/self.mean_limb_lengths).mean(axis = -1) + ref_rb_pos /= scale[:, None, None] + ############################## Limb Length ############################## + s_dt = 1/30 + + self.root_pos_acc.append(trans) + filtered_root_trans = np.array(self.root_pos_acc) + filtered_root_trans[..., 2] = filters.gaussian_filter1d(filtered_root_trans[..., 2], 10, axis=0, mode="mirror") # More filtering on the root translation + filtered_root_trans[..., :2] = filters.gaussian_filter1d(filtered_root_trans[..., :2], 5, axis=0, mode="mirror") + trans = filtered_root_trans[-1] + + self.body_pos_acc.append(ref_rb_pos) + body_pos = np.array(self.body_pos_acc) + filtered_ref_rb_pos = filters.gaussian_filter1d(body_pos, 2, axis=0, mode="mirror") + ref_rb_pos = filtered_ref_rb_pos[-1] + + ref_rb_pos = torch.from_numpy(ref_rb_pos + trans).float() + ref_rb_pos = ref_rb_pos.matmul(self.to_isaac_mat.T).cuda() + + ref_body_vel = SkeletonMotion._compute_velocity(torch.stack([self.prev_ref_body_pos, ref_rb_pos], dim=1), time_delta=s_dt, guassian_filter=False)[:, 0] # + + time_steps = 1 + ref_rb_pos_subset = ref_rb_pos[..., self._track_bodies_id, :] + ref_body_vel_subset = ref_body_vel[..., self._track_bodies_id, :] + + if self.zero_out_far: + close_distance = self.close_distance + distance = torch.norm(root_pos - ref_rb_pos_subset[..., 0, :], dim=-1) + + zeros_subset = distance > close_distance + ref_rb_pos_subset[zeros_subset, 1:] = body_pos_subset[zeros_subset, 1:] + ref_body_vel_subset[zeros_subset, :] = body_vel_subset[zeros_subset, :] + + far_distance = self.far_distance # does not seem to need this in particular... + vector_zero_subset = distance > far_distance # > 5 meters, it become just a direction + ref_rb_pos_subset[vector_zero_subset, 0] = ((ref_rb_pos_subset[vector_zero_subset, 0] - body_pos_subset[vector_zero_subset, 0]) / distance[vector_zero_subset, None] * far_distance) + body_pos_subset[vector_zero_subset, 0] + + obs = humanoid_im.compute_imitation_observations_v7(root_pos, root_rot, body_pos_subset, body_vel_subset, ref_rb_pos_subset, ref_body_vel_subset, time_steps, self._has_upright_start) + + self.prev_ref_body_pos = ref_rb_pos + + if len(env_ids) == self.num_envs: + self.ref_body_pos = ref_rb_pos + self.ref_body_pos_subset = torch.from_numpy(ref_rb_pos_orig) + self.ref_pose_aa = None + + return obs + + def _compute_reset(self): + self.reset_buf[:] = 0 + self._terminate_buf[:] = 0 + diff --git a/phc/env/tasks/humanoid_im_mcp_getup.py b/phc/env/tasks/humanoid_im_mcp_getup.py new file mode 100644 index 0000000..d11933f --- /dev/null +++ b/phc/env/tasks/humanoid_im_mcp_getup.py @@ -0,0 +1,31 @@ + + +from typing import OrderedDict +import torch +import numpy as np +from phc.utils.torch_utils import quat_to_tan_norm +import phc.env.tasks.humanoid_im_getup as humanoid_im_getup +import phc.env.tasks.humanoid_im_mcp as humanoid_im_mcp +from phc.env.tasks.humanoid_amp import HumanoidAMP, remove_base_rot +from phc.utils.motion_lib_smpl import MotionLibSMPL + +from phc.utils import torch_utils + +from isaacgym import gymapi +from isaacgym import gymtorch +from isaacgym.torch_utils import * +from phc.utils.flags import flags +import joblib +import gc +from poselib.poselib.skeleton.skeleton3d import SkeletonMotion, SkeletonState +from rl_games.algos_torch import torch_ext +import torch.nn as nn +from phc.learning.network_loader import load_mcp_mlp, load_pnn +from collections import deque + +class HumanoidImMCPGetup(humanoid_im_getup.HumanoidImGetup, humanoid_im_mcp.HumanoidImMCP): + + def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, headless): + super().__init__(cfg=cfg, sim_params=sim_params, physics_engine=physics_engine, device_type=device_type, device_id=device_id, headless=headless) + return + diff --git a/phc/env/tasks/humanoid_pedestrain_terrain_z.py b/phc/env/tasks/humanoid_pedestrain_terrain_z.py new file mode 100644 index 0000000..42446cc --- /dev/null +++ b/phc/env/tasks/humanoid_pedestrain_terrain_z.py @@ -0,0 +1,162 @@ +import time +import torch +import phc.env.tasks.humanoid_pedestrian_terrain as humanoid_pedestrain_terrain +from phc.env.tasks.humanoid_amp import remove_base_rot +from phc.utils import torch_utils +from typing import OrderedDict + +from isaacgym.torch_utils import * +from phc.utils.flags import flags +from rl_games.algos_torch import torch_ext +import torch.nn as nn +from collections import deque +from phc.learning.network_loader import load_z_encoder, load_z_decoder +from phc.utils.torch_utils import project_to_norm + +ENABLE_MAX_COORD_OBS = True + +class HumanoidPedestrianTerrainZ(humanoid_pedestrain_terrain.HumanoidPedestrianTerrain): + + def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, headless): + super().__init__(cfg=cfg, sim_params=sim_params, physics_engine=physics_engine, device_type=device_type, device_id=device_id, headless=headless) + + self.models_path = cfg["env"].get("models", ['output/dgx/smpl_im_fit_3_1/Humanoid_00185000.pth']) + check_points = [torch_ext.load_checkpoint(ck_path) for ck_path in self.models_path] + + ### Loading Distill Model ### + self.distill_model_config = self.cfg['env']['distill_model_config'] + self.embedding_size_distill = self.distill_model_config['embedding_size'] + self.embedding_norm_distill = self.distill_model_config['embedding_norm'] + self.fut_tracks_distill = self.distill_model_config['fut_tracks'] + self.num_traj_samples_distill = self.distill_model_config['numTrajSamples'] + self.traj_sample_timestep_distill = self.distill_model_config['trajSampleTimestepInv'] + self.fut_tracks_dropout_distill = self.distill_model_config['fut_tracks_dropout'] + self.z_activation = self.distill_model_config['z_activation'] + self.distill_z_type = self.distill_model_config.get("z_type", "sphere") + + self.embedding_partition_distill = self.distill_model_config.get("embedding_partion", 1) + self.dict_size_distill = self.distill_model_config.get("dict_size", 1) + ### Loading Distill Model ### + + self.z_all = self.cfg['env'].get("z_all", False) + self.use_vae_prior = self.cfg['env'].get("use_vae_prior", False) + self.running_mean, self.running_var = check_points[-1]['running_mean_std']['running_mean'], check_points[-1]['running_mean_std']['running_var'] + + self.decoder = load_z_decoder(check_points[0], activation = self.z_activation, z_type = self.distill_z_type, device = self.device) + self.encoder = load_z_encoder(check_points[0], activation = self.z_activation, z_type = self.distill_z_type, device = self.device) + self.power_acc = torch.zeros((self.num_envs, 2 )).to(self.device) + + return + + def get_task_obs_size_detail(self): + task_obs_detail = super().get_task_obs_size_detail() + + ### For Z + task_obs_detail['proj_norm'] = self.cfg['env'].get("proj_norm", True) + task_obs_detail['embedding_norm'] = self.cfg['env'].get("embedding_norm", 3) + task_obs_detail['embedding_size'] = self.cfg['env'].get("embedding_size", 256) + task_obs_detail['z_readout'] = self.cfg['env'].get("z_readout", False) + task_obs_detail['z_type'] = self.cfg['env'].get("z_type", "sphere") + task_obs_detail['num_unique_motions'] = self._motion_lib._num_unique_motions + + + return task_obs_detail + + def _setup_character_props(self, key_bodies): + super()._setup_character_props(key_bodies) + self._num_actions = self.cfg['env'].get("embedding_size", 256) + return + + def step(self, action_z): + + # if self.dr_randomizations.get('actions', None): + # actions = self.dr_randomizations['actions']['noise_lambda'](actions) + # if flags.server_mode: + # t_s = time.time() + # t_s = time.time() + with torch.no_grad(): + # Apply trained Model. + + ################ GT-Z ################ + + self_obs_size = self.get_self_obs_size() + self_obs = (self.obs_buf[:, :self_obs_size] - self.running_mean.float()[:self_obs_size]) / torch.sqrt(self.running_var.float()[:self_obs_size] + 1e-05) + if self.distill_z_type == "hyper": + action_z = self.decoder.hyper_layer(action_z) + if self.distill_z_type == "vq_vae": + + if self.is_discrete: + indexes = action_z + else: + B, F = action_z.shape + indexes = action_z.reshape(B, -1, self.embedding_size_distill).argmax(dim = -1) + task_out_proj = self.decoder.quantizer.embedding.weight[indexes.view(-1)] + print(f"\r {indexes.numpy()[0]}", end = '') + action_z = task_out_proj.view(-1, self.embedding_size_distill) + + elif self.distill_z_type == "vae": + if self.use_vae_prior: + z_prior_out = self.decoder.z_prior(self_obs) + prior_mu, prior_log_var = self.decoder.z_prior_mu(z_prior_out), self.decoder.z_prior_logvar(z_prior_out) + action_z = prior_mu + action_z + else: + pass + else: + action_z = project_to_norm(action_z, self.cfg['env'].get("embedding_norm", 5), self.distill_z_type) + + if self.z_all: + x_all = self.decoder.decoder(action_z) + else: + self_obs = torch.clamp(self_obs, min=-5.0, max=5.0) + x_all = self.decoder.decoder(torch.cat([self_obs, action_z], dim = -1)) + + # z_prior_out = self.decoder.z_prior(self_obs); prior_mu, prior_log_var = self.decoder.z_prior_mu(z_prior_out), self.decoder.z_prior_logvar(z_prior_out); print(prior_mu.max(), prior_mu.min()) + # print('....') + actions = x_all + + + # actions = x_all[:, 3] # Debugging + # apply actions + self.pre_physics_step(actions) + + # step physics and render each frame + self._physics_step() + + # to fix! + if self.device == 'cpu': + self.gym.fetch_results(self.sim, True) + + # compute observations, rewards, resets, ... + self.post_physics_step() + if flags.server_mode: + dt = time.time() - t_s + print(f'\r {1/dt:.2f} fps', end='') + + # dt = time.time() - t_s + # self.fps.append(1/dt) + # print(f'\r {np.mean(self.fps):.2f} fps', end='') + + + if self.dr_randomizations.get('observations', None): + self.obs_buf = self.dr_randomizations['observations']['noise_lambda'](self.obs_buf) + + +@torch.jit.script +def compute_z_target(root_pos, root_rot, ref_body_pos, ref_body_vel, time_steps, upright): + # type: (Tensor, Tensor, Tensor, Tensor, int, bool) -> Tensor + # No rotation information. Leave IK for RL. + # Future tracks in this obs will not contain future diffs. + obs = [] + B, J, _ = ref_body_pos.shape + + if not upright: + root_rot = remove_base_rot(root_rot) + + heading_inv_rot = torch_utils.calc_heading_quat_inv(root_rot) + heading_rot = torch_utils.calc_heading_quat(root_rot) + heading_inv_rot_expand = heading_inv_rot.unsqueeze(-2).repeat((1, J, 1)).repeat_interleave(time_steps, 0) + local_ref_body_pos = ref_body_pos.view(B, time_steps, J, 3) - root_pos.view(B, 1, 1, 3) # preserves the body position + local_ref_body_pos = torch_utils.my_quat_rotate(heading_inv_rot_expand.view(-1, 4), local_ref_body_pos.view(-1, 3)) + + + return local_ref_body_pos.view(B, J, -1) \ No newline at end of file diff --git a/phc/env/tasks/humanoid_pedestrian_terrain.py b/phc/env/tasks/humanoid_pedestrian_terrain.py new file mode 100644 index 0000000..46010e0 --- /dev/null +++ b/phc/env/tasks/humanoid_pedestrian_terrain.py @@ -0,0 +1,1737 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +from shutil import ExecError +import torch +import numpy as np + +import env.tasks.humanoid_traj as humanoid_traj +from isaacgym import gymapi +from isaacgym.torch_utils import * +from env.tasks.humanoid import dof_to_obs +from env.tasks.humanoid_amp import HumanoidAMP, remove_base_rot +from phc.utils.flags import flags +from utils import torch_utils +from isaacgym import gymtorch +import joblib +from poselib.poselib.core.rotation3d import quat_inverse, quat_mul +from tqdm import tqdm +from scipy.spatial.transform import Rotation as sRot +import matplotlib.pyplot as plt +from typing import OrderedDict +from phc.utils.draw_utils import agt_color +from phc.env.tasks.humanoid import compute_humanoid_observations_smpl_max, compute_humanoid_observations_smpl, ENABLE_MAX_COORD_OBS + +HACK_MOTION_SYNC = False + +class HumanoidPedestrianTerrain(humanoid_traj.HumanoidTraj): + def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, + headless): + ## ZL Hack to get the height map to load. + self.velocity_map = cfg["env"].get("velocity_map", False) + self.device_type = cfg.get("device_type", "cuda") + self.device_id = cfg.get("device_id", 0) + self.device = "cpu" + if self.device_type == "cuda" or self.device_type == "GPU": + self.device = "cuda" + ":" + str(self.device_id) + + # self.real_mesh = cfg['args'].real_mesh + self.real_mesh = False + self.load_smpl_configs(cfg) + self.cfg = cfg + self.num_envs = cfg["env"]["num_envs"] + self.headless = cfg["headless"] + self.sensor_extent = cfg["env"].get("sensor_extent", 2) + self.sensor_res = cfg["env"].get("sensor_res", 32) + self.power_reward = cfg["env"].get("power_reward", False) + self.power_coefficient = cfg["env"].get("power_coefficient", 0.0005) + self.fuzzy_target = cfg["env"].get("fuzzy_target", False) + + + self.square_height_points = self.init_square_height_points() + self.terrain_obs_type = self.cfg['env'].get("terrain_obs_type", + "square") + self.terrain_obs = self.cfg['env'].get("terrain_obs", False) + self.terrain_obs_root = self.cfg['env'].get("terrain_obs_root", + "pelvis") + if self.terrain_obs_type == "fov": + self.height_points = self.init_fov_height_points() + elif self.terrain_obs_type == "square_fov": + self.height_points = self.init_square_fov_height_points() + elif self.terrain_obs_type == "square": + self.height_points = self.square_height_points + self.root_points = self.init_root_points() + + self.center_height_points = self.init_center_height_points() + self.height_meas_scale = 5 + + self.show_sensors = self.cfg['args'].show_sensors + if (not self.headless) and self.show_sensors: + self._sensor_handles = [[] for _ in range(self.num_envs)] + + super().__init__(cfg=cfg, + sim_params=sim_params, + physics_engine=physics_engine, + device_type=device_type, + device_id=device_id, + headless=headless) + + self.reward_raw = torch.zeros((self.num_envs, 2)).to(self.device) + + if (not self.headless) and self.show_sensors: + self._build_sensor_state_tensors() + + return + + def _build_env(self, env_id, env_ptr, humanoid_asset): + super()._build_env(env_id, env_ptr, humanoid_asset) + + if (not self.headless) and self.show_sensors: + self._load_sensor_asset() + self._build_sensor(env_id, env_ptr) + + return + + def _build_sensor(self, env_id, env_ptr): + default_pose = gymapi.Transform() + + for i in range(self.num_height_points): + marker_handle = self.gym.create_actor(env_ptr, self._sensor_asset, + default_pose, "marker", + self.num_envs + 1, 0, 0) + self.gym.set_rigid_body_color(env_ptr, marker_handle, 0, + gymapi.MESH_VISUAL, + gymapi.Vec3(*agt_color(env_id))) + self._sensor_handles[env_id].append(marker_handle) + + return + + def _build_sensor_state_tensors(self): + num_actors = self._root_states.shape[0] // self.num_envs + self._sensor_states = self._root_states.view(self.num_envs, num_actors, self._root_states.shape[-1])[..., 11:(11 + self.num_height_points), :] + self._sensor_pos = self._sensor_states[..., :3] + self._sensor_actor_ids = self._humanoid_actor_ids.unsqueeze(-1) + to_torch(self._sensor_handles, dtype=torch.int32, device=self.device) + self._sensor_actor_ids = self._sensor_actor_ids.flatten() + return + + def _load_sensor_asset(self): + asset_root = "amp/data/assets/mjcf/" + asset_file = "sensor_marker.urdf" + + asset_options = gymapi.AssetOptions() + asset_options.angular_damping = 0.01 + asset_options.linear_damping = 0.01 + asset_options.max_angular_velocity = 100.0 + asset_options.density = 1.0 + asset_options.fix_base_link = True + asset_options.default_dof_drive_mode = gymapi.DOF_MODE_NONE + + self._sensor_asset = self.gym.load_asset(self.sim, asset_root, + asset_file, asset_options) + + return + + def _draw_task(self): + # cols = np.array([[1.0, 0.0, 0.0]], dtype=np.float32) + + norm_states = self.get_head_pose() + base_quat = norm_states[:, 3:7] + if not self._has_upright_start: + base_quat = remove_base_rot(base_quat) + heading_rot = torch_utils.calc_heading_quat(base_quat) + + points = quat_apply( + heading_rot.repeat(1, self.num_height_points).reshape(-1, 4), + self.height_points) + (norm_states[:, :3]).unsqueeze(1) + + if (not self.headless) and self.show_sensors: + self._sensor_pos[:] = points + # self._sensor_pos[..., 2] += 0.3 + # self._sensor_pos[..., 2] -= 5 + + traj_samples = self._fetch_traj_samples() + + self._marker_pos[:] = traj_samples + self._marker_pos[..., 2] = self._humanoid_root_states[..., 2:3] # jp hack # ZL hack + # self._marker_pos[..., 2] = 0.89 + # self._marker_pos[..., 2] = 0 + + if (not self.headless) and self.show_sensors: + comb_idx = torch.cat([self._sensor_actor_ids, self._marker_actor_ids]) + else: + comb_idx = torch.cat([self._marker_actor_ids]) + + if flags.show_traj: + self.gym.set_actor_root_state_tensor_indexed( + self.sim, gymtorch.unwrap_tensor(self._root_states), + gymtorch.unwrap_tensor(comb_idx), len(comb_idx)) + + self.gym.clear_lines(self.viewer) + + for i, env_ptr in enumerate(self.envs): + verts = self._traj_gen.get_traj_verts(i) + verts[..., 2] = self._humanoid_root_states[i, 2] # ZL Hack + # verts[..., 2] = 0.89 + # verts[..., 2] = 0 + lines = torch.cat([verts[:-1], verts[1:]], dim=-1).cpu().numpy() + # cols = np.array([[1.0, 0.0, 0.0]], dtype=np.float32) + cols = np.array(agt_color(i), dtype=np.float32)[None, ] + curr_cols = np.broadcast_to(cols, [lines.shape[0], cols.shape[-1]]) + self.gym.add_lines(self.viewer, env_ptr, lines.shape[0], lines, curr_cols) + else: + self._marker_pos[:] = 0 + self.gym.set_actor_root_state_tensor_indexed( + self.sim, gymtorch.unwrap_tensor(self._root_states), + gymtorch.unwrap_tensor(comb_idx), len(comb_idx)) + + self.gym.clear_lines(self.viewer) + + + return + + def _compute_humanoid_obs(self, env_ids=None): + if (ENABLE_MAX_COORD_OBS): + if (env_ids is None): + body_pos = self._rigid_body_pos.clone() + body_rot = self._rigid_body_rot + body_vel = self._rigid_body_vel + body_ang_vel = self._rigid_body_ang_vel + else: + body_pos = self._rigid_body_pos[env_ids] + body_rot = self._rigid_body_rot[env_ids] + body_vel = self._rigid_body_vel[env_ids] + body_ang_vel = self._rigid_body_ang_vel[env_ids] + if self.humanoid_type in ["smpl", "smplh", "smplx"]: + if (env_ids is None): + smpl_params = self.humanoid_shapes + limb_weights = self.humanoid_limb_and_weights + else: + smpl_params = self.humanoid_shapes[env_ids] + limb_weights = self.humanoid_limb_and_weights[env_ids] + + if self._root_height_obs: + center_height = self.get_center_heights(torch.cat([body_pos[:, 0], body_rot[:, 0]], dim=-1), env_ids=env_ids).mean(dim=-1, keepdim=True) + body_pos[:, :, 2] = body_pos[:, :, 2] - center_height + + obs = compute_humanoid_observations_smpl_max(body_pos, body_rot, body_vel, body_ang_vel, smpl_params, limb_weights, self._local_root_obs, self._root_height_obs, self._has_upright_start, self._has_shape_obs, self._has_limb_weight_obs) + else: + raise NotImplementedError + # obs = compute_humanoid_observations_max(body_pos, body_rot, body_vel, body_ang_vel, self._local_root_obs, self._root_height_obs) + + else: + if (env_ids is None): + root_pos = self._rigid_body_pos[:, 0, :] + root_rot = self._rigid_body_rot[:, 0, :] + root_vel = self._rigid_body_vel[:, 0, :] + root_ang_vel = self._rigid_body_ang_vel[:, 0, :] + dof_pos = self._dof_pos + dof_vel = self._dof_vel + key_body_pos = self._rigid_body_pos[:, self._key_body_ids, :] + else: + root_pos = self._rigid_body_pos[env_ids][:, 0, :] + root_rot = self._rigid_body_rot[env_ids][:, 0, :] + root_vel = self._rigid_body_vel[env_ids][:, 0, :] + root_ang_vel = self._rigid_body_ang_vel[env_ids][:, 0, :] + dof_pos = self._dof_pos[env_ids] + dof_vel = self._dof_vel[env_ids] + key_body_pos = self._rigid_body_pos[env_ids][:, self._key_body_ids, :] + + if self.humanoid_type in ["smpl", "smplh", "smplx"] and self.self.has_shape_obs: + if (env_ids is None): + smpl_params = self.humanoid_shapes + else: + smpl_params = self.humanoid_shapes[env_ids] + obs = compute_humanoid_observations_smpl(root_pos, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_body_pos, self._dof_obs_size, self._dof_offsets, smpl_params, self._local_root_obs, self._root_height_obs, self._has_upright_start, self._has_shape_obs) + else: + raise NotImplementedError + # obs = compute_humanoid_observations(root_pos, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_body_pos, self._local_root_obs, self._root_height_obs, self._dof_obs_size, self._dof_offsets) + return obs + + def get_task_obs_size(self): + obs_size = 0 + if (self._enable_task_obs): + + obs_size = 2 * self._num_traj_samples + + if self.terrain_obs: + if self.velocity_map: + obs_size += self.num_height_points * 3 + else: + obs_size += self.num_height_points + + if self._divide_group and self._group_obs: + obs_size += 5 * 11 * 3 + + return obs_size + + def get_self_obs_size(self): + return self._num_self_obs + + def get_task_obs_size_detail(self): + task_obs_detail = OrderedDict() + + + if (self._enable_task_obs): + task_obs_detail['traj'] = 2 * self._num_traj_samples + # task_obs_detail.append(["traj", 2 * self._num_traj_samples]) + + if self.terrain_obs: + if self.velocity_map: + task_obs_detail['heightmap_velocity'] = self.num_height_points * 3 + # task_obs_detail.append(["heightmap_velocity", self.num_height_points * 3]) + else: + task_obs_detail['heightmap'] = self.num_height_points + # task_obs_detail.append(["heightmap", self.num_height_points]) + + if self._divide_group and self._group_obs: + task_obs_detail['people'] = 5 * 11 * 3 + # task_obs_detail.append(["people", 5 * 11 * 3]) + + return task_obs_detail + + + def get_head_pose(self, env_ids=None): + if self.humanoid_type in ["smpl", "smplh", "smplx"]: + head_idx = self._body_names.index("Head") + else: + head_idx = 2 + head_pose = torch.cat([ + self._rigid_body_pos[:, head_idx], self._rigid_body_rot[:, + head_idx] + ], + dim=1) + if (env_ids is None): + return head_pose + else: + return head_pose[env_ids] + + # # ZL: Dev purposes only, will remove. + # def _fetch_traj_samples(self, env_ids=None): + # # 5 seconds with 0.5 second intervals, 10 samples. + # if (env_ids is None): + # env_ids = torch.arange(self.num_envs, + # device=self.device, + # dtype=torch.long) + + # timestep_beg = self.progress_buf[env_ids] * self.dt + # timesteps = torch.arange(self._num_traj_samples, + # device=self.device, + # dtype=torch.float) + # timesteps = timesteps * self._traj_sample_timestep + # traj_timesteps = timestep_beg.unsqueeze(-1) + timesteps + + # env_ids_tiled = torch.broadcast_to(env_ids.unsqueeze(-1), + # traj_timesteps.shape) + + # traj_samples_flat = self._traj_gen.calc_pos(env_ids_tiled.flatten(), + # traj_timesteps.flatten()) + # traj_samples = torch.reshape(traj_samples_flat, + # shape=(env_ids.shape[0], + # self._num_traj_samples, + # traj_samples_flat.shape[-1])) + + # traj_samples_flat = self._traj_gen.mock_calc_pos(env_ids, env_ids_tiled.flatten(), traj_timesteps.flatten(), self.query_value_gradient) + + # return traj_samples + + def update_value_func(self, eval_value_func, actor_func): + self.eval_value_func = eval_value_func + self.actor_func = actor_func + + def query_value_gradient(self, env_ids, new_traj): + # new_traj would be the same as self._fetch_traj_samples(env_ids) + # return callable value function and update_obs (processed with mean and std) + # value_func(obs) + # new_traj of shape (num_envs, 10, 3) + # TODO: implement this + if "eval_value_func" in self.__dict__: + sim_obs_size = self.get_self_obs_size() + task_obs_detal = self.get_task_obs_size_detail() + assert(task_obs_detal[0][0] == "traj") + + if (env_ids is None): + root_states = self._humanoid_root_states + else: + root_states = self._humanoid_root_states[env_ids] + + new_traj_obs = compute_location_observations(root_states, new_traj.view(env_ids.shape[0], 10, -1), self._has_upright_start) + buffered_obs = self.obs_buf[env_ids].clone() + buffered_obs[:, sim_obs_size:(task_obs_detal[0][1] + sim_obs_size)] = new_traj_obs + + return buffered_obs, self.eval_value_func + return None, None + + def live_plotter(self, img, identifier='', pause_time=0.00000001): + if not hasattr(self, 'imshow_obj'): + plt.ion() + + self.fig = plt.figure(figsize=(1, 1), dpi = 350) + ax = self.fig.add_subplot(111) + self.imshow_obj = ax.imshow(img) + # create a variable for the line so we can later update it + # update plot label/title + + plt.title('{}'.format(identifier)) + plt.show() + if not img is None: + self.imshow_obj.set_data(img) + + # plt.pause(pause_time) + self.fig.canvas.start_event_loop(0.001) + + def _compute_task_obs(self, env_ids=None): + if (env_ids is None): + root_states = self._humanoid_root_states + else: + root_states = self._humanoid_root_states[env_ids] + num_envs = self.num_envs if env_ids is None else len(env_ids) + + traj_samples = self._fetch_traj_samples(env_ids) + + obs = compute_location_observations(root_states, traj_samples, self._has_upright_start) + if self.terrain_obs: + + if self.terrain_obs_root == "head": + head_pose = self.get_head_pose(env_ids=env_ids) + self.measured_heights = self.get_heights(root_states=head_pose, env_ids=env_ids) + else: + self.measured_heights = self.get_heights(root_states=root_states, env_ids=env_ids) + + + # if flags.height_debug: + # # joblib.dump(self.measured_heights, "heights.pkl") + # if env_ids is None or len(env_ids) == self.num_envs: + # heights = self.measured_heights.view(num_envs, -1, 3 if self.velocity_map else 1) + # sensor_size = int(np.sqrt(self.num_height_points)) + # heights_show = heights.cpu().numpy()[self.viewing_env_idx, :, 0].reshape(sensor_size, sensor_size) + # if heights_show.min() < 0: + # heights_show -= heights_show.min() + # self.live_plotter(heights_show) + + if self.cfg['env'].get("use_center_height", False): + center_heights = self.get_center_heights(root_states=root_states, env_ids=env_ids) + center_heights = center_heights.mean(dim=-1, keepdim=True) + + if self.velocity_map: + measured_heights = self.measured_heights.view(num_envs, -1, 3) + measured_heights[..., 0] = center_heights - measured_heights[..., 0] + heights = measured_heights.view(num_envs, -1) + else: + heights = center_heights - self.measured_heights + heights = torch.clip(heights, -3, 3.) * self.height_meas_scale # + + else: + heights = torch.clip(root_states[:, 2:3] - self.measured_heights, -3, 3.) * self.height_meas_scale # + + obs = torch.cat([obs, heights], dim=1) + + if self._divide_group and self._group_obs: + group_obs = compute_group_observation(self._rigid_body_pos, self._rigid_body_rot, self._rigid_body_vel, self.selected_group_jts, self._group_num_people, self._has_upright_start) + # Group obs has to be computed as a whole. otherwise, the grouping breaks. + if not (env_ids is None): + group_obs = group_obs[env_ids] + + obs = torch.cat([obs, group_obs], dim=1) + + return obs + + + def _compute_flip_task_obs(self, normal_task_obs, env_ids): + + # location_obs 20 + # Terrain obs: self.num_terrain_obs + # group obs + B, D = normal_task_obs.shape + traj_samples_dim = 20 + obs_acc = [] + normal_task_obs = normal_task_obs.clone() + traj_samples = normal_task_obs[:, :traj_samples_dim].view(B, 10, 2) + traj_samples[..., 1] *= -1 + obs_acc.append(traj_samples.view(B, -1)) + if self.terrain_obs: + if self.velocity_map: + height_samples = normal_task_obs[..., traj_samples_dim: traj_samples_dim + self.num_height_points * 3] + height_samples = height_samples.view(B, int(np.sqrt(self.num_height_points)), int(np.sqrt(self.num_height_points)), 3) + height_samples[..., 0].flip(2) + height_samples = height_samples.flip(2) + obs_acc.append(height_samples.view(B, -1)) + else: + height_samples = normal_task_obs[..., traj_samples_dim: traj_samples_dim + self.num_height_points].view(B, int(np.sqrt(self.num_height_points)), int(np.sqrt(self.num_height_points))) + height_samples = height_samples.flip(2) + obs_acc.append(height_samples.view(B, -1)) + + obs = torch.cat(obs_acc, dim=1) + + if self._divide_group and self._group_obs: + group_obs = normal_task_obs[..., traj_samples_dim + self.num_height_points: ].view(B, -1, 3) + group_obs[..., 1] *= -1 + obs_acc.append(group_obs.view(B, -1)) + + + obs = torch.cat(obs_acc, dim=1) + + del obs_acc + + return obs + + def _reset_task(self, env_ids): + # super()._reset_task(env_ids) # Commented out to disable traj resetting + if not flags.server_mode: + root_pos = self._humanoid_root_states[env_ids, 0:3] + self._traj_gen.reset(env_ids, root_pos) + return + + + def _sample_ref_state(self, env_ids, vel_min=1, vel_range=0.5): + num_envs = env_ids.shape[0] + motion_ids = self._motion_lib.sample_motions(num_envs) + # if (self._state_init == HumanoidAMP.StateInit.Random or self._state_init == HumanoidAMP.StateInit.Hybrid): + # motion_times = self._sample_time(motion_ids) + # elif (self._state_init == HumanoidAMP.StateInit.Start): + # motion_times = torch.zeros(num_envs, device=self.device) + # else: + # assert ( + # False + # ), "Unsupported state initialization strategy: {:s}".format( + # str(self._state_init)) + motion_times = self._sample_time(motion_ids) + + if self.humanoid_type in ["smpl", "smplh", "smplx"]: + curr_gender_betas = self.humanoid_shapes[env_ids] + root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, rb_pos, rb_rot, body_vel, body_ang_vel = self._get_fixed_smpl_state_from_motionlib( + motion_ids, motion_times, curr_gender_betas) + else: + root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel = self._motion_lib.get_motion_state( + motion_ids, motion_times) + rb_pos, rb_rot = None, None + + key_pos = rb_pos[:, self._key_body_ids] + + # if flags.random_heading: + # random_rot = np.zeros([num_envs, 3]) + # random_rot[:, 2] = np.pi * (2 * np.random.random([num_envs]) - 1.0) + # random_heading_quat = torch.from_numpy(sRot.from_euler("xyz", random_rot).as_quat()).float().to(self.device) + # random_heading_quat_repeat = random_heading_quat[:, None].repeat(1, 24, 1) + # root_rot = quat_mul(random_heading_quat, root_rot).clone() + # rb_pos = quat_apply(random_heading_quat_repeat, rb_pos - root_pos[:, None, :]).clone() + root_pos[:, None, :] + # key_pos = quat_apply(random_heading_quat_repeat[:, :4, :], (key_pos - root_pos[:, None, :])).clone() + root_pos[:, None, :] + # rb_rot = quat_mul(random_heading_quat_repeat, rb_rot).clone() + # root_ang_vel = quat_apply(random_heading_quat, root_ang_vel).clone() + # root_vel = quat_apply(random_heading_quat, root_vel).clone() + + return motion_ids, motion_times, root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, key_pos, rb_pos, rb_rot, body_vel, body_ang_vel + + def _reset_ref_state_init(self, env_ids): + num_envs = env_ids.shape[0] + motion_ids, motion_times, root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, key_pos, rb_pos, rb_rot, body_vel, body_ang_vel= self._sample_ref_state(env_ids) + ## Randomrized location setting + new_root_xy = self.terrain.sample_valid_locations(self.num_envs, env_ids) + # joblib.dump(self.terrain.sample_valid_locations(100000, torch.arange(100000)).detach().cpu(), "new_root_xy.pkl") + # import ipdb; ipdb.set_trace() + + if flags.fixed: + # new_root_xy[:, 0], new_root_xy[:, 1] = 0 , 0 + # new_root_xy[:, 0], new_root_xy[:, 1] = 134.8434 + env_ids , -28.9593 + # new_root_xy[:, 0], new_root_xy[:, 1] = 30 + env_ids * 4, 240 + new_root_xy[:, 0], new_root_xy[:, 1] = 84 + env_ids * 3, 143 + # new_root_xy[:, 0], new_root_xy[:, 1] = 95 + env_ids * 5, 307 + # new_root_xy[:, 0], new_root_xy[:, 1] = 27, 1 + env_ids * 2 + # x_grid, y_grid = torch.meshgrid(torch.arange(64), torch.arange(64)) + # new_root_xy[:, 0], new_root_xy[:, 1] = x_grid.flatten()[env_ids] * 2, y_grid.flatten()[env_ids] * 2 + # if env_ids[0] == 0: + # new_root_xy[0, 0], new_root_xy[0, 1] = 34 , -81 + + if flags.server_mode: + new_traj = self._traj_gen.input_new_trajs(env_ids) + new_root_xy[:, 0], new_root_xy[:, 1] = new_traj[:, 0, 0], new_traj[:, 0, 1] + + + + diff_xy = new_root_xy - root_pos[:, 0:2] + root_pos[:, 0:2] = new_root_xy + + root_states = torch.cat([root_pos, root_rot], dim=1) + + center_height = self.get_center_heights(root_states, env_ids=env_ids).mean(dim=-1) + + if self.big_ankle: # Big ankle needs a bit more room. + center_height += 0.05 + + root_pos[:, 2] += center_height + key_pos[..., 0:2] += diff_xy[:, None, :] + key_pos[..., 2] += center_height[:, None] + rb_pos[..., 0:2] += diff_xy[:, None, :] + key_pos[..., 2] += center_height[:, None] + + self._set_env_state(env_ids=env_ids, + root_pos=root_pos, + root_rot=root_rot, + dof_pos=dof_pos, + root_vel=root_vel, + root_ang_vel=root_ang_vel, + dof_vel=dof_vel, + rigid_body_pos=rb_pos, + rigid_body_rot=rb_rot, + rigid_body_vel=body_vel, + rigid_body_ang_vel=body_ang_vel, + + ) + + self._reset_ref_env_ids = env_ids + self._reset_ref_motion_ids = motion_ids + self._reset_ref_motion_times = motion_times + if flags.follow: + self.start = True ## Updating camera when reset + + return + + def init_center_height_points(self): + # center_height_points + y = torch.tensor(np.linspace(-0.2, 0.2, 3),device=self.device,requires_grad=False) + x = torch.tensor(np.linspace(-0.1, 0.1, 3),device=self.device,requires_grad=False) + grid_x, grid_y = torch.meshgrid(x, y) + grid_x, grid_y = torch.meshgrid(x, y) + + self.num_center_height_points = grid_x.numel() + points = torch.zeros(self.num_envs, + self.num_center_height_points, + 3, + device=self.device, + requires_grad=False) + points[:, :, 0] = grid_x.flatten() + points[:, :, 1] = grid_y.flatten() + return points + + def init_square_height_points(self): + # 4mx4m square + y = torch.tensor(np.linspace(-self.sensor_extent, self.sensor_extent, self.sensor_res),device=self.device,requires_grad=False) + x = torch.tensor(np.linspace(-self.sensor_extent, self.sensor_extent, + self.sensor_res), + device=self.device, + requires_grad=False) + grid_x, grid_y = torch.meshgrid(x, y) + grid_x, grid_y = torch.meshgrid(x, y) + + self.num_height_points = grid_x.numel() + points = torch.zeros(self.num_envs, + self.num_height_points, + 3, + device=self.device, + requires_grad=False) + points[:, :, 0] = grid_x.flatten() + points[:, :, 1] = grid_y.flatten() + return points + + def init_square_fov_height_points(self): + y = torch.tensor(np.linspace(-1, 1, 20),device=self.device,requires_grad=False) + x = torch.tensor(np.linspace(-0.02, 1.98, 20),device=self.device,requires_grad=False) + grid_x, grid_y = torch.meshgrid(x, y) + + self.num_height_points = grid_x.numel() + points = torch.zeros(self.num_envs, + self.num_height_points, + 3, + device=self.device, + requires_grad=False) + points[:, :, 0] = grid_x.flatten() + points[:, :, 1] = grid_y.flatten() + return points + + def init_root_points(self): + y = torch.tensor(np.linspace(-0.5, 0.5, 20), + device=self.device, + requires_grad=False) + x = torch.tensor(np.linspace(-0.25, 0.25, 10), + device=self.device, + requires_grad=False) + grid_x, grid_y = torch.meshgrid(x, y) + + self.num_root_points = grid_x.numel() + points = torch.zeros(self.num_envs, + self.num_root_points, + 3, + device=self.device, + requires_grad=False) + points[:, :, 0] = grid_x.flatten() + points[:, :, 1] = grid_y.flatten() + return points + + def init_fov_height_points(self): + # 3m x 3m fan shaped area + rs = np.exp(np.arange(0.2, 2, 0.1)) - 0.9 + rs = rs/rs.max() * 2 + + max_angle = 110 + phi = np.exp(np.linspace(0.1, 1.5, 12)) - 1 + phi = phi/phi.max() * max_angle + phi = np.concatenate([-phi[::-1],[0], phi]) * np.pi/180 + xs, ys = [], [] + for r in rs: + xs.append(r * np.cos(phi)); ys.append(r * np.sin(phi)) + + xs, ys = np.concatenate(xs), np.concatenate(ys) + + xs, ys = torch.from_numpy(xs).to(self.device), torch.from_numpy(ys).to( + self.device) + + self.num_height_points = xs.shape[0] + points = torch.zeros(self.num_envs, + self.num_height_points, + 3, + device=self.device, + requires_grad=False) + points[:, :, 0] = xs + points[:, :, 1] = ys + return points + + def get_center_heights(self, root_states, env_ids=None): + base_quat = root_states[:, 3:7] + if self.cfg["env"]["terrain"]["terrainType"] == 'plane': + return torch.zeros(self.num_envs, + self.num_center_height_points, + device=self.device, + requires_grad=False) + elif self.cfg["env"]["terrain"]["terrainType"] == 'none': + raise NameError("Can't measure height with terrain type 'none'") + + if self.humanoid_type in ["smpl", "smplh", "smplx"] and not self._has_upright_start: + base_quat = remove_base_rot(base_quat) + + if env_ids is None: + points = quat_apply_yaw( + base_quat.repeat(1, self.num_center_height_points,), + self.center_height_points) + (root_states[:, :3]).unsqueeze(1) + else: + points = quat_apply_yaw( + base_quat.repeat(1, self.num_center_height_points,), + self.center_height_points[env_ids]) + ( + root_states[:, :3]).unsqueeze(1) + + heights = self.terrain.sample_height_points(points.clone(), env_ids=env_ids) + num_envs = self.num_envs if env_ids is None else len(env_ids) + + return heights.view(num_envs, -1) + + def get_heights(self, root_states, env_ids=None): + + base_quat = root_states[:, 3:7] + if self.cfg["env"]["terrain"]["terrainType"] == 'plane': + return torch.zeros(self.num_envs, + self.num_height_points, + device=self.device, + requires_grad=False) + elif self.cfg["env"]["terrain"]["terrainType"] == 'none': + raise NameError("Can't measure height with terrain type 'none'") + + if self.humanoid_type in ["smpl", "smplh", "smplx"] and not self._has_upright_start: + base_quat = remove_base_rot(base_quat) + + heading_rot = torch_utils.calc_heading_quat(base_quat) + + if env_ids is None: + points = quat_apply( + heading_rot.repeat(1, self.num_height_points).reshape(-1, 4), + self.height_points) + (root_states[:, :3]).unsqueeze(1) + else: + points = quat_apply( + heading_rot.repeat(1, self.num_height_points).reshape(-1, 4), + self.height_points[env_ids]) + ( + root_states[:, :3]).unsqueeze(1) + + if self.velocity_map: + root_states_all = self._humanoid_root_states + else: + root_states_all = None + + if (self._divide_group or flags.divide_group) and not self._group_obs and not self._disable_group_obs: + heading_rot_all = torch_utils.calc_heading_quat(self._humanoid_root_states[:, 3:7]) + root_points = quat_apply( + heading_rot_all.repeat(1, self.num_root_points).reshape(-1, 4), + self.root_points) + (self._humanoid_root_states[:, :3]).unsqueeze(1) + # update heights with root points + heights = self.terrain.sample_height_points( + points.clone(), + root_states = root_states_all, + root_points = root_points, + env_ids=env_ids, + num_group_people=self._group_num_people, + group_ids = self._group_ids) + else: + heights = self.terrain.sample_height_points( + points.clone(), + root_states=root_states_all, + root_points=None, + env_ids=env_ids, + ) + # heights = self.terrain.sample_height_points(points.clone(), None) + num_envs = self.num_envs if env_ids is None else len(env_ids) + + return heights.view(num_envs, -1) + + def _create_ground_plane(self): + if self.real_mesh: + self.create_mesh_ground() + else: + self.create_training_ground() + + def create_mesh_ground(self): + # plane_params = gymapi.PlaneParams() + # plane_params.normal = gymapi.Vec3(0.0, 0.0, 1.0) + # plane_params.distance = 7.001 + # plane_params.static_friction = self.plane_static_friction + # plane_params.dynamic_friction = self.plane_dynamic_friction + # plane_params.restitution = self.plane_restitution + # self.gym.add_ground(self.sim, plane_params) + + # scene_name = "parking" + # scene_name = "parking_with_cars" + # scene_name = "with_less_car" + scene_name = "mesh-downtown-san-jose-mapaligned-cropped-bottom-part-global" + self.mesh_data = joblib.load(f"data/mesh/{scene_name}.pkl") + + mesh_vertices = self.mesh_data["vertices"] + mesh_triangles = self.mesh_data["faces"].astype(np.uint32) + + tm_params = gymapi.TriangleMeshParams() + tm_params.nb_vertices = mesh_vertices.shape[0] + tm_params.nb_triangles = mesh_triangles.shape[0] + tm_params.transform.p.x = 0.0 + tm_params.transform.p.y = 0.0 + tm_params.transform.p.z = 0.0 + tm_params.static_friction = self.plane_static_friction + tm_params.dynamic_friction = self.plane_dynamic_friction + tm_params.restitution = self.plane_restitution + self.gym.add_triangle_mesh(self.sim, mesh_vertices.flatten(order='C'), + mesh_triangles.flatten(order='C'), + tm_params) + + self.terrain = MeshTerrain(self.mesh_data, self.device) + self.height_samples = torch.tensor(self.terrain.heightsamples).to( + self.device) + return + + def create_training_ground(self): + if flags.small_terrain: + self.cfg["env"]["terrain"]['mapLength'] = 8 + self.cfg["env"]["terrain"]['mapWidth'] = 8 + + self.terrain = Terrain(self.cfg["env"]["terrain"], + num_robots=self.num_envs, + device=self.device) + + tm_params = gymapi.TriangleMeshParams() + tm_params.nb_vertices = self.terrain.vertices.shape[0] + tm_params.nb_triangles = self.terrain.triangles.shape[0] + tm_params.transform.p.x = 0 + tm_params.transform.p.y = 0 + tm_params.transform.p.z = 0.0 + tm_params.static_friction = self.cfg["env"]["terrain"]["staticFriction"] + tm_params.dynamic_friction = self.cfg["env"]["terrain"]["dynamicFriction"] + tm_params.restitution = self.cfg["env"]["terrain"]["restitution"] + self.gym.add_triangle_mesh(self.sim, + self.terrain.vertices.flatten(order='C'), + self.terrain.triangles.flatten(order='C'), + tm_params) + self.height_samples = torch.tensor(self.terrain.heightsamples).view(self.terrain.tot_rows, self.terrain.tot_cols).to(self.device) + + # plane_params = gymapi.PlaneParams() + # plane_params.normal = gymapi.Vec3(0.0, 0.0, 1.0) + # plane_params.distance = 0 + # plane_params.static_friction = self.plane_static_friction + # plane_params.dynamic_friction = self.plane_dynamic_friction + # plane_params.restitution = self.plane_restitution + # self.gym.add_ground(self.sim, plane_params) + # print("using plain ground");print("using plain ground");print("using plain ground");print("using plain ground");print("using plain ground"); + + def _compute_reset(self): + time = self.progress_buf * self.dt + env_ids = torch.arange(self.num_envs, + device=self.device, + dtype=torch.long) + tar_pos = self._traj_gen.calc_pos(env_ids, time) + ### ZL: entry point + # self._traj_gen.update_sim_pos(self._humanoid_root_states[) + + root_states = self._humanoid_root_states + center_height = self.get_center_heights( + root_states, env_ids=None).mean(dim=-1, keepdim=True) + + # import ipdb + # ipdb.set_trace() + self.reset_buf[:], self._terminate_buf[:] = compute_humanoid_reset( + self.reset_buf, self.progress_buf, self._contact_forces, + self._contact_body_ids, center_height, self._rigid_body_pos, + tar_pos, self.max_episode_length, self._fail_dist, + self._enable_early_termination, self._termination_heights, flags.no_collision_check) + return + + def _compute_reward(self, actions): + root_pos = self._humanoid_root_states[..., 0:3] + + time = self.progress_buf * self.dt + env_ids = torch.arange(self.num_envs, device=self.device, dtype=torch.long) + tar_pos = self._traj_gen.calc_pos(env_ids, time) + if self.fuzzy_target: + location_reward = compute_location_reward_fuzzy(root_pos, tar_pos) + else: + location_reward = compute_location_reward(root_pos, tar_pos) + + power = torch.abs(torch.multiply(self.dof_force_tensor, self._dof_vel)).sum(dim = -1) + # power_reward = -0.00005 * (power ** 2) + power_reward = -self.power_coefficient * power + + if self.power_reward: + self.rew_buf[:] = location_reward + power_reward + else: + self.rew_buf[:] = location_reward + self.reward_raw[:] = torch.cat([location_reward[:, None], power_reward[:, None]], dim = -1) + + return + + +from isaacgym.terrain_utils import * +from phc.utils.draw_utils import * + + +def poles_terrain(terrain, difficulty=1): + """ + Generate stairs + + Parameters: + terrain (terrain): the terrain + step_width (float): the width of the step [meters] + step_height (float): the step_height [meters] + platform_size (float): size of the flat platform at the center of the terrain [meters] + Returns: + terrain (SubTerrain): update terrain + """ + # switch parameters to discrete units + height = 0 + start_x = 0 + stop_x = terrain.width + start_y = 0 + stop_y = terrain.length + + img = np.zeros((terrain.width, terrain.length), dtype=int) + # disk, circle, curve, poly, ellipse + base_prob = 1 / 2 + # probs = np.array([0.7, 0.7, 0.4, 0.5, 0.5]) * ((1 - base_prob) * difficulty + base_prob) + probs = np.array([0.9, 0.4, 0.5, 0.5]) * ((1 - base_prob) * difficulty + base_prob) + low, high = 200, 500 + num_mult = int(stop_x // 80) + + for i in range(len(probs)): + p = probs[i] + if i == 0: + for _ in range(10 * num_mult): + if np.random.binomial(1, p): + img += draw_disk(img_size=terrain.width, max_r = 7) * int(np.random.uniform(low, high)) + # elif i == 1 and np.random.binomial(1, p): + # for _ in range(5 * num_mult): + # if np.random.binomial(1, p): + + # img += draw_circle(img_size=terrain.width, max_r=5) * int( + # np.random.uniform(low, high)) + elif i == 1 and np.random.binomial(1, p): + for _ in range(3 * num_mult): + if np.random.binomial(1, p): + img += draw_curve(img_size=terrain.width) * int(np.random.uniform(low, high)) + elif i == 2 and np.random.binomial(1, p): + for _ in range(1 * num_mult): + if np.random.binomial(1, p): + img += draw_polygon(img_size=terrain.width, max_sides=5) * int(np.random.uniform(low, high)) + elif i == 3 and np.random.binomial(1, p): + for _ in range(5 * num_mult): + if np.random.binomial(1, p): + img += draw_ellipse(img_size=terrain.width, + max_size=5) * int( + np.random.uniform(low, high)) + + terrain.height_field_raw[start_x:stop_x, start_y:stop_y] = img + + return terrain + + +class MeshTerrain: + def __init__(self, heigthmap_data, device): + self.border_size = 20 + self.border = 500 + self.sample_extent_x = 300 + self.sample_extent_y = 300 + self.vertical_scale = 1 + self.device = device + + self.heightsamples = torch.from_numpy(heigthmap_data['heigthmap']).to(device) + self.walkable_map = torch.from_numpy(heigthmap_data['walkable_map']).to(device) + self.cam_pos = torch.from_numpy(heigthmap_data['cam_pos']) + self.x_scale = heigthmap_data['x_scale'] + self.y_scale = heigthmap_data['y_scale'] + + self.x_shape, self.y_shape = self.heightsamples.shape + self.x_c = (self.x_shape / 2) / self.x_scale + self.y_c = (self.y_shape / 2) / self.y_scale + + coord_x, coord_y = torch.where(self.walkable_map == 1) # Image coordinates, need to flip y and x + coord_x, coord_y = coord_x.float(), coord_y.float() + self.coord_x_scale = coord_x / self.x_scale - self.x_c + self.coord_y_scale = coord_y / self.y_scale - self.y_c + + self.coord_x_scale += self.cam_pos[0] + self.coord_y_scale += self.cam_pos[1] + + self.num_samples = self.coord_x_scale.shape[0] + + + def sample_valid_locations(self, num_envs, env_ids): + num_envs = env_ids.shape[0] + idxes = np.random.randint(0, self.num_samples, size=num_envs) + valid_locs = torch.stack([self.coord_x_scale[idxes], self.coord_y_scale[idxes]], dim = -1) + return valid_locs + + def world_points_to_map(self, points): + points[..., 0] -= self.cam_pos[0] - self.x_c + points[..., 1] -= self.cam_pos[1] - self.y_c + points[..., 0] *= self.x_scale + points[..., 1] *= self.y_scale + points = (points).long() + + px = points[:, :, 0].view(-1) + py = points[:, :, 1].view(-1) + + px = torch.clip(px, 0, self.heightsamples.shape[0] - + 2) # image, so sampling 1 is for x + py = torch.clip(py, 0, self.heightsamples.shape[1] - 2) + return px, py + + def sample_height_points(self, + points, + root_states = None, + root_points=None, + env_ids=None, + num_group_people=512, + group_ids=None): + + B, N, C = points.shape + px, py = self.world_points_to_map(points) + heightsamples = self.heightsamples.clone() + device = points.device + if env_ids is None: + env_ids = torch.arange(B).to(points).long() + + if not root_points is None: + # Adding human root points to the height field + max_num_envs, num_root_points, _ = root_points.shape + root_px, root_py = self.world_points_to_map(root_points) + num_groups = int(root_points.shape[0] / num_group_people) + heightsamples_group = heightsamples[None, ].repeat( + num_groups, 1, 1) + + + root_px, root_py = root_px.view(-1, num_group_people * num_root_points), root_py.view(-1, num_group_people * num_root_points) + px, py = px.view(-1, N), py.view(-1, N) + heights = torch.zeros(px.shape).to(px.device) + + if not root_states is None: + linear_vel = root_states[:, 7:10] # This contains ALL the linear velocities + root_rot = root_states[:, 3:7] + heading_rot = torch_utils.calc_heading_quat_inv(root_rot) + velocity_map = torch.zeros([px.shape[0], px.shape[1], + 2]).to(root_states) + velocity_map_group = torch.zeros(heightsamples_group.shape + + (3, )).to(points) + + for idx in range(num_groups): + heightsamples_group[idx][root_px[idx],root_py[idx]] += torch.tensor(1.7 / self.vertical_scale) + # heightsamples_group[idx][root_px[idx] + 1,root_py[idx] + 1] += torch.tensor(1.7 / self.vertical_scale) + group_mask_env_ids = group_ids[env_ids] == idx # agents to select for this group from the current env_ids + # if sum(group_mask) == 0: + # continue + group_px, group_py = px[group_mask_env_ids].view(-1), py[group_mask_env_ids].view(-1) + heights1 = heightsamples_group[idx][group_px, group_py] + heights2 = heightsamples_group[idx][group_px + 1, group_py + 1] + + heights_group = torch.min(heights1, heights2) + heights[group_mask_env_ids] = heights_group.view(-1, N) + + if not root_states is None: + # First update the map with the velocity + group_mask_all = group_ids == idx + env_ids_in_group = env_ids[group_mask_env_ids] + group_linear_vel = linear_vel[group_mask_all] + + velocity_map_group[ + idx, root_px[idx], + root_py[idx], :] = group_linear_vel.repeat( + 1, root_points.shape[1]).view(-1, 3) # Make sure that the order is correct. + + # Then sampling the points + vel_group = velocity_map_group[idx][group_px, group_py] + vel_group = vel_group.view(-1, N, 3) + vel_group -= linear_vel[env_ids_in_group, None] # this is one-to-one substraction of the agents in the group to mark the static terrain with relative velocity + group_heading_rot = heading_rot[env_ids_in_group] + + group_vel_idv = torch_utils.my_quat_rotate( + group_heading_rot.repeat(1, N).view(-1, 4), + vel_group.view(-1, 3) + ) # Global velocity transform. for ALL of the elements in the group. + group_vel_idv = group_vel_idv.view(-1, N, 3)[..., :2] + velocity_map[group_mask_env_ids] = group_vel_idv + # import matplotlib.pyplot as plt; plt.imshow(heights[0].reshape(32, 32).cpu().numpy()); plt.show() + if root_states is None: + return heights * self.vertical_scale + else: + heights = (heights * self.vertical_scale).view(B, -1, 1) + return torch.cat([heights, velocity_map], dim=-1) + else: + heights1 = heightsamples[px, py] + heights2 = heightsamples[px + 1, py + 1] + heights = torch.min(heights1, heights2) + + + return heights * self.vertical_scale + + +class Terrain: + def __init__(self, cfg, num_robots, device) -> None: + + self.type = cfg["terrainType"] + self.device = device + if self.type in ["none", 'plane']: + return + self.horizontal_scale = 0.1 # resolution 0.1 + self.vertical_scale = 0.005 + self.border_size = 50 + self.env_length = cfg["mapLength"] + self.env_width = cfg["mapWidth"] + self.proportions = [ + np.sum(cfg["terrainProportions"][:i + 1]) + for i in range(len(cfg["terrainProportions"])) + ] + + self.env_rows = cfg["numLevels"] + self.env_cols = cfg["numTerrains"] + self.num_maps = self.env_rows * self.env_cols + self.env_origins = np.zeros((self.env_rows, self.env_cols, 3)) + + self.width_per_env_pixels = int(self.env_width / self.horizontal_scale) + self.length_per_env_pixels = int(self.env_length / + self.horizontal_scale) + + self.border = int(self.border_size / self.horizontal_scale) + self.tot_cols = int( + self.env_cols * self.width_per_env_pixels) + 2 * self.border + self.tot_rows = int( + self.env_rows * self.length_per_env_pixels) + 2 * self.border + + self.height_field_raw = np.zeros((self.tot_rows, self.tot_cols), dtype=np.int16) + self.walkable_field_raw = np.zeros((self.tot_rows, self.tot_cols), dtype=np.int16) + if cfg["curriculum"]: + self.curiculum(num_robots, + num_terrains=self.env_cols, + num_levels=self.env_rows) + else: + self.randomized_terrain() + self.heightsamples = torch.from_numpy(self.height_field_raw).to(self.device) # ZL: raw height field, first dimension is x, second is y + self.walkable_field = torch.from_numpy(self.walkable_field_raw).to(self.device) + self.vertices, self.triangles = convert_heightfield_to_trimesh(self.height_field_raw, self.horizontal_scale, self.vertical_scale,cfg["slopeTreshold"]) + self.sample_extent_x = int((self.tot_rows - self.border * 2) * self.horizontal_scale) + self.sample_extent_y = int((self.tot_cols - self.border * 2) * self.horizontal_scale) + + coord_x, coord_y = torch.where(self.walkable_field == 0) + coord_x_scale = coord_x * self.horizontal_scale + coord_y_scale = coord_y * self.horizontal_scale + walkable_subset = torch.logical_and( + torch.logical_and(coord_y_scale < coord_y_scale.max() - self.border * self.horizontal_scale, coord_x_scale < coord_x_scale.max() - self.border * self.horizontal_scale), + torch.logical_and(coord_y_scale > coord_y_scale.min() + self.border * self.horizontal_scale, coord_x_scale > coord_x_scale.min() + self.border * self.horizontal_scale) + ) + # import ipdb; ipdb.set_trace() + # joblib.dump(self.walkable_field_raw, "walkable_field.pkl") + + self.coord_x_scale = coord_x_scale[walkable_subset] + self.coord_y_scale = coord_y_scale[walkable_subset] + self.num_samples = self.coord_x_scale.shape[0] + + + def sample_valid_locations(self, max_num_envs, env_ids, group_num_people = 16, sample_groups = False): + if sample_groups: + num_groups = max_num_envs// group_num_people + group_centers = torch.stack([torch_rand_float(0., self.sample_extent_x, (num_groups, 1),device=self.device).squeeze(1), torch_rand_float(0., self.sample_extent_y, (num_groups, 1),device=self.device).squeeze(1)], dim = -1) + group_diffs = torch.stack([torch_rand_float(-8., 8, (num_groups, group_num_people) ,device=self.device), torch_rand_float(8., -8, (num_groups, group_num_people),device=self.device)], dim = -1) + valid_locs = (group_centers[:, None, ] + group_diffs).reshape(max_num_envs, -1) + + if not env_ids is None: + valid_locs = valid_locs[env_ids] + else: + num_envs = env_ids.shape[0] + idxes = np.random.randint(0, self.num_samples, size=num_envs) + valid_locs = torch.stack([self.coord_x_scale[idxes], self.coord_y_scale[idxes]], dim = -1) + + return valid_locs + + def world_points_to_map(self, points): + points = (points / self.horizontal_scale).long() + px = points[:, :, 0].view(-1) + py = points[:, :, 1].view(-1) + px = torch.clip(px, 0, self.heightsamples.shape[0] - 2) + py = torch.clip(py, 0, self.heightsamples.shape[1] - 2) + return px, py + + + def sample_height_points(self, points, root_states = None, root_points=None, env_ids = None, num_group_people = 512, group_ids = None): + B, N, C = points.shape + px, py = self.world_points_to_map(points) + heightsamples = self.heightsamples.clone() + if env_ids is None: + env_ids = torch.arange(B).to(points).long() + + if not root_points is None: + # Adding human root points to the height field + max_num_envs, num_root_points, _ = root_points.shape + root_px, root_py = self.world_points_to_map(root_points) + num_groups = int(root_points.shape[0]/num_group_people) + heightsamples_group = heightsamples[None, ].repeat(num_groups, 1, 1) + + root_px, root_py = root_px.view(-1, num_group_people * num_root_points), root_py.view(-1, num_group_people * num_root_points) + px, py = px.view(-1, N), py.view(-1, N) + heights = torch.zeros_like(px) + + if not root_states is None: + linear_vel = root_states[:, 7:10] # This contains ALL the linear velocities + root_rot = root_states[:, 3:7] + heading_rot = torch_utils.calc_heading_quat_inv(root_rot) + velocity_map = torch.zeros([px.shape[0], px.shape[1], 2]).to(root_states) + velocity_map_group = torch.zeros(heightsamples_group.shape + (3,)).to(points) + + for idx in range(num_groups): + heightsamples_group[idx][root_px[idx], root_py[idx]] += torch.tensor(1.7 / self.vertical_scale).short() + group_mask_env_ids = group_ids[env_ids] == idx # agents to select for this group from the current env_ids + # if sum(group_mask) == 0: + # continue + group_px, group_py = px[group_mask_env_ids].view(-1), py[group_mask_env_ids].view(-1) + heights1 = heightsamples_group[idx][group_px, group_py] + heights2 = heightsamples_group[idx][group_px + 1, group_py + 1] + heights_group = torch.min(heights1, heights2) + heights[group_mask_env_ids] = heights_group.view(-1, N).long() + + if not root_states is None: + # First update the map with the velocity + group_mask_all = group_ids == idx + env_ids_in_group = env_ids[group_mask_env_ids] + group_linear_vel = linear_vel[group_mask_all] + velocity_map_group[idx, root_px[idx], root_py[idx], :] = group_linear_vel.repeat(1, root_points.shape[1]).view(-1, 3) + + # Sample the points for each agent's px and py + vel_group = velocity_map_group[idx][group_px, group_py] + vel_group = vel_group.view(-1, N, 3) + vel_group -= linear_vel[env_ids_in_group, None] # for each agent's velocity map, minus it's own velocity to get the relative velocity + group_heading_rot = heading_rot[env_ids_in_group] + + group_vel_idv = torch_utils.my_quat_rotate( + group_heading_rot.repeat(1, N).view(-1, 4), + vel_group.view(-1, 3) + ) # Global velocity transform. for ALL of the elements in the group. + group_vel_idv = group_vel_idv.view(-1, N, 3)[..., :2] + velocity_map[group_mask_env_ids] = group_vel_idv + if root_states is None: + return heights * self.vertical_scale + else: + heights = (heights * self.vertical_scale).view(B, -1, 1) + return torch.cat([heights, velocity_map], dim = -1) + + else: + heights1 = heightsamples[px, py] + heights2 = heightsamples[px + 1, py + 1] + heights = torch.min(heights1, heights2) + + if root_states is None: + return heights * self.vertical_scale + else: + velocity_map = torch.zeros((B, N, 2)).to(points) + linear_vel = root_states[env_ids, 7:10] + root_rot = root_states[env_ids, 3:7] + heading_rot = torch_utils.calc_heading_quat_inv(root_rot) + linear_vel_ego = torch_utils.my_quat_rotate(heading_rot, linear_vel) + velocity_map[:] = velocity_map[:] - linear_vel_ego[:, None, :2] # Flip velocity to be in agent's point of view + heights = (heights * self.vertical_scale).view(B, -1, 1) + return torch.cat([heights, velocity_map], dim = -1) + + # def sample_height_points(self, points, root_points=None): + # # Ugly but correct solution + # B, N, _ = points.shape + # px, py = self.world_points_to_map(points) + + # if not root_points is None: + # # Adding human root points to the height field + # root_px, root_py = self.world_points_to_map(root_points) + # root_px, root_py = root_px.view(B, -1), root_py.view(B, -1) + # heights_acc = [] + # for curr_agent in range(B): + # px, py = px.view(B, -1), py.view(B, -1) + # heightsamples = self.heightsamples.clone() + # mask = torch.ones(B).bool() + # mask[curr_agent] = False + # heightsamples[root_px[mask].flatten(), root_py[mask].flatten()] += torch.tensor(1.7 / self.vertical_scale).short() + # heights1 = heightsamples[px[curr_agent], py[curr_agent]] + # heights2 = heightsamples[px[curr_agent] + 1, py[curr_agent] + 1] + # heights = torch.min(heights1, heights2) + # heights_acc.append(heights) + # heights = torch.stack(heights_acc, dim=0) + + # else: + # heightsamples = self.heightsamples.clone() + # heights1 = heightsamples[px, py] + # heights2 = heightsamples[px + 1, py + 1] + # heights = torch.min(heights1, heights2) + # return heights + + def randomized_terrain(self): + for k in range(self.num_maps): + # Env coordinates in the world + (i, j) = np.unravel_index(k, (self.env_rows, self.env_cols)) + + # Heightfield coordinate system from now on + start_x = self.border + i * self.length_per_env_pixels + end_x = self.border + (i + 1) * self.length_per_env_pixels + start_y = self.border + j * self.width_per_env_pixels + end_y = self.border + (j + 1) * self.width_per_env_pixels + + terrain = SubTerrain("terrain", + width=self.width_per_env_pixels, + length=self.width_per_env_pixels, + vertical_scale=self.vertical_scale, + horizontal_scale=self.horizontal_scale) + choice = np.random.uniform(0, 1) + difficulty = np.random.uniform(0.1, 1) + slope = difficulty * 0.7 + discrete_obstacles_height = 0.025 + difficulty * 0.15 + stepping_stones_size = 2 - 1.8 * difficulty + step_height = 0.05 + 0.175 * difficulty + if choice < self.proportions[0]: + if choice < 0.05: + slope *= -1 + pyramid_sloped_terrain(terrain, slope=slope, platform_size=3.) + elif choice < self.proportions[1]: + if choice < 0.15: + slope *= -1 + pyramid_sloped_terrain(terrain, slope=slope, platform_size=3.) + random_uniform_terrain(terrain, + min_height=-0.1, + max_height=0.1, + step=0.025, + downsampled_scale=0.2) + elif choice < self.proportions[3]: + if choice < self.proportions[2]: + step_height *= -1 + pyramid_stairs_terrain(terrain, + step_width=0.31, + step_height=step_height, + platform_size=3.) + elif choice < self.proportions[4]: + discrete_obstacles_terrain(terrain, + discrete_obstacles_height, + 1., + 2., + 40, + platform_size=3.) + elif choice < self.proportions[5]: + stepping_stones_terrain(terrain, + stone_size=stepping_stones_size, + stone_distance=0.1, + max_height=0., + platform_size=3.) + elif choice < self.proportions[6]: + poles_terrain(terrain=terrain, difficulty=difficulty) + self.walkable_field_raw[start_x:end_x, start_y:end_y] = (terrain.height_field_raw != 0) + + elif choice < self.proportions[7]: + # plain walking terrain + pass + + self.height_field_raw[start_x:end_x, start_y:end_y] = terrain.height_field_raw + + env_origin_x = (i + 0.5) * self.env_length + env_origin_y = (j + 0.5) * self.env_width + x1 = int((self.env_length / 2. - 1) / self.horizontal_scale) + x2 = int((self.env_length / 2. + 1) / self.horizontal_scale) + y1 = int((self.env_width / 2. - 1) / self.horizontal_scale) + y2 = int((self.env_width / 2. + 1) / self.horizontal_scale) + env_origin_z = np.max(terrain.height_field_raw[x1:x2, y1:y2]) * self.vertical_scale + self.env_origins[i, j] = [env_origin_x, env_origin_y, env_origin_z] + self.walkable_field_raw = ndimage.binary_dilation(self.walkable_field_raw, iterations=3).astype(int) + + def curiculum(self, num_robots, num_terrains, num_levels): + num_robots_per_map = int(num_robots / num_terrains) + left_over = num_robots % num_terrains + idx = 0 + for j in tqdm(range(num_terrains)): + for i in range(num_levels): + terrain = SubTerrain("terrain", + width=self.width_per_env_pixels, + length=self.width_per_env_pixels, + vertical_scale=self.vertical_scale, + horizontal_scale=self.horizontal_scale) + difficulty = i / num_levels + choice = j / num_terrains + + slope = difficulty * 0.7 + step_height = 0.05 + 0.175 * difficulty + discrete_obstacles_height = 0.025 + difficulty * 0.15 + stepping_stones_size = 2 - 1.8 * difficulty + + start_x = self.border + i * self.length_per_env_pixels + end_x = self.border + (i + 1) * self.length_per_env_pixels + start_y = self.border + j * self.width_per_env_pixels + end_y = self.border + (j + 1) * self.width_per_env_pixels + + if choice < self.proportions[0]: + if choice < 0.05: + slope *= -1 + pyramid_sloped_terrain(terrain, + slope=slope, + platform_size=3.) + elif choice < self.proportions[1]: + if choice < 0.15: + slope *= -1 + pyramid_sloped_terrain(terrain, + slope=slope, + platform_size=3.) + random_uniform_terrain(terrain, + min_height=-0.1, + max_height=0.1, + step=0.025, + downsampled_scale=0.2) + elif choice < self.proportions[3]: + if choice < self.proportions[2]: + step_height *= -1 + pyramid_stairs_terrain(terrain, + step_width=0.31, + step_height=step_height, + platform_size=3.) + elif choice < self.proportions[4]: + discrete_obstacles_terrain(terrain, + discrete_obstacles_height, + 1., + 2., + 40, + platform_size=3.) + elif choice < self.proportions[5]: + stepping_stones_terrain(terrain, + stone_size=stepping_stones_size, + stone_distance=0.1, + max_height=0., + platform_size=3.) + elif choice < self.proportions[6]: + poles_terrain(terrain=terrain, difficulty=difficulty) + self.walkable_field_raw[start_x:end_x, start_y:end_y] = (terrain.height_field_raw != 0) + # self.walkable_field_raw[start_x:end_x, start_y:end_y] = 1 + elif choice < self.proportions[7]: + # plain walking terrain + pass + + # Heightfield coordinate system + self.height_field_raw[start_x:end_x, start_y:end_y] = terrain.height_field_raw + + robots_in_map = num_robots_per_map + if j < left_over: + robots_in_map += 1 + + env_origin_x = (i + 0.5) * self.env_length + env_origin_y = (j + 0.5) * self.env_width + x1 = int((self.env_length / 2. - 1) / self.horizontal_scale) + x2 = int((self.env_length / 2. + 1) / self.horizontal_scale) + y1 = int((self.env_width / 2. - 1) / self.horizontal_scale) + y2 = int((self.env_width / 2. + 1) / self.horizontal_scale) + env_origin_z = np.max( + terrain.height_field_raw[x1:x2, + y1:y2]) * self.vertical_scale + self.env_origins[i, j] = [ + env_origin_x, env_origin_y, env_origin_z + ] + + self.walkable_field_raw = ndimage.binary_dilation(self.walkable_field_raw, iterations=3).astype(int) + + + + +@torch.jit.script +def compute_humanoid_reset(reset_buf, progress_buf, contact_buf, + contact_body_ids, center_height, rigid_body_pos, + tar_pos, max_episode_length, fail_dist, + enable_early_termination, termination_heights, disableCollision): + # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, float, float, bool, Tensor, bool) -> Tuple[Tensor, Tensor] + terminated = torch.zeros_like(reset_buf) + + if (enable_early_termination): + masked_contact_buf = contact_buf.clone() + masked_contact_buf[:, contact_body_ids, :] = 0 + ## torch.sum to disable self-collision. + # force_threshold = 200 + force_threshold = 50 + body_contact_force = torch.sqrt(torch.square(torch.abs(masked_contact_buf.sum(dim=-2))).sum(dim=-1)) > force_threshold + + # has_fallen = torch.logical_and(body_contact_force, fall_height) + has_fallen = body_contact_force + # first timestep can sometimes still have nonzero contact forces + # so only check after first couple of steps + has_fallen *= (progress_buf > 1) + + root_pos = rigid_body_pos[..., 0, :] + tar_delta = tar_pos[..., 0:2] - root_pos[...,0:2] # also reset if toooo far away from the target trajectory + tar_dist_sq = torch.sum(tar_delta * tar_delta, dim=-1) + tar_fail = tar_dist_sq > fail_dist * fail_dist + + has_failed = torch.logical_or(has_fallen, tar_fail) + # if has_fallen.any(): + # import ipdb + # ipdb.set_trace() + + if disableCollision: + has_failed[:] = False + + ############################## Debug ############################## + # if torch.sum(has_fallen) > 0: + # import ipdb; ipdb.set_trace() + # print("???") + # mujoco_joint_names = np.array(['Pelvis', 'L_Hip', 'L_Knee', 'L_Ankle', 'L_Toe', 'R_Hip', 'R_Knee', 'R_Ankle', 'R_Toe', 'Torso', 'Spine', 'Chest', 'Neck', 'Head', 'L_Thorax', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'L_Hand', 'R_Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'R_Hand']) + # print( mujoco_joint_names[masked_contact_buf[0, :, 0].nonzero().cpu().numpy()]) + ############################## Debug ############################## + + + # has_failed[:] = False + + terminated = torch.where(has_failed, torch.ones_like(reset_buf), terminated) + + + # if torch.sum(terminated) > 0: + # termianted_progress = progress_buf[torch.where(terminated)] + # print(torch.where(termianted_progress < 30), termianted_progress[termianted_progress < 30]) + + reset = torch.where(progress_buf >= max_episode_length - 1, torch.ones_like(reset_buf), terminated) + + return reset, terminated + +# @torch.jit.script +# def compute_humanoid_reset(reset_buf, progress_buf, contact_buf, contact_body_ids, center_height, rigid_body_pos, +# tar_pos, max_episode_length, fail_dist, +# enable_early_termination, termination_heights): +# # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, float, float, bool, Tensor) -> Tuple[Tensor, Tensor] +# # print("using plain reset") +# terminated = torch.zeros_like(reset_buf) + +# if (enable_early_termination): +# masked_contact_buf = contact_buf.clone() +# masked_contact_buf[:, contact_body_ids, :] = 0 +# fall_contact = torch.any(torch.abs(masked_contact_buf) > 0.1, dim=-1) +# fall_contact = torch.any(fall_contact, dim=-1) + +# body_height = rigid_body_pos[..., 2] +# fall_height = body_height < termination_heights +# fall_height[:, contact_body_ids] = False +# fall_height = torch.any(fall_height, dim=-1) + +# has_fallen = torch.logical_and(fall_contact, fall_height) +# # first timestep can sometimes still have nonzero contact forces +# # so only check after first couple of steps +# has_fallen *= (progress_buf > 1) + +# root_pos = rigid_body_pos[..., 0, :] +# tar_delta = tar_pos[..., 0:2] - root_pos[..., 0:2] +# tar_dist_sq = torch.sum(tar_delta * tar_delta, dim=-1) +# tar_fail = tar_dist_sq > fail_dist * fail_dist + +# has_failed = torch.logical_or(has_fallen, tar_fail) + +# terminated = torch.where(has_failed, torch.ones_like(reset_buf), terminated) + +# reset = torch.where(progress_buf >= max_episode_length - 1, torch.ones_like(reset_buf), terminated) + +# return reset, terminated + + +@torch.jit.script +def quat_apply_yaw(quat, vec): + quat_yaw = quat.clone().view(-1, 4) + quat_yaw[:, :2] = 0. + quat_yaw = normalize(quat_yaw) + return quat_apply(quat_yaw, vec) + + +@torch.jit.script +def wrap_to_pi(angles): + angles %= 2 * np.pi + angles -= 2 * np.pi * (angles > np.pi) + return angles + + +# Task-location +@torch.jit.script +def compute_location_observations(root_states, traj_samples, upright): + # type: (Tensor, Tensor, bool) -> Tensor + root_pos = root_states[:, 0:3] + root_rot = root_states[:, 3:7] + if not upright: + root_rot = remove_base_rot(root_rot) + heading_rot = torch_utils.calc_heading_quat_inv(root_rot) + + heading_rot_exp = torch.broadcast_to( + heading_rot.unsqueeze(-2), + (heading_rot.shape[0], traj_samples.shape[1], heading_rot.shape[1])) + heading_rot_exp = torch.reshape( + heading_rot_exp, (heading_rot_exp.shape[0] * heading_rot_exp.shape[1], + heading_rot_exp.shape[2])) + traj_samples_delta = traj_samples - root_pos.unsqueeze(-2) + traj_samples_delta_flat = torch.reshape( + traj_samples_delta, + (traj_samples_delta.shape[0] * traj_samples_delta.shape[1], + traj_samples_delta.shape[2])) + + local_traj_pos = torch_utils.my_quat_rotate(heading_rot_exp, + traj_samples_delta_flat) + local_traj_pos = local_traj_pos[..., 0:2] + + obs = torch.reshape(local_traj_pos, + (traj_samples.shape[0], + traj_samples.shape[1] * local_traj_pos.shape[1])) + return obs + + + +@torch.jit.script +def compute_location_reward(root_pos, tar_pos): + # type: (Tensor, Tensor) -> Tensor + pos_err_scale = 2.0 + + pos_diff = tar_pos[..., 0:2] - root_pos[..., 0:2] + pos_err = torch.sum(pos_diff * pos_diff, dim=-1) + pos_reward = torch.exp(-pos_err_scale * pos_err) + + reward = pos_reward + + return reward + +@torch.jit.script +def compute_location_reward_fuzzy(root_pos, tar_pos): + # type: (Tensor, Tensor) -> Tensor + pos_err_scale = 2.0 + radius = 0.0025 + pos_diff = tar_pos[..., 0:2] - root_pos[..., 0:2] + + pos_err = torch.sum(pos_diff * pos_diff, dim=-1) + pos_err[pos_err < radius] = 0 # 5cm radius around target is perfect. + + pos_reward = torch.exp(-pos_err_scale * pos_err) + + reward = pos_reward + + return reward + + +# @torch.jit.script +# def compute_group_observation(body_pos, body_rot, body_vel, selected_jts, num_group_people, upright): +# # type: (Tensor, Tensor, Tensor, Tensor, int, bool) -> Tensor +# root_pos = body_pos[:, 0, :] +# root_rot = body_rot[:, 0, :] + +# root_h = root_pos[:, 2:3] +# if not upright: +# root_rot = remove_base_rot(root_rot) +# heading_rot = torch_utils.calc_heading_quat_inv(root_rot) + +# top_k = 5 +# num_selected_jts = len(selected_jts) + +# B, J, _ = body_pos.shape + +# repeated_root_pos = root_pos.repeat(B, 1).view(B, B, -1) + +# dist = torch.norm(root_pos[..., None, :] - repeated_root_pos, dim = -1) + +# topk_dist, topk_idx = torch.topk(dist, top_k + 1, dim = -1, largest = False) +# topk_dist, topk_idx = topk_dist[..., 1:], topk_idx[..., 1:] ## ZL should cap distance +# topk_mask = (topk_dist > 10).view(-1) + +# selected_idxes = topk_idx.flatten() +# selected_pos = body_pos[selected_idxes][:, selected_jts].view(B, -1, 3) +# selected_vel = body_vel[selected_idxes][:, [0]].view(B, -1, 3) + +# heading_rot_expand = heading_rot.unsqueeze(-2).repeat((1, selected_pos.shape[1], 1)) +# flat_heading_rot = heading_rot_expand.view(heading_rot_expand.shape[0] * heading_rot_expand.shape[1], heading_rot_expand.shape[2]) + +# root_pos_expand = root_pos.unsqueeze(-2) +# local_body_pos = selected_pos - root_pos_expand +# flat_local_body_pos = local_body_pos.view(local_body_pos.shape[0] * local_body_pos.shape[1], local_body_pos.shape[2]) +# flat_local_body_pos = torch_utils.my_quat_rotate(flat_heading_rot, flat_local_body_pos) + +# flat_body_vel = selected_vel.view(selected_vel.shape[0] * selected_vel.shape[1], selected_vel.shape[2]) +# heading_rot_expand = heading_rot.unsqueeze(-2).repeat((1, selected_vel.shape[1], 1)) +# flat_heading_rot = heading_rot_expand.view(heading_rot_expand.shape[0] * heading_rot_expand.shape[1], heading_rot_expand.shape[2]) +# flat_local_body_vel = torch_utils.my_quat_rotate(flat_heading_rot, flat_body_vel) + +# local_body_pos = flat_local_body_pos.view(-1, num_selected_jts, 3) +# local_body_vel = flat_local_body_vel.view(-1, 1, 3) + +# local_body_pos[topk_mask], local_body_vel[topk_mask] = 0, 0 +# local_body_pos, local_body_vel = local_body_pos.view(B, -1), local_body_vel.view(B, -1) +# group_obs = torch.cat((local_body_pos, local_body_vel), dim = -1) +# return group_obs + + + +@torch.jit.script +def compute_group_observation(body_pos, body_rot, body_vel, selected_jts, num_group_people, upright): + # type: (Tensor, Tensor, Tensor, Tensor, int, bool) -> Tensor + # joints + root velocities + root_pos = body_pos[:, 0, :] + root_rot = body_rot[:, 0, :] + + root_h = root_pos[:, 2:3] + if not upright: + root_rot = remove_base_rot(root_rot) + heading_rot = torch_utils.calc_heading_quat_inv(root_rot) + + top_k = 5 + selected_jts = [0, 1, 5, 9, 3, 7, 16, 21, 18, 23] + num_selected_jts = len(selected_jts) + + B, J, _ = body_pos.shape + group_pos = body_pos.view(-1, num_group_people, J, 3) + group_vel = body_vel.view(-1, num_group_people, J, 3) + group_root_pos = group_pos[..., 0, :] + repeated_root_pos = group_root_pos.repeat(1, num_group_people, 1).view(-1, num_group_people, num_group_people, 3) + + indexes = torch.arange(B).to(body_pos.device) + dist = torch.norm(group_root_pos[..., None, :] - repeated_root_pos, dim = -1) + topk_dist, topk_idx = torch.topk(dist, top_k + 1, dim = -1, largest = False) + topk_dist, topk_idx = topk_dist[..., 1:], topk_idx[..., 1:] + topk_mask = (topk_dist > 10).view(-1) + + repeated_indexes = indexes.view(-1, num_group_people).repeat(1, num_group_people).view(-1, num_group_people, num_group_people) + selected_idxes = torch.gather(repeated_indexes, -1, topk_idx).flatten() + + selected_pos = body_pos[selected_idxes][:, selected_jts].view(B, -1, 3) + selected_vel = body_vel[selected_idxes][:, [0]].view(B, -1, 3) + + heading_rot_expand = heading_rot.unsqueeze(-2).repeat((1, selected_pos.shape[1], 1)) + flat_heading_rot = heading_rot_expand.view(heading_rot_expand.shape[0] * heading_rot_expand.shape[1], heading_rot_expand.shape[2]) + + root_pos_expand = root_pos.unsqueeze(-2) + local_body_pos = selected_pos - root_pos_expand + flat_local_body_pos = local_body_pos.view(local_body_pos.shape[0] * local_body_pos.shape[1], local_body_pos.shape[2]) + flat_local_body_pos = torch_utils.my_quat_rotate(flat_heading_rot, flat_local_body_pos) + + flat_body_vel = selected_vel.view(selected_vel.shape[0] * selected_vel.shape[1], selected_vel.shape[2]) + heading_rot_expand = heading_rot.unsqueeze(-2).repeat((1, selected_vel.shape[1], 1)) + flat_heading_rot = heading_rot_expand.view(heading_rot_expand.shape[0] * heading_rot_expand.shape[1], heading_rot_expand.shape[2]) + flat_local_body_vel = torch_utils.my_quat_rotate(flat_heading_rot, flat_body_vel) + + local_body_pos = flat_local_body_pos.view(-1, num_selected_jts, 3) + local_body_vel = flat_local_body_vel.view(-1, 1, 3) + + local_body_pos[topk_mask], local_body_vel[topk_mask] = 0, 0 + group_obs = torch.cat([local_body_pos.view(B, -1, top_k, 3), local_body_vel.view(B, -1, top_k, 3)], dim = 1).view(B, -1) + + return group_obs diff --git a/phc/env/tasks/humanoid_reach.py b/phc/env/tasks/humanoid_reach.py new file mode 100644 index 0000000..411f368 --- /dev/null +++ b/phc/env/tasks/humanoid_reach.py @@ -0,0 +1,417 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import torch + +import env.tasks.humanoid as humanoid +import env.tasks.humanoid_amp as humanoid_amp +import env.tasks.humanoid_amp_task as humanoid_amp_task +from utils import torch_utils + +from isaacgym import gymapi +from isaacgym import gymtorch +from isaacgym.torch_utils import * + +class HumanoidReach(humanoid_amp_task.HumanoidAMPTask): + def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, headless): + self._tar_speed = cfg["env"]["tarSpeed"] + self._tar_change_steps_min = cfg["env"]["tarChangeStepsMin"] + self._tar_change_steps_max = cfg["env"]["tarChangeStepsMax"] + self._tar_dist_max = cfg["env"]["tarDistMax"] + self._tar_height_min = cfg["env"]["tarHeightMin"] + self._tar_height_max = cfg["env"]["tarHeightMax"] + + super().__init__(cfg=cfg, + sim_params=sim_params, + physics_engine=physics_engine, + device_type=device_type, + device_id=device_id, + headless=headless) + + self._tar_change_steps = torch.zeros([self.num_envs], device=self.device, dtype=torch.int64) + self._tar_pos = torch.zeros([self.num_envs, 3], device=self.device, dtype=torch.float) + + reach_body_name = cfg["env"]["reachBodyName"] + self._reach_body_id = self._build_reach_body_id_tensor(self.envs[0], self.humanoid_handles[0], reach_body_name) + + if (not self.headless): + self._build_marker_state_tensors() + + return + + def _sample_ref_state(self, env_ids): + motion_ids, motion_times, root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, rb_pos, rb_rot, body_vel, body_ang_vel = super()._sample_ref_state(env_ids) + root_pos[..., :2] = 0.0 # Set the root position to be zero + return motion_ids, motion_times, root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, rb_pos, rb_rot, body_vel, body_ang_vel + + def get_task_obs_size(self): + obs_size = 0 + if (self._enable_task_obs): + obs_size = 3 + return obs_size + + def post_physics_step(self): + super().post_physics_step() + + if (humanoid_amp.HACK_OUTPUT_MOTION): + self._hack_output_motion_target() + + return + + def _update_marker(self): + self._marker_pos[..., :] = self._tar_pos + self.gym.set_actor_root_state_tensor_indexed(self.sim, gymtorch.unwrap_tensor(self._root_states), + gymtorch.unwrap_tensor(self._marker_actor_ids), len(self._marker_actor_ids)) + return + + def _create_envs(self, num_envs, spacing, num_per_row): + if (not self.headless): + self._marker_handles = [] + self._load_marker_asset() + + super()._create_envs(num_envs, spacing, num_per_row) + return + + def _load_marker_asset(self): + asset_root = "pulse/data/assets/mjcf/" + asset_file = "location_marker.urdf" + + asset_options = gymapi.AssetOptions() + asset_options.angular_damping = 0.01 + asset_options.linear_damping = 0.01 + asset_options.max_angular_velocity = 100.0 + asset_options.density = 1.0 + asset_options.fix_base_link = True + asset_options.default_dof_drive_mode = gymapi.DOF_MODE_NONE + + self._marker_asset = self.gym.load_asset(self.sim, asset_root, asset_file, asset_options) + + return + + def _build_env(self, env_id, env_ptr, humanoid_asset): + super()._build_env(env_id, env_ptr, humanoid_asset) + + if (not self.headless): + self._build_marker(env_id, env_ptr) + + return + + def _build_marker(self, env_id, env_ptr): + default_pose = gymapi.Transform() + + marker_handle = self.gym.create_actor(env_ptr, self._marker_asset, default_pose, "marker", env_id, 2, 2) + self.gym.set_rigid_body_color(env_ptr, marker_handle, 0, gymapi.MESH_VISUAL, gymapi.Vec3(0.8, 0.0, 0.0)) + self._marker_handles.append(marker_handle) + + return + + def _build_marker_state_tensors(self): + num_actors = self._root_states.shape[0] // self.num_envs + self._marker_states = self._root_states.view(self.num_envs, num_actors, self._root_states.shape[-1])[..., 1, :] + self._marker_pos = self._marker_states[..., :3] + + self._marker_actor_ids = self._humanoid_actor_ids + 1 + + return + + def _build_reach_body_id_tensor(self, env_ptr, actor_handle, body_name): + body_id = self.gym.find_actor_rigid_body_handle(env_ptr, actor_handle, body_name) + assert(body_id != -1) + body_id = to_torch(body_id, device=self.device, dtype=torch.long) + return body_id + + def _update_task(self): + reset_task_mask = self.progress_buf >= self._tar_change_steps + rest_env_ids = reset_task_mask.nonzero(as_tuple=False).flatten() + if len(rest_env_ids) > 0: + self._reset_task(rest_env_ids) + return + + def _reset_task(self, env_ids): + n = len(env_ids) + + rand_pos = torch.rand([n, 3], device=self.device) + rand_pos[..., 0:2] = self._tar_dist_max * (2.0 * rand_pos[..., 0:2] - 1.0) + rand_pos[..., 2] = (self._tar_height_max - self._tar_height_min) * rand_pos[..., 2] + self._tar_height_min + + change_steps = torch.randint(low=self._tar_change_steps_min, high=self._tar_change_steps_max, + size=(n,), device=self.device, dtype=torch.int64) + + self._tar_pos[env_ids, :] = rand_pos + self._tar_change_steps[env_ids] = self.progress_buf[env_ids] + change_steps + return + + def _compute_task_obs(self, env_ids=None): + if (env_ids is None): + root_states = self._humanoid_root_states + tar_pos = self._tar_pos + else: + root_states = self._humanoid_root_states[env_ids] + tar_pos = self._tar_pos[env_ids] + + obs = compute_location_observations(root_states, tar_pos) + return obs + + def _compute_reward(self, actions): + reach_body_pos = self._rigid_body_pos[:, self._reach_body_id, :] + root_rot = self._humanoid_root_states[..., 3:7] + + self.rew_buf[:] = compute_reach_reward(reach_body_pos, root_rot, + self._tar_pos, self._tar_speed, + self.dt) + return + + def _draw_task(self): + self._update_marker() + + cols = np.array([[0.0, 1.0, 0.0]], dtype=np.float32) + + self.gym.clear_lines(self.viewer) + + starts = self._rigid_body_pos[:, self._reach_body_id, :] + ends = self._tar_pos + verts = torch.cat([starts, ends], dim=-1).cpu().numpy() + + for i, env_ptr in enumerate(self.envs): + curr_verts = verts[i] + curr_verts = curr_verts.reshape([1, 6]) + self.gym.add_lines(self.viewer, env_ptr, curr_verts.shape[0], curr_verts, cols) + + return + + def _hack_output_motion_target(self): + if (not hasattr(self, '_output_motion_target_pos')): + self._output_motion_target_pos = [] + + tar_pos = self._tar_pos[0].cpu().numpy() + self._output_motion_target_pos.append(tar_pos) + + reset = self.reset_buf[0].cpu().numpy() == 1 + + if (reset and len(self._output_motion_target_pos) > 1): + output_data = np.array(self._output_motion_target_pos) + np.save('output/record_tar_motion.npy', output_data) + + self._output_motion_target_pos = [] + + return + +class HumanoidReachZ(humanoid_amp_task.HumanoidAMPZTask): + def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, headless): + self._tar_speed = cfg["env"]["tarSpeed"] + self._tar_change_steps_min = cfg["env"]["tarChangeStepsMin"] + self._tar_change_steps_max = cfg["env"]["tarChangeStepsMax"] + self._tar_dist_max = cfg["env"]["tarDistMax"] + self._tar_height_min = cfg["env"]["tarHeightMin"] + self._tar_height_max = cfg["env"]["tarHeightMax"] + + super().__init__(cfg=cfg, + sim_params=sim_params, + physics_engine=physics_engine, + device_type=device_type, + device_id=device_id, + headless=headless) + + self._tar_change_steps = torch.zeros([self.num_envs], device=self.device, dtype=torch.int64) + self._tar_pos = torch.zeros([self.num_envs, 3], device=self.device, dtype=torch.float) + + reach_body_name = cfg["env"]["reachBodyName"] + self._reach_body_id = self._build_reach_body_id_tensor(self.envs[0], self.humanoid_handles[0], reach_body_name) + + if (not self.headless): + self._build_marker_state_tensors() + + return + + def get_task_obs_size(self): + obs_size = 0 + if (self._enable_task_obs): + obs_size = 3 + return obs_size + + def post_physics_step(self): + super().post_physics_step() + + if (humanoid_amp.HACK_OUTPUT_MOTION): + self._hack_output_motion_target() + + return + + def _sample_ref_state(self, env_ids): + motion_ids, motion_times, root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, rb_pos, rb_rot, body_vel, body_ang_vel = super()._sample_ref_state(env_ids) + root_pos[..., :2] = 0.0 # Set the root position to be zero + return motion_ids, motion_times, root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, rb_pos, rb_rot, body_vel, body_ang_vel + + def _update_marker(self): + self._marker_pos[..., :] = self._tar_pos + self.gym.set_actor_root_state_tensor_indexed(self.sim, gymtorch.unwrap_tensor(self._root_states), + gymtorch.unwrap_tensor(self._marker_actor_ids), len(self._marker_actor_ids)) + return + + def _create_envs(self, num_envs, spacing, num_per_row): + if (not self.headless): + self._marker_handles = [] + self._load_marker_asset() + + super()._create_envs(num_envs, spacing, num_per_row) + return + + def _load_marker_asset(self): + asset_root = "pulse/data/assets/mjcf/" + asset_file = "location_marker.urdf" + + asset_options = gymapi.AssetOptions() + asset_options.angular_damping = 0.01 + asset_options.linear_damping = 0.01 + asset_options.max_angular_velocity = 100.0 + asset_options.density = 1.0 + asset_options.fix_base_link = True + asset_options.default_dof_drive_mode = gymapi.DOF_MODE_NONE + + self._marker_asset = self.gym.load_asset(self.sim, asset_root, asset_file, asset_options) + + return + + def _build_env(self, env_id, env_ptr, humanoid_asset): + super()._build_env(env_id, env_ptr, humanoid_asset) + + if (not self.headless): + self._build_marker(env_id, env_ptr) + + return + + def _build_marker(self, env_id, env_ptr): + default_pose = gymapi.Transform() + + marker_handle = self.gym.create_actor(env_ptr, self._marker_asset, default_pose, "marker", env_id, 2, 2) + self.gym.set_rigid_body_color(env_ptr, marker_handle, 0, gymapi.MESH_VISUAL, gymapi.Vec3(0.8, 0.0, 0.0)) + self._marker_handles.append(marker_handle) + + return + + def _build_marker_state_tensors(self): + num_actors = self._root_states.shape[0] // self.num_envs + self._marker_states = self._root_states.view(self.num_envs, num_actors, self._root_states.shape[-1])[..., 1, :] + self._marker_pos = self._marker_states[..., :3] + + self._marker_actor_ids = self._humanoid_actor_ids + 1 + + return + + def _build_reach_body_id_tensor(self, env_ptr, actor_handle, body_name): + body_id = self.gym.find_actor_rigid_body_handle(env_ptr, actor_handle, body_name) + assert(body_id != -1) + body_id = to_torch(body_id, device=self.device, dtype=torch.long) + return body_id + + def _update_task(self): + reset_task_mask = self.progress_buf >= self._tar_change_steps + rest_env_ids = reset_task_mask.nonzero(as_tuple=False).flatten() + if len(rest_env_ids) > 0: + self._reset_task(rest_env_ids) + return + + def _reset_task(self, env_ids): + n = len(env_ids) + + rand_pos = torch.rand([n, 3], device=self.device) + rand_pos[..., 0:2] = self._tar_dist_max * (2.0 * rand_pos[..., 0:2] - 1.0) + rand_pos[..., 2] = (self._tar_height_max - self._tar_height_min) * rand_pos[..., 2] + self._tar_height_min + + change_steps = torch.randint(low=self._tar_change_steps_min, high=self._tar_change_steps_max, + size=(n,), device=self.device, dtype=torch.int64) + + self._tar_pos[env_ids, :] = rand_pos + self._tar_change_steps[env_ids] = self.progress_buf[env_ids] + change_steps + return + + def _compute_task_obs(self, env_ids=None): + if (env_ids is None): + root_states = self._humanoid_root_states + tar_pos = self._tar_pos + else: + root_states = self._humanoid_root_states[env_ids] + tar_pos = self._tar_pos[env_ids] + + obs = compute_location_observations(root_states, tar_pos) + return obs + + def _compute_reward(self, actions): + reach_body_pos = self._rigid_body_pos[:, self._reach_body_id, :] + root_rot = self._humanoid_root_states[..., 3:7] + + self.rew_buf[:] = compute_reach_reward(reach_body_pos, root_rot, + self._tar_pos, self._tar_speed, + self.dt) + return + + def _draw_task(self): + self._update_marker() + + cols = np.array([[0.0, 1.0, 0.0]], dtype=np.float32) + + self.gym.clear_lines(self.viewer) + + starts = self._rigid_body_pos[:, self._reach_body_id, :] + ends = self._tar_pos + + verts = torch.cat([starts, ends], dim=-1).cpu().numpy() + + for i, env_ptr in enumerate(self.envs): + curr_verts = verts[i] + curr_verts = curr_verts.reshape([1, 6]) + self.gym.add_lines(self.viewer, env_ptr, curr_verts.shape[0], curr_verts, cols) + + return + + def _hack_output_motion_target(self): + if (not hasattr(self, '_output_motion_target_pos')): + self._output_motion_target_pos = [] + + tar_pos = self._tar_pos[0].cpu().numpy() + self._output_motion_target_pos.append(tar_pos) + + reset = self.reset_buf[0].cpu().numpy() == 1 + + if (reset and len(self._output_motion_target_pos) > 1): + output_data = np.array(self._output_motion_target_pos) + np.save('output/record_tar_motion.npy', output_data) + + self._output_motion_target_pos = [] + + return + +##################################################################### +###=========================jit functions=========================### +##################################################################### + +@torch.jit.script +def compute_location_observations(root_states, tar_pos): + # type: (Tensor, Tensor) -> Tensor + root_pos = root_states[:, 0:3] + root_rot = root_states[:, 3:7] + + heading_rot_inv = torch_utils.calc_heading_quat_inv(root_rot) + local_tar_pos = tar_pos - root_pos + + local_tar_pos = torch_utils.my_quat_rotate(heading_rot_inv, local_tar_pos) + + obs = local_tar_pos + return obs + +@torch.jit.script +def compute_reach_reward(reach_body_pos, root_rot, tar_pos, tar_speed, dt): + # type: (Tensor, Tensor, Tensor, float, float) -> Tensor + pos_err_scale = 4.0 + + pos_diff = tar_pos - reach_body_pos + pos_err = torch.sum(pos_diff * pos_diff, dim=-1) + pos_reward = torch.exp(-pos_err_scale * pos_err) + + reward = pos_reward + + return reward \ No newline at end of file diff --git a/phc/env/tasks/humanoid_speed.py b/phc/env/tasks/humanoid_speed.py new file mode 100644 index 0000000..129bcae --- /dev/null +++ b/phc/env/tasks/humanoid_speed.py @@ -0,0 +1,565 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import torch + +import env.tasks.humanoid as humanoid +import env.tasks.humanoid_amp as humanoid_amp +import env.tasks.humanoid_amp_task as humanoid_amp_task +from utils import torch_utils + +from isaacgym import gymapi +from isaacgym import gymtorch +from isaacgym.torch_utils import * +from scipy.spatial.transform import Rotation as sRot +from phc.utils.flags import flags + +TAR_ACTOR_ID = 1 + +class HumanoidSpeed(humanoid_amp_task.HumanoidAMPTask): + def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, headless): + self._tar_speed_min = cfg["env"]["tarSpeedMin"] + self._tar_speed_max = cfg["env"]["tarSpeedMax"] + self._speed_change_steps_min = cfg["env"]["speedChangeStepsMin"] + self._speed_change_steps_max = cfg["env"]["speedChangeStepsMax"] + + self._add_input_noise = cfg["env"].get("addInputNoise", False) + + super().__init__(cfg=cfg, + sim_params=sim_params, + physics_engine=physics_engine, + device_type=device_type, + device_id=device_id, + headless=headless) + + self._speed_change_steps = torch.zeros([self.num_envs], device=self.device, dtype=torch.int64) + self._prev_root_pos = torch.zeros([self.num_envs, 3], device=self.device, dtype=torch.float) + self._tar_speed = torch.ones([self.num_envs], device=self.device, dtype=torch.float) + + self.power_usage_reward = cfg["env"].get("power_usage_reward", False) + reward_raw_num = 1 + if self.power_usage_reward: + reward_raw_num += 1 + if self.power_reward: + reward_raw_num += 1 + + self.reward_raw = torch.zeros((self.num_envs, reward_raw_num)).to(self.device) + self.power_coefficient = cfg["env"].get("power_coefficient", 0.0005) + self.power_usage_coefficient = cfg["env"].get("power_usage_coefficient", 0.0025) + self.power_acc = torch.zeros((self.num_envs, 2 )).to(self.device) + + if (not self.headless): + self._build_marker_state_tensors() + + return + + def get_task_obs_size(self): + obs_size = 0 + if (self._enable_task_obs): + obs_size = 3 + + if (self._add_input_noise): + obs_size += 16 + + if self.obs_v == 2: + obs_size *= self.past_track_steps + + return obs_size + + def pre_physics_step(self, actions): + super().pre_physics_step(actions) + self._prev_root_pos[:] = self._humanoid_root_states[..., 0:3] + return + + def post_physics_step(self): + super().post_physics_step() + + if (humanoid_amp.HACK_OUTPUT_MOTION): + self._hack_output_motion_target() + + return + + def _update_marker(self): + humanoid_root_pos = self._humanoid_root_states[..., 0:3] + self._marker_pos[..., 0:2] = humanoid_root_pos[..., 0:2] + self._marker_pos[..., 0] += 0.5 + 0.2 * self._tar_speed + self._marker_pos[..., 2] = 0.0 + + + self._marker_rot[:] = 0 + self._marker_rot[:, -1] = 1.0 + + + self.gym.set_actor_root_state_tensor_indexed(self.sim, gymtorch.unwrap_tensor(self._root_states), + gymtorch.unwrap_tensor(self._marker_actor_ids), + len(self._marker_actor_ids)) + return + + def _create_envs(self, num_envs, spacing, num_per_row): + if (not self.headless): + self._marker_handles = [] + self._load_marker_asset() + + super()._create_envs(num_envs, spacing, num_per_row) + return + + def _load_marker_asset(self): + asset_root = "pulse/data/assets/mjcf/" + asset_file = "heading_marker.urdf" + + asset_options = gymapi.AssetOptions() + asset_options.angular_damping = 0 + asset_options.linear_damping = 0 + asset_options.max_angular_velocity = 0 + asset_options.density = 0 + asset_options.fix_base_link = True + asset_options.default_dof_drive_mode = gymapi.DOF_MODE_NONE + + self._marker_asset = self.gym.load_asset(self.sim, asset_root, asset_file, asset_options) + + return + + def _build_env(self, env_id, env_ptr, humanoid_asset): + super()._build_env(env_id, env_ptr, humanoid_asset) + + if (not self.headless): + self._build_marker(env_id, env_ptr) + + return + + def _build_marker(self, env_id, env_ptr): + default_pose = gymapi.Transform() + default_pose.p.x = 1.0 + default_pose.p.z = 0.0 + + marker_handle = self.gym.create_actor(env_ptr, self._marker_asset, default_pose, "marker", self.num_envs + 10, 1, 0) + self.gym.set_rigid_body_color(env_ptr, marker_handle, 0, gymapi.MESH_VISUAL, gymapi.Vec3(0.8, 0.0, 0.0)) + self._marker_handles.append(marker_handle) + + + return + + def _build_marker_state_tensors(self): + num_actors = self._root_states.shape[0] // self.num_envs + + self._marker_states = self._root_states.view(self.num_envs, num_actors, self._root_states.shape[-1])[..., TAR_ACTOR_ID, :] + self._marker_pos = self._marker_states[..., :3] + self._marker_rot = self._marker_states[..., 3:7] + self._marker_actor_ids = self._humanoid_actor_ids + to_torch(self._marker_handles, device=self.device, dtype=torch.int32) + + + return + + def _update_task(self): + reset_task_mask = self.progress_buf >= self._speed_change_steps + rest_env_ids = reset_task_mask.nonzero(as_tuple=False).flatten() + + + if len(rest_env_ids) > 0: + self._reset_task(rest_env_ids) + return + + def _reset_task(self, env_ids): + n = len(env_ids) + + tar_speed = (self._tar_speed_max - self._tar_speed_min) * torch.rand(n, device=self.device) + self._tar_speed_min + change_steps = torch.randint(low=self._speed_change_steps_min, high=self._speed_change_steps_max, + size=(n,), device=self.device, dtype=torch.int64) + + self._tar_speed[env_ids] = tar_speed + self._speed_change_steps[env_ids] = self.progress_buf[env_ids] + change_steps + return + + def _compute_flip_task_obs(self, normal_task_obs, env_ids): + B, D = normal_task_obs.shape + flip_task_obs = normal_task_obs.clone() + flip_task_obs[:, 1] = -flip_task_obs[:, 1] + + return flip_task_obs + + def _compute_task_obs(self, env_ids=None): + if (env_ids is None): + root_states = self._humanoid_root_states + tar_speed = self._tar_speed + else: + root_states = self._humanoid_root_states[env_ids] + tar_speed = self._tar_speed[env_ids] + + obs = compute_speed_observations(root_states, tar_speed) + + if self._add_input_noise: + obs = torch.cat([obs, torch.randn((obs.shape[0], 16)).to(obs) * 0.1], dim=-1) + + return obs + + def _compute_reward(self, actions): + root_pos = self._humanoid_root_states[..., 0:3] + root_rot = self._humanoid_root_states[..., 3:7] + + # if False: + if flags.test: + root_pos = self._humanoid_root_states[..., 0:3] + delta_root_pos = root_pos - self._prev_root_pos + root_vel = delta_root_pos / self.dt + tar_dir_speed = root_vel[..., 0] + # print(self._tar_speed, tar_dir_speed) + + self.rew_buf[:] = self.reward_raw = compute_speed_reward(root_pos, self._prev_root_pos, root_rot, self._tar_speed, self.dt) + self.reward_raw = self.reward_raw[:, None] + + # if True: + if self.power_reward: + power_all = torch.abs(torch.multiply(self.dof_force_tensor, self._dof_vel)) + power = power_all.sum(dim=-1) + power_reward = -self.power_coefficient * power + power_reward[self.progress_buf <= 3] = 0 # First 3 frame power reward should not be counted. since they could be dropped. + + self.rew_buf[:] += power_reward + self.reward_raw = torch.cat([self.reward_raw, power_reward[:, None]], dim=-1) + + # if True: + if self.power_usage_reward: + power_all = torch.abs(torch.multiply(self.dof_force_tensor, self._dof_vel)) + power_all = power_all.reshape(-1, 23, 3) + left_power = power_all[:, self.left_indexes].reshape(self.num_envs, -1).sum(dim = -1) + right_power = power_all[:, self.right_indexes].reshape(self.num_envs, -1).sum(dim = -1) + self.power_acc[:, 0] += left_power + self.power_acc[:, 1] += right_power + power_usage_reward = self.power_acc/(self.progress_buf + 1)[:, None] + # print((power_usage_reward[:, 0] - power_usage_reward[:, 1]).abs()) + power_usage_reward = - self.power_usage_coefficient * (power_usage_reward[:, 0] - power_usage_reward[:, 1]).abs() + power_usage_reward[self.progress_buf <= 3] = 0 # First 3 frame power reward should not be counted. since they could be dropped. on the ground to balance. + + self.rew_buf[:] += power_usage_reward + self.reward_raw = torch.cat([self.reward_raw, power_usage_reward[:, None]], dim=-1) + + + return + + def _draw_task(self): + self._update_marker() + return + + def _reset_ref_state_init(self, env_ids): + super()._reset_ref_state_init(env_ids) + self.power_acc[env_ids] = 0 + + def _sample_ref_state(self, env_ids): + motion_ids, motion_times, root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, rb_pos, rb_rot, body_vel, body_ang_vel = super()._sample_ref_state(env_ids) + + # ZL Hack: Forcing to always be facing the x-direction. + heading_rot_inv = torch_utils.calc_heading_quat_inv(root_rot) + heading_rot_inv_repeat = heading_rot_inv[:, None].repeat(1, 24, 1) + root_rot = quat_mul(heading_rot_inv, root_rot).clone() + rb_pos = quat_apply(heading_rot_inv_repeat, rb_pos - root_pos[:, None, :]).clone() + root_pos[:, None, :] + rb_rot = quat_mul(heading_rot_inv_repeat, rb_rot).clone() + root_ang_vel = quat_apply(heading_rot_inv, root_ang_vel).clone() + root_vel = quat_apply(heading_rot_inv, root_vel).clone() + body_vel = quat_apply(heading_rot_inv_repeat, body_vel).clone() + + return motion_ids, motion_times, root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, rb_pos, rb_rot, body_vel, body_ang_vel + + def _hack_output_motion_target(self): + if (not hasattr(self, '_output_motion_target_speed')): + self._output_motion_target_speed = [] + + tar_speed = self._tar_speed[0].cpu().numpy() + self._output_motion_target_speed.append(tar_speed) + + reset = self.reset_buf[0].cpu().numpy() == 1 + + if (reset and len(self._output_motion_target_speed) > 1): + output_data = np.array(self._output_motion_target_speed) + np.save('output/record_tar_speed.npy', output_data) + + self._output_motion_target_speed = [] + + return + +class HumanoidSpeedZ(humanoid_amp_task.HumanoidAMPZTask): + def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, headless): + self._tar_speed_min = cfg["env"]["tarSpeedMin"] + self._tar_speed_max = cfg["env"]["tarSpeedMax"] + self._speed_change_steps_min = cfg["env"]["speedChangeStepsMin"] + self._speed_change_steps_max = cfg["env"]["speedChangeStepsMax"] + + super().__init__(cfg=cfg, + sim_params=sim_params, + physics_engine=physics_engine, + device_type=device_type, + device_id=device_id, + headless=headless) + + self._speed_change_steps = torch.zeros([self.num_envs], device=self.device, dtype=torch.int64) + self._prev_root_pos = torch.zeros([self.num_envs, 3], device=self.device, dtype=torch.float) + self._tar_speed = torch.ones([self.num_envs], device=self.device, dtype=torch.float) + + self.power_usage_reward = cfg["env"].get("power_usage_reward", False) + reward_raw_num = 1 + if self.power_usage_reward: + reward_raw_num += 1 + if self.power_reward: + reward_raw_num += 1 + + self.reward_raw = torch.zeros((self.num_envs, reward_raw_num)).to(self.device) + self.power_coefficient = cfg["env"].get("power_coefficient", 0.0005) + self.power_usage_coefficient = cfg["env"].get("power_usage_coefficient", 0.0025) + self.power_acc = torch.zeros((self.num_envs, 2 )).to(self.device) + + + if (not self.headless): + self._build_marker_state_tensors() + + return + + def _sample_ref_state(self, env_ids): + motion_ids, motion_times, root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, rb_pos, rb_rot, body_vel, body_ang_vel = super()._sample_ref_state(env_ids) + + # ZL Hack: Forcing to always be facing the x-direction. + heading_rot_inv = torch_utils.calc_heading_quat_inv(root_rot) + heading_rot_inv_repeat = heading_rot_inv[:, None].repeat(1, 24, 1) + root_rot = quat_mul(heading_rot_inv, root_rot).clone() + rb_pos = quat_apply(heading_rot_inv_repeat, rb_pos - root_pos[:, None, :]).clone() + root_pos[:, None, :] + rb_rot = quat_mul(heading_rot_inv_repeat, rb_rot).clone() + root_ang_vel = quat_apply(heading_rot_inv, root_ang_vel).clone() + root_vel = quat_apply(heading_rot_inv, root_vel).clone() + body_vel = quat_apply(heading_rot_inv_repeat, body_vel).clone() + + return motion_ids, motion_times, root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, rb_pos, rb_rot, body_vel, body_ang_vel + + + def get_task_obs_size(self): + obs_size = 0 + if (self._enable_task_obs): + obs_size = 3 + + if self.obs_v == 2: + obs_size *= self.past_track_steps + + return obs_size + + def pre_physics_step(self, actions): + super().pre_physics_step(actions) + self._prev_root_pos[:] = self._humanoid_root_states[..., 0:3] + return + + def post_physics_step(self): + super().post_physics_step() + + if (humanoid_amp.HACK_OUTPUT_MOTION): + self._hack_output_motion_target() + + return + + def _update_marker(self): + humanoid_root_pos = self._humanoid_root_states[..., 0:3] + self._marker_pos[..., 0:2] = humanoid_root_pos[..., 0:2] + self._marker_pos[..., 0] += 0.5 + 0.2 * self._tar_speed + self._marker_pos[..., 2] = 0.0 + + self._marker_rot[:] = 0 + self._marker_rot[:, -1] = 1.0 + + self.gym.set_actor_root_state_tensor_indexed(self.sim, gymtorch.unwrap_tensor(self._root_states), + gymtorch.unwrap_tensor(self._marker_actor_ids), + len(self._marker_actor_ids)) + return + + def _create_envs(self, num_envs, spacing, num_per_row): + if (not self.headless): + self._marker_handles = [] + self._load_marker_asset() + + super()._create_envs(num_envs, spacing, num_per_row) + return + + def _load_marker_asset(self): + asset_root = "pulse/data/assets/mjcf/" + asset_file = "heading_marker.urdf" + + asset_options = gymapi.AssetOptions() + asset_options.angular_damping = 0.01 + asset_options.linear_damping = 0.01 + asset_options.max_angular_velocity = 100.0 + asset_options.density = 1.0 + asset_options.fix_base_link = True + asset_options.default_dof_drive_mode = gymapi.DOF_MODE_NONE + + self._marker_asset = self.gym.load_asset(self.sim, asset_root, asset_file, asset_options) + + return + + def _build_env(self, env_id, env_ptr, humanoid_asset): + super()._build_env(env_id, env_ptr, humanoid_asset) + + if (not self.headless): + self._build_marker(env_id, env_ptr) + + return + + def _build_marker(self, env_id, env_ptr): + default_pose = gymapi.Transform() + default_pose.p.x = 1.0 + default_pose.p.z = 0.0 + + marker_handle = self.gym.create_actor(env_ptr, self._marker_asset, default_pose, "marker", env_id, 2) + self.gym.set_rigid_body_color(env_ptr, marker_handle, 0, gymapi.MESH_VISUAL, gymapi.Vec3(0.8, 0.0, 0.0)) + self._marker_handles.append(marker_handle) + + return + + def _build_marker_state_tensors(self): + num_actors = self._root_states.shape[0] // self.num_envs + self._marker_states = self._root_states.view(self.num_envs, num_actors, self._root_states.shape[-1])[..., TAR_ACTOR_ID, :] + self._marker_pos = self._marker_states[..., :3] + self._marker_rot = self._marker_states[..., 3:7] + self._marker_actor_ids = self._humanoid_actor_ids + to_torch(self._marker_handles, device=self.device, dtype=torch.int32) + + return + + def _update_task(self): + reset_task_mask = self.progress_buf >= self._speed_change_steps + rest_env_ids = reset_task_mask.nonzero(as_tuple=False).flatten() + if len(rest_env_ids) > 0: + self._reset_task(rest_env_ids) + return + + def _reset_task(self, env_ids): + n = len(env_ids) + + tar_speed = (self._tar_speed_max - self._tar_speed_min) * torch.rand(n, device=self.device) + self._tar_speed_min + change_steps = torch.randint(low=self._speed_change_steps_min, high=self._speed_change_steps_max, + size=(n,), device=self.device, dtype=torch.int64) + + self._tar_speed[env_ids] = tar_speed + if len(env_ids) > 0 and flags.test: + print(self._tar_speed) + self._speed_change_steps[env_ids] = self.progress_buf[env_ids] + change_steps + return + + def _compute_task_obs(self, env_ids=None): + if (env_ids is None): + root_states = self._humanoid_root_states + tar_speed = self._tar_speed + else: + root_states = self._humanoid_root_states[env_ids] + tar_speed = self._tar_speed[env_ids] + + obs = compute_speed_observations(root_states, tar_speed) + return obs + + def _compute_flip_task_obs(self, normal_task_obs, env_ids): + B, D = normal_task_obs.shape + flip_task_obs = normal_task_obs.clone() + flip_task_obs[:, 1] = -flip_task_obs[:, 1] + + return flip_task_obs + + def _reset_ref_state_init(self, env_ids): + super()._reset_ref_state_init(env_ids) + self.power_acc[env_ids] = 0 + + def _compute_reward(self, actions): + root_pos = self._humanoid_root_states[..., 0:3] + root_rot = self._humanoid_root_states[..., 3:7] + + if flags.test: + root_pos = self._humanoid_root_states[..., 0:3] + delta_root_pos = root_pos - self._prev_root_pos + root_vel = delta_root_pos / self.dt + tar_dir_speed = root_vel[..., 0] + # print(self._tar_speed, tar_dir_speed) + + self.rew_buf[:] = self.reward_raw = compute_speed_reward(root_pos, self._prev_root_pos, root_rot, self._tar_speed, self.dt) + self.reward_raw = self.reward_raw[:, None] + # if True: + if self.power_reward: + power_all = torch.abs(torch.multiply(self.dof_force_tensor, self._dof_vel)) + power = power_all.sum(dim=-1) + power_reward = -self.power_coefficient * power + power_reward[self.progress_buf <= 3] = 0 # First 3 frame power reward should not be counted. since they could be dropped. + + self.rew_buf[:] += power_reward + self.reward_raw = torch.cat([self.reward_raw, power_reward[:, None]], dim=-1) + + if self.power_usage_reward: + power_all = torch.abs(torch.multiply(self.dof_force_tensor, self._dof_vel)) + power_all = power_all.reshape(-1, 23, 3) + left_power = power_all[:, self.left_indexes].reshape(self.num_envs, -1).sum(dim = -1) + right_power = power_all[:, self.right_indexes].reshape(self.num_envs, -1).sum(dim = -1) + self.power_acc[:, 0] += left_power + self.power_acc[:, 1] += right_power + power_usage_reward = self.power_acc/(self.progress_buf + 1)[:, None] + # print((power_usage_reward[:, 0] - power_usage_reward[:, 1]).abs()) + power_usage_reward = - self.power_usage_coefficient * (power_usage_reward[:, 0] - power_usage_reward[:, 1]).abs() + power_usage_reward[self.progress_buf <= 3] = 0 # First 3 frame power reward should not be counted. since they could be dropped. on the ground to balance. + self.rew_buf[:] += power_usage_reward + self.reward_raw = torch.cat([self.reward_raw, power_usage_reward[:, None]], dim=-1) + + + return + + def _draw_task(self): + self._update_marker() + return + + def _hack_output_motion_target(self): + if (not hasattr(self, '_output_motion_target_speed')): + self._output_motion_target_speed = [] + + tar_speed = self._tar_speed[0].cpu().numpy() + self._output_motion_target_speed.append(tar_speed) + + reset = self.reset_buf[0].cpu().numpy() == 1 + + if (reset and len(self._output_motion_target_speed) > 1): + output_data = np.array(self._output_motion_target_speed) + np.save('output/record_tar_speed.npy', output_data) + + self._output_motion_target_speed = [] + + return + +##################################################################### +###=========================jit functions=========================### +##################################################################### + +@torch.jit.script +def compute_speed_observations(root_states, tar_speed): + # type: (Tensor, Tensor) -> Tensor + root_rot = root_states[:, 3:7] + + tar_dir3d = torch.zeros_like(root_states[..., 0:3]) + tar_dir3d[..., 0] = 1 + heading_rot = torch_utils.calc_heading_quat_inv(root_rot) + + local_tar_dir = torch_utils.my_quat_rotate(heading_rot, tar_dir3d) + local_tar_dir = local_tar_dir[..., 0:2] + tar_speed = tar_speed.unsqueeze(-1) + + obs = torch.cat([local_tar_dir, tar_speed], dim=-1) + + return obs + +@torch.jit.script +def compute_speed_reward(root_pos, prev_root_pos, root_rot, tar_speed, dt): + # type: (Tensor, Tensor, Tensor, Tensor, float) -> Tensor + vel_err_scale = 0.25 + tangent_err_w = 0.1 + + delta_root_pos = root_pos - prev_root_pos + root_vel = delta_root_pos / dt + tar_dir_speed = root_vel[..., 0] + tangent_speed = root_vel[..., 1] + + tar_vel_err = tar_speed - tar_dir_speed + tangent_vel_err = tangent_speed + dir_reward = torch.exp(-vel_err_scale * (tar_vel_err * tar_vel_err + tangent_err_w * tangent_vel_err * tangent_vel_err)) + + reward = dir_reward + + return reward \ No newline at end of file diff --git a/phc/env/tasks/humanoid_strike.py b/phc/env/tasks/humanoid_strike.py new file mode 100644 index 0000000..bd39b4d --- /dev/null +++ b/phc/env/tasks/humanoid_strike.py @@ -0,0 +1,591 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import torch + +from isaacgym import gymapi, gymtorch +from isaacgym.torch_utils import * + +import env.tasks.humanoid_amp as humanoid_amp +import env.tasks.humanoid_amp_task as humanoid_amp_task +from utils import torch_utils + +class HumanoidStrike(humanoid_amp_task.HumanoidAMPTask): + def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, headless): + super().__init__(cfg=cfg, + sim_params=sim_params, + physics_engine=physics_engine, + device_type=device_type, + device_id=device_id, + headless=headless) + + self._tar_dist_min = 0.5 + self._tar_dist_max = 10.0 + self._near_dist = 1.5 + self._near_prob = 0.5 + + self._prev_root_pos = torch.zeros([self.num_envs, 3], device=self.device, dtype=torch.float) + + strike_body_names = cfg["env"]["strikeBodyNames"] + self._strike_body_ids = self._build_strike_body_ids_tensor(self.envs[0], self.humanoid_handles[0], strike_body_names) + self._build_target_tensors() + + self.power_usage_reward = cfg["env"].get("power_usage_reward", False) + self.power_coefficient = cfg["env"].get("power_coefficient", 0.0005) + self.power_usage_coefficient = cfg["env"].get("power_usage_coefficient", 0.0025) + self.power_acc = torch.zeros((self.num_envs, 2 )).to(self.device) + + + return + + def get_task_obs_size(self): + obs_size = 0 + if (self._enable_task_obs): + obs_size = 15 + return obs_size + + def post_physics_step(self): + super().post_physics_step() + + if (humanoid_amp.HACK_OUTPUT_MOTION): + self._hack_output_motion_target() + + return + + def _create_envs(self, num_envs, spacing, num_per_row): + self._target_handles = [] + self._load_target_asset() + + super()._create_envs(num_envs, spacing, num_per_row) + return + + def _build_env(self, env_id, env_ptr, humanoid_asset): + super()._build_env(env_id, env_ptr, humanoid_asset) + self._build_target(env_id, env_ptr) + return + + def _load_target_asset(self): + asset_root = "pulse/data/assets/mjcf/" + asset_file = "strike_target.urdf" + + asset_options = gymapi.AssetOptions() + asset_options.angular_damping = 0.01 + asset_options.linear_damping = 0.01 + asset_options.max_angular_velocity = 100.0 + asset_options.density = 30.0 + asset_options.default_dof_drive_mode = gymapi.DOF_MODE_NONE + + self._target_asset = self.gym.load_asset(self.sim, asset_root, asset_file, asset_options) + return + + def _build_target(self, env_id, env_ptr): + default_pose = gymapi.Transform() + default_pose.p.x = 1.0 + + target_handle = self.gym.create_actor(env_ptr, self._target_asset, default_pose, "target", env_id, 2) + self._target_handles.append(target_handle) + + return + + def _build_strike_body_ids_tensor(self, env_ptr, actor_handle, body_names): + env_ptr = self.envs[0] + actor_handle = self.humanoid_handles[0] + body_ids = [] + + for body_name in body_names: + body_id = self.gym.find_actor_rigid_body_handle(env_ptr, actor_handle, body_name) + assert(body_id != -1) + body_ids.append(body_id) + + body_ids = to_torch(body_ids, device=self.device, dtype=torch.long) + return body_ids + + def _build_target_tensors(self): + num_actors = self.get_num_actors_per_env() + self._target_states = self._root_states.view(self.num_envs, num_actors, self._root_states.shape[-1])[..., 1, :] + + self._tar_actor_ids = to_torch(num_actors * np.arange(self.num_envs), device=self.device, dtype=torch.int32) + 1 + + bodies_per_env = self._rigid_body_state.shape[0] // self.num_envs + contact_force_tensor = self.gym.acquire_net_contact_force_tensor(self.sim) + contact_force_tensor = gymtorch.wrap_tensor(contact_force_tensor) + self._tar_contact_forces = contact_force_tensor.view(self.num_envs, bodies_per_env, 3)[..., self.num_bodies, :] + + return + + def _reset_actors(self, env_ids): + super()._reset_actors(env_ids) + self._reset_target(env_ids) + return + + def _reset_target(self, env_ids): + n = len(env_ids) + + init_near = torch.rand([n], dtype=self._target_states.dtype, device=self._target_states.device) < self._near_prob + dist_max = self._tar_dist_max * torch.ones([n], dtype=self._target_states.dtype, device=self._target_states.device) + dist_max[init_near] = self._near_dist + rand_dist = (dist_max - self._tar_dist_min) * torch.rand([n], dtype=self._target_states.dtype, device=self._target_states.device) + self._tar_dist_min + + rand_theta = 2 * np.pi * torch.rand([n], dtype=self._target_states.dtype, device=self._target_states.device) + self._target_states[env_ids, 0] = rand_dist * torch.cos(rand_theta) + self._humanoid_root_states[env_ids, 0] + self._target_states[env_ids, 1] = rand_dist * torch.sin(rand_theta) + self._humanoid_root_states[env_ids, 1] + self._target_states[env_ids, 2] = 0.9 + + rand_rot_theta = 2 * np.pi * torch.rand([n], dtype=self._target_states.dtype, device=self._target_states.device) + axis = torch.tensor([0.0, 0.0, 1.0], dtype=self._target_states.dtype, device=self._target_states.device) + rand_rot = quat_from_angle_axis(rand_rot_theta, axis) + + self._target_states[env_ids, 3:7] = rand_rot + self._target_states[env_ids, 7:10] = 0.0 + self._target_states[env_ids, 10:13] = 0.0 + return + + def _sample_ref_state(self, env_ids): + motion_ids, motion_times, root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, rb_pos, rb_rot, body_vel, body_ang_vel = super()._sample_ref_state(env_ids) + root_pos[..., :2] = 0.0 # Set the root position to be zero + return motion_ids, motion_times, root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, rb_pos, rb_rot, body_vel, body_ang_vel + + def _reset_env_tensors(self, env_ids): + super()._reset_env_tensors(env_ids) + + env_ids_int32 = self._tar_actor_ids[env_ids] + self.gym.set_actor_root_state_tensor_indexed(self.sim, gymtorch.unwrap_tensor(self._root_states), + gymtorch.unwrap_tensor(env_ids_int32), len(env_ids_int32)) + return + + def pre_physics_step(self, actions): + super().pre_physics_step(actions) + self._prev_root_pos[:] = self._humanoid_root_states[..., 0:3] + return + + def _compute_task_obs(self, env_ids=None): + if (env_ids is None): + root_states = self._humanoid_root_states + tar_states = self._target_states + else: + root_states = self._humanoid_root_states[env_ids] + tar_states = self._target_states[env_ids] + + obs = compute_strike_observations(root_states, tar_states) + return obs + + def _compute_reward(self, actions): + tar_pos = self._target_states[..., 0:3] + tar_rot = self._target_states[..., 3:7] + char_root_state = self._humanoid_root_states + strike_body_vel = self._rigid_body_vel[..., self._strike_body_ids[0], :] + + self.rew_buf[:] = compute_strike_reward(tar_pos, tar_rot, char_root_state, + self._prev_root_pos, strike_body_vel, + self.dt, self._near_dist) + + if self.power_usage_reward: + power_all = torch.abs(torch.multiply(self.dof_force_tensor, self._dof_vel)) + power_all = power_all.reshape(-1, 23, 3) + left_power = power_all[:, self.left_lower_indexes].reshape(self.num_envs, -1).sum(dim = -1) + right_power = power_all[:, self.right_lower_indexes].reshape(self.num_envs, -1).sum(dim = -1) + self.power_acc[:, 0] += left_power + self.power_acc[:, 1] += right_power + power_usage_reward = self.power_acc/(self.progress_buf + 1)[:, None] + # print((power_usage_reward[:, 0] - power_usage_reward[:, 1]).abs()) + power_usage_reward = - self.power_usage_coefficient * (power_usage_reward[:, 0] - power_usage_reward[:, 1]).abs() + power_usage_reward[self.progress_buf <= 3] = 0 # First 3 frame power reward should not be counted. since they could be dropped. on the ground to balance. + + self.rew_buf[:] += power_usage_reward + return + + def _compute_reset(self): + self.reset_buf[:], self._terminate_buf[:] = compute_humanoid_reset(self.reset_buf, self.progress_buf, + self._contact_forces, self._contact_body_ids, + self._rigid_body_pos, self._tar_contact_forces, + self._strike_body_ids, self.max_episode_length, + self._enable_early_termination, self._termination_heights) + return + + def _draw_task(self): + cols = np.array([[0.0, 1.0, 0.0]], dtype=np.float32) + + self.gym.clear_lines(self.viewer) + + starts = self._humanoid_root_states[..., 0:3] + ends = self._target_states[..., 0:3] + verts = torch.cat([starts, ends], dim=-1).cpu().numpy() + + for i, env_ptr in enumerate(self.envs): + curr_verts = verts[i] + curr_verts = curr_verts.reshape([1, 6]) + self.gym.add_lines(self.viewer, env_ptr, curr_verts.shape[0], curr_verts, cols) + + return + + def _hack_output_motion_target(self): + if (not hasattr(self, '_output_motion_target_pos')): + self._output_motion_target_pos = [] + self._output_motion_target_rot = [] + + tar_pos = self._target_states[0, 0:3].cpu().numpy() + self._output_motion_target_pos.append(tar_pos) + + tar_rot = self._target_states[0, 3:7].cpu().numpy() + self._output_motion_target_rot.append(tar_rot) + + reset = self.reset_buf[0].cpu().numpy() == 1 + + if (reset and len(self._output_motion_target_pos) > 1): + output_tar_pos = np.array(self._output_motion_target_pos) + output_tar_rot = np.array(self._output_motion_target_rot) + output_data = np.concatenate([output_tar_pos, output_tar_rot], axis=-1) + np.save('output/record_tar_motion.npy', output_data) + + self._output_motion_target_pos = [] + self._output_motion_target_rot = [] + + return + +class HumanoidStrikeZ(humanoid_amp_task.HumanoidAMPZTask): + def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, headless): + super().__init__(cfg=cfg, + sim_params=sim_params, + physics_engine=physics_engine, + device_type=device_type, + device_id=device_id, + headless=headless) + + self._tar_dist_min = 0.5 + self._tar_dist_max = 10.0 + self._near_dist = 1.5 + self._near_prob = 0.5 + + self._prev_root_pos = torch.zeros([self.num_envs, 3], device=self.device, dtype=torch.float) + + strike_body_names = cfg["env"]["strikeBodyNames"] + self._strike_body_ids = self._build_strike_body_ids_tensor(self.envs[0], self.humanoid_handles[0], strike_body_names) + self._build_target_tensors() + + self.power_usage_reward = cfg["env"].get("power_usage_reward", False) + self.power_coefficient = cfg["env"].get("power_coefficient", 0.0005) + self.power_usage_coefficient = cfg["env"].get("power_usage_coefficient", 0.0025) + self.power_acc = torch.zeros((self.num_envs, 2 )).to(self.device) + + return + + def _sample_ref_state(self, env_ids): + motion_ids, motion_times, root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, rb_pos, rb_rot, body_vel, body_ang_vel = super()._sample_ref_state(env_ids) + root_pos[..., :2] = 0.0 # Set the root position to be zero + return motion_ids, motion_times, root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, rb_pos, rb_rot, body_vel, body_ang_vel + + def get_task_obs_size(self): + obs_size = 0 + if (self._enable_task_obs): + obs_size = 15 + return obs_size + + def post_physics_step(self): + super().post_physics_step() + + if (humanoid_amp.HACK_OUTPUT_MOTION): + self._hack_output_motion_target() + + return + + def _create_envs(self, num_envs, spacing, num_per_row): + self._target_handles = [] + self._load_target_asset() + + super()._create_envs(num_envs, spacing, num_per_row) + return + + def _build_env(self, env_id, env_ptr, humanoid_asset): + super()._build_env(env_id, env_ptr, humanoid_asset) + self._build_target(env_id, env_ptr) + return + + def _load_target_asset(self): + asset_root = "pulse/data/assets/mjcf/" + asset_file = "strike_target.urdf" + + asset_options = gymapi.AssetOptions() + asset_options.angular_damping = 0.01 + asset_options.linear_damping = 0.01 + asset_options.max_angular_velocity = 100.0 + asset_options.density = 30.0 + asset_options.default_dof_drive_mode = gymapi.DOF_MODE_NONE + + self._target_asset = self.gym.load_asset(self.sim, asset_root, asset_file, asset_options) + return + + def _build_target(self, env_id, env_ptr): + default_pose = gymapi.Transform() + default_pose.p.x = 1.0 + + target_handle = self.gym.create_actor(env_ptr, self._target_asset, default_pose, "target", env_id, 2) + self._target_handles.append(target_handle) + + return + + def _build_strike_body_ids_tensor(self, env_ptr, actor_handle, body_names): + env_ptr = self.envs[0] + actor_handle = self.humanoid_handles[0] + body_ids = [] + + for body_name in body_names: + body_id = self.gym.find_actor_rigid_body_handle(env_ptr, actor_handle, body_name) + assert(body_id != -1) + body_ids.append(body_id) + + body_ids = to_torch(body_ids, device=self.device, dtype=torch.long) + return body_ids + + def _build_target_tensors(self): + num_actors = self.get_num_actors_per_env() + self._target_states = self._root_states.view(self.num_envs, num_actors, self._root_states.shape[-1])[..., 1, :] + + self._tar_actor_ids = to_torch(num_actors * np.arange(self.num_envs), device=self.device, dtype=torch.int32) + 1 + + bodies_per_env = self._rigid_body_state.shape[0] // self.num_envs + contact_force_tensor = self.gym.acquire_net_contact_force_tensor(self.sim) + contact_force_tensor = gymtorch.wrap_tensor(contact_force_tensor) + self._tar_contact_forces = contact_force_tensor.view(self.num_envs, bodies_per_env, 3)[..., self.num_bodies, :] + + return + + def _reset_actors(self, env_ids): + super()._reset_actors(env_ids) + self._reset_target(env_ids) + return + + def _reset_target(self, env_ids): + n = len(env_ids) + + init_near = torch.rand([n], dtype=self._target_states.dtype, device=self._target_states.device) < self._near_prob + dist_max = self._tar_dist_max * torch.ones([n], dtype=self._target_states.dtype, device=self._target_states.device) + dist_max[init_near] = self._near_dist + rand_dist = (dist_max - self._tar_dist_min) * torch.rand([n], dtype=self._target_states.dtype, device=self._target_states.device) + self._tar_dist_min + + rand_theta = 2 * np.pi * torch.rand([n], dtype=self._target_states.dtype, device=self._target_states.device) + self._target_states[env_ids, 0] = rand_dist * torch.cos(rand_theta) + self._humanoid_root_states[env_ids, 0] + self._target_states[env_ids, 1] = rand_dist * torch.sin(rand_theta) + self._humanoid_root_states[env_ids, 1] + self._target_states[env_ids, 2] = 0.9 + + rand_rot_theta = 2 * np.pi * torch.rand([n], dtype=self._target_states.dtype, device=self._target_states.device) + axis = torch.tensor([0.0, 0.0, 1.0], dtype=self._target_states.dtype, device=self._target_states.device) + rand_rot = quat_from_angle_axis(rand_rot_theta, axis) + + self._target_states[env_ids, 3:7] = rand_rot + self._target_states[env_ids, 7:10] = 0.0 + self._target_states[env_ids, 10:13] = 0.0 + return + + def _reset_env_tensors(self, env_ids): + super()._reset_env_tensors(env_ids) + + env_ids_int32 = self._tar_actor_ids[env_ids] + self.gym.set_actor_root_state_tensor_indexed(self.sim, gymtorch.unwrap_tensor(self._root_states), + gymtorch.unwrap_tensor(env_ids_int32), len(env_ids_int32)) + return + + def pre_physics_step(self, actions): + super().pre_physics_step(actions) + self._prev_root_pos[:] = self._humanoid_root_states[..., 0:3] + return + + def _compute_task_obs(self, env_ids=None): + if (env_ids is None): + root_states = self._humanoid_root_states + tar_states = self._target_states + else: + root_states = self._humanoid_root_states[env_ids] + tar_states = self._target_states[env_ids] + + obs = compute_strike_observations(root_states, tar_states) + return obs + + def _compute_reward(self, actions): + tar_pos = self._target_states[..., 0:3] + tar_rot = self._target_states[..., 3:7] + char_root_state = self._humanoid_root_states + strike_body_vel = self._rigid_body_vel[..., self._strike_body_ids[0], :] + + self.rew_buf[:] = compute_strike_reward(tar_pos, tar_rot, char_root_state, + self._prev_root_pos, strike_body_vel, + self.dt, self._near_dist) + + + if self.power_usage_reward: + power_all = torch.abs(torch.multiply(self.dof_force_tensor, self._dof_vel)) + power_all = power_all.reshape(-1, 23, 3) + left_power = power_all[:, self.left_lower_indexes].reshape(self.num_envs, -1).sum(dim = -1) + right_power = power_all[:, self.right_lower_indexes].reshape(self.num_envs, -1).sum(dim = -1) + self.power_acc[:, 0] += left_power + self.power_acc[:, 1] += right_power + power_usage_reward = self.power_acc/(self.progress_buf + 1)[:, None] + # print((power_usage_reward[:, 0] - power_usage_reward[:, 1]).abs()) + power_usage_reward = - self.power_usage_coefficient * (power_usage_reward[:, 0] - power_usage_reward[:, 1]).abs() + power_usage_reward[self.progress_buf <= 3] = 0 # First 3 frame power reward should not be counted. since they could be dropped. on the ground to balance. + + self.rew_buf[:] += power_usage_reward + return + + def _compute_reset(self): + self.reset_buf[:], self._terminate_buf[:] = compute_humanoid_reset(self.reset_buf, self.progress_buf, + self._contact_forces, self._contact_body_ids, + self._rigid_body_pos, self._tar_contact_forces, + self._strike_body_ids, self.max_episode_length, + self._enable_early_termination, self._termination_heights) + return + + def _draw_task(self): + cols = np.array([[0.0, 1.0, 0.0]], dtype=np.float32) + + self.gym.clear_lines(self.viewer) + + starts = self._humanoid_root_states[..., 0:3] + ends = self._target_states[..., 0:3] + verts = torch.cat([starts, ends], dim=-1).cpu().numpy() + + for i, env_ptr in enumerate(self.envs): + curr_verts = verts[i] + curr_verts = curr_verts.reshape([1, 6]) + self.gym.add_lines(self.viewer, env_ptr, curr_verts.shape[0], curr_verts, cols) + + return + + def _hack_output_motion_target(self): + if (not hasattr(self, '_output_motion_target_pos')): + self._output_motion_target_pos = [] + self._output_motion_target_rot = [] + + tar_pos = self._target_states[0, 0:3].cpu().numpy() + self._output_motion_target_pos.append(tar_pos) + + tar_rot = self._target_states[0, 3:7].cpu().numpy() + self._output_motion_target_rot.append(tar_rot) + + reset = self.reset_buf[0].cpu().numpy() == 1 + + if (reset and len(self._output_motion_target_pos) > 1): + output_tar_pos = np.array(self._output_motion_target_pos) + output_tar_rot = np.array(self._output_motion_target_rot) + output_data = np.concatenate([output_tar_pos, output_tar_rot], axis=-1) + np.save('output/record_tar_motion.npy', output_data) + + self._output_motion_target_pos = [] + self._output_motion_target_rot = [] + + return + +##################################################################### +###=========================jit functions=========================### +##################################################################### + +@torch.jit.script +def compute_strike_observations(root_states, tar_states): + # type: (Tensor, Tensor) -> Tensor + root_pos = root_states[:, 0:3] + root_rot = root_states[:, 3:7] + + tar_pos = tar_states[:, 0:3] + tar_rot = tar_states[:, 3:7] + tar_vel = tar_states[:, 7:10] + tar_ang_vel = tar_states[:, 10:13] + + heading_rot = torch_utils.calc_heading_quat_inv(root_rot) + + local_tar_pos = tar_pos - root_pos + local_tar_pos[..., -1] = tar_pos[..., -1] + local_tar_pos = torch_utils.my_quat_rotate(heading_rot, local_tar_pos) + local_tar_vel = torch_utils.my_quat_rotate(heading_rot, tar_vel) + local_tar_ang_vel = torch_utils.my_quat_rotate(heading_rot, tar_ang_vel) + + local_tar_rot = quat_mul(heading_rot, tar_rot) + local_tar_rot_obs = torch_utils.quat_to_tan_norm(local_tar_rot) + + obs = torch.cat([local_tar_pos, local_tar_rot_obs, local_tar_vel, local_tar_ang_vel], dim=-1) + return obs + +@torch.jit.script +def compute_strike_reward(tar_pos, tar_rot, root_state, prev_root_pos, strike_body_vel, dt, near_dist): + # type: (Tensor, Tensor, Tensor, Tensor, Tensor, float, float) -> Tensor + tar_speed = 1.0 + vel_err_scale = 4.0 + + tar_rot_w = 0.6 + vel_reward_w = 0.4 + + up = torch.zeros_like(tar_pos) + up[..., -1] = 1 + tar_up = quat_rotate(tar_rot, up) + tar_rot_err = torch.sum(up * tar_up, dim=-1) + tar_rot_r = torch.clamp_min(1.0 - tar_rot_err, 0.0) + + root_pos = root_state[..., 0:3] + tar_dir = tar_pos[..., 0:2] - root_pos[..., 0:2] + tar_dir = torch.nn.functional.normalize(tar_dir, dim=-1) + delta_root_pos = root_pos - prev_root_pos + root_vel = delta_root_pos / dt + tar_dir_speed = torch.sum(tar_dir * root_vel[..., :2], dim=-1) + tar_vel_err = tar_speed - tar_dir_speed + tar_vel_err = torch.clamp_min(tar_vel_err, 0.0) + vel_reward = torch.exp(-vel_err_scale * (tar_vel_err * tar_vel_err)) + speed_mask = tar_dir_speed <= 0 + vel_reward[speed_mask] = 0 + + + reward = tar_rot_w * tar_rot_r + vel_reward_w * vel_reward + + succ = tar_rot_err < 0.2 + reward = torch.where(succ, torch.ones_like(reward), reward) + return reward + + +@torch.jit.script +def compute_humanoid_reset(reset_buf, progress_buf, contact_buf, contact_body_ids, rigid_body_pos, + tar_contact_forces, strike_body_ids, max_episode_length, + enable_early_termination, termination_heights): + # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, float, bool, Tensor) -> Tuple[Tensor, Tensor] + contact_force_threshold = 50.0 + + terminated = torch.zeros_like(reset_buf) + + if (enable_early_termination): + masked_contact_buf = contact_buf.clone() + masked_contact_buf[:, contact_body_ids, :] = 0 + fall_contact = torch.any(torch.abs(masked_contact_buf) > 0.1, dim=-1) + fall_contact = torch.any(fall_contact, dim=-1) + + body_height = rigid_body_pos[..., 2] + fall_height = body_height < termination_heights + fall_height[:, contact_body_ids] = False + fall_height = torch.any(fall_height, dim=-1) + + has_fallen = torch.logical_and(fall_contact, fall_height) + + tar_has_contact = torch.any(torch.abs(tar_contact_forces[..., 0:2]) > contact_force_threshold, dim=-1) + #strike_body_force = contact_buf[:, strike_body_id, :] + #strike_body_has_contact = torch.any(torch.abs(strike_body_force) > contact_force_threshold, dim=-1) + nonstrike_body_force = masked_contact_buf + nonstrike_body_force[:, strike_body_ids, :] = 0 + + # nonstrike_body_has_contact = torch.any(torch.sqrt(torch.square(torch.abs(nonstrike_body_force.sum(dim=-2))).sum(dim=-1)) > contact_force_threshold, dim=-1) + # nonstrike_body_has_contact = torch.any(nonstrike_body_has_contact, dim=-1) + + nonstrike_body_has_contact = torch.any(torch.abs(nonstrike_body_force) > contact_force_threshold, dim=-1) + nonstrike_body_has_contact = torch.any(nonstrike_body_has_contact, dim=-1) + + tar_fail = torch.logical_and(tar_has_contact, nonstrike_body_has_contact) + + has_failed = torch.logical_or(has_fallen, tar_fail) + + + # first timestep can sometimes still have nonzero contact forces + # so only check after first couple of steps + has_failed *= (progress_buf > 1) + terminated = torch.where(has_failed, torch.ones_like(reset_buf), terminated) + + reset = torch.where(progress_buf >= max_episode_length - 1, torch.ones_like(reset_buf), terminated) + + return reset, terminated \ No newline at end of file diff --git a/phc/env/tasks/vec_task.py b/phc/env/tasks/vec_task.py new file mode 100644 index 0000000..2e4d1fd --- /dev/null +++ b/phc/env/tasks/vec_task.py @@ -0,0 +1,162 @@ +# Copyright (c) 2018-2023, NVIDIA Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from gym import spaces + +from isaacgym import gymtorch +from isaacgym.torch_utils import to_torch +import torch +import numpy as np + + +# VecEnv Wrapper for RL training +class VecTask(): + + def __init__(self, task, rl_device, clip_observations=5.0): + self.task = task + + self.num_environments = task.num_envs + self.num_agents = 1 # used for multi-agent environments + self.num_observations = task.num_obs + self.num_states = task.num_states + self.num_actions = task.num_actions + + self.obs_space = spaces.Box(np.ones(self.num_obs) * -np.Inf, np.ones(self.num_obs) * np.Inf) + self.state_space = spaces.Box(np.ones(self.num_states) * -np.Inf, np.ones(self.num_states) * np.Inf) + if isinstance(self.num_actions, int): + self.act_space = spaces.Box(np.ones(self.num_actions) * -1., np.ones(self.num_actions) * 1.) + elif isinstance(self.num_actions, list): + self.act_space = spaces.Tuple([spaces.Discrete(num_actions) for num_actions in self.num_actions]) + + + self.clip_obs = clip_observations + self.rl_device = rl_device + + print("RL device: ", rl_device) + + def step(self, actions): + raise NotImplementedError + + def reset(self): + raise NotImplementedError + + def get_number_of_agents(self): + return self.num_agents + + @property + def observation_space(self): + return self.obs_space + + @property + def action_space(self): + return self.act_space + + @property + def num_envs(self): + return self.num_environments + + @property + def num_acts(self): + return self.num_actions + + @property + def num_obs(self): + return self.num_observations + + +# C++ CPU Class +class VecTaskCPU(VecTask): + + def __init__(self, task, rl_device, sync_frame_time=False, clip_observations=5.0): + super().__init__(task, rl_device, clip_observations=clip_observations) + self.sync_frame_time = sync_frame_time + + def step(self, actions): + actions = actions.cpu().numpy() + self.task.render(self.sync_frame_time) + + obs, rewards, resets, extras = self.task.step(actions) + + return (to_torch(np.clip(obs, -self.clip_obs, self.clip_obs), dtype=torch.float, device=self.rl_device), to_torch(rewards, dtype=torch.float, device=self.rl_device), to_torch(resets, dtype=torch.uint8, device=self.rl_device), []) + + def reset(self): + actions = 0.01 * (1 - 2 * np.random.rand(self.num_envs, self.num_actions)).astype('f') + + # step the simulator + obs, rewards, resets, extras = self.task.step(actions) + + return to_torch(np.clip(obs, -self.clip_obs, self.clip_obs), dtype=torch.float, device=self.rl_device) + + +# C++ GPU Class +class VecTaskGPU(VecTask): + + def __init__(self, task, rl_device, clip_observations=5.0): + super().__init__(task, rl_device, clip_observations=clip_observations) + + self.obs_tensor = gymtorch.wrap_tensor(self.task.obs_tensor, counts=(self.task.num_envs, self.task.num_obs)) + self.rewards_tensor = gymtorch.wrap_tensor(self.task.rewards_tensor, counts=(self.task.num_envs,)) + self.resets_tensor = gymtorch.wrap_tensor(self.task.resets_tensor, counts=(self.task.num_envs,)) + + def step(self, actions): + self.task.render(False) + actions_tensor = gymtorch.unwrap_tensor(actions) + + self.task.step(actions_tensor) + + return torch.clamp(self.obs_tensor, -self.clip_obs, self.clip_obs), self.rewards_tensor, self.resets_tensor, [] + + def reset(self): + actions = 0.01 * (1 - 2 * torch.rand([self.task.num_envs, self.task.num_actions], dtype=torch.float32, device=self.rl_device)) + actions_tensor = gymtorch.unwrap_tensor(actions) + + # step the simulator + self.task.step(actions_tensor) + + return torch.clamp(self.obs_tensor, -self.clip_obs, self.clip_obs) + + +# Python CPU/GPU Class +class VecTaskPython(VecTask): + + def get_state(self): + return torch.clamp(self.task.states_buf, -self.clip_obs, self.clip_obs).to(self.rl_device) + + def step(self, actions): + + self.task.step(actions) + + return torch.clamp(self.task.obs_buf, -self.clip_obs, self.clip_obs).to(self.rl_device), self.task.rew_buf.to(self.rl_device), self.task.reset_buf.to(self.rl_device), self.task.extras + + def reset(self): + actions = 0.01 * (1 - 2 * torch.rand([self.task.num_envs, self.task.num_actions], dtype=torch.float32, device=self.rl_device)) + + # step the simulator + self.task.step(actions) + + return torch.clamp(self.task.obs_buf, -self.clip_obs, self.clip_obs).to(self.rl_device) diff --git a/phc/env/tasks/vec_task_wrappers.py b/phc/env/tasks/vec_task_wrappers.py new file mode 100644 index 0000000..664289f --- /dev/null +++ b/phc/env/tasks/vec_task_wrappers.py @@ -0,0 +1,81 @@ +# Copyright (c) 2018-2023, NVIDIA Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from gym import spaces +import numpy as np +import torch +from phc.env.tasks.vec_task import VecTaskCPU, VecTaskGPU, VecTaskPython + +class VecTaskCPUWrapper(VecTaskCPU): + def __init__(self, task, rl_device, sync_frame_time=False, clip_observations=5.0): + super().__init__(task, rl_device, sync_frame_time, clip_observations) + return + +class VecTaskGPUWrapper(VecTaskGPU): + def __init__(self, task, rl_device, clip_observations=5.0): + super().__init__(task, rl_device, clip_observations) + return + + +class VecTaskPythonWrapper(VecTaskPython): + def __init__(self, task, rl_device, clip_observations=5.0): + super().__init__(task, rl_device, clip_observations) + + self._amp_obs_space = spaces.Box(np.ones(task.get_num_amp_obs()) * -np.Inf, np.ones(task.get_num_amp_obs()) * np.Inf) + + self._enc_amp_obs_space = spaces.Box(np.ones(task.get_num_enc_amp_obs()) * -np.Inf, np.ones(task.get_num_enc_amp_obs()) * np.Inf) + return + + def reset(self, env_ids=None): + self.task.reset(env_ids) + return torch.clamp(self.task.obs_buf, -self.clip_obs, self.clip_obs).to(self.rl_device) + + @property + def amp_observation_space(self): + return self._amp_obs_space + + @property + def enc_amp_observation_space(self): + return self._enc_amp_obs_space + + def fetch_amp_obs_demo(self, num_samples): + return self.task.fetch_amp_obs_demo(num_samples) + + @property + def enc_amp_observation_space(self): + return self._enc_amp_obs_space + + ################ Calm ################ + def fetch_amp_obs_demo_pair(self, num_samples): + return self.task.fetch_amp_obs_demo_pair(num_samples) + + def fetch_amp_obs_demo_enc_pair(self, num_samples): + return self.task.fetch_amp_obs_demo_enc_pair(num_samples) + + def fetch_amp_obs_demo_per_id(self, num_samples, motion_ids): + return self.task.fetch_amp_obs_demo_per_id(num_samples, motion_ids) diff --git a/phc/env/util/gym_util.py b/phc/env/util/gym_util.py new file mode 100644 index 0000000..5f6e1a2 --- /dev/null +++ b/phc/env/util/gym_util.py @@ -0,0 +1,212 @@ +from phc.utils import logger +from isaacgym import gymapi +import numpy as np +import torch +from isaacgym.torch_utils import * +from isaacgym import gymtorch + +def setup_gym_viewer(config): + gym = initialize_gym(config) + sim, viewer = configure_gym(gym, config) + return gym, sim, viewer + + +def initialize_gym(config): + gym = gymapi.acquire_gym() + if not gym.initialize(): + logger.warn("*** Failed to initialize gym") + quit() + + return gym + + +def configure_gym(gym, config): + engine, render = config['engine'], config['render'] + + # physics engine settings + if(engine == 'FLEX'): + sim_engine = gymapi.SIM_FLEX + elif(engine == 'PHYSX'): + sim_engine = gymapi.SIM_PHYSX + else: + logger.warn("Uknown physics engine. defaulting to FLEX") + sim_engine = gymapi.SIM_FLEX + + # gym viewer + if render: + # create viewer + sim = gym.create_sim(0, 0, sim_type=sim_engine) + viewer = gym.create_viewer( + sim, int(gymapi.DEFAULT_VIEWER_WIDTH / 1.25), + int(gymapi.DEFAULT_VIEWER_HEIGHT / 1.25) + ) + + if viewer is None: + logger.warn("*** Failed to create viewer") + quit() + + # enable left mouse click or space bar for throwing projectiles + if config['add_projectiles']: + gym.subscribe_viewer_mouse_event(viewer, gymapi.MOUSE_LEFT_BUTTON, "shoot") + # gym.subscribe_viewer_keyboard_event(viewer, gymapi.KEY_SPACE, "shoot") + + else: + sim = gym.create_sim(0, -1) + viewer = None + + # simulation params + scene_config = config['env']['scene'] + sim_params = gymapi.SimParams() + sim_params.solver_type = scene_config['SolverType'] + sim_params.num_outer_iterations = scene_config['NumIterations'] + sim_params.num_inner_iterations = scene_config['NumInnerIterations'] + sim_params.relaxation = scene_config.get('Relaxation', 0.75) + sim_params.warm_start = scene_config.get('WarmStart', 0.25) + sim_params.geometric_stiffness = scene_config.get('GeometricStiffness', 1.0) + sim_params.shape_collision_margin = 0.01 + + sim_params.gravity = gymapi.Vec3(0.0, -9.8, 0.0) + gym.set_sim_params(sim, sim_params) + + return sim, viewer + + +def parse_states_from_reference_states(reference_states, progress): + # parse reference states from DeepMimicState + global_quats_ref = torch.tensor( + reference_states._global_rotation[(progress,)].numpy(), + dtype=torch.double + ).cuda() + ts_ref = torch.tensor( + reference_states._translation[(progress,)].numpy(), + dtype=torch.double + ).cuda() + vels_ref = torch.tensor( + reference_states._velocity[(progress,)].numpy(), + dtype=torch.double + ).cuda() + avels_ref = torch.tensor( + reference_states._angular_velocity[(progress,)].numpy(), + dtype=torch.double + ).cuda() + return global_quats_ref, ts_ref, vels_ref, avels_ref + + +def parse_states_from_reference_states_with_motion_id(precomputed_state, + progress, motion_id): + assert len(progress) == len(motion_id) + # get the global id + global_id = precomputed_state['motion_offset'][motion_id] + progress + global_id = np.minimum(global_id, + precomputed_state['global_quats_ref'].shape[0] - 1) + + # parse reference states from DeepMimicState + global_quats_ref = precomputed_state['global_quats_ref'][global_id] + ts_ref = precomputed_state['ts_ref'][global_id] + vels_ref = precomputed_state['vels_ref'][global_id] + avels_ref = precomputed_state['avels_ref'][global_id] + return global_quats_ref, ts_ref, vels_ref, avels_ref + + +def parse_dof_state_with_motion_id(precomputed_state, dof_state, + progress, motion_id): + assert len(progress) == len(motion_id) + # get the global id + global_id = precomputed_state['motion_offset'][motion_id] + progress + # NOTE: it should never reach the dof_state.shape, cause the episode is + # terminated 2 steps before + global_id = np.minimum(global_id, dof_state.shape[0] - 1) + + # parse reference states from DeepMimicState + return dof_state[global_id] + + +def get_flatten_ids(precomputed_state): + motion_offsets = precomputed_state['motion_offset'] + init_state_id, init_motion_id, global_id = [], [], [] + for i_motion in range(len(motion_offsets) - 1): + i_length = motion_offsets[i_motion + 1] - motion_offsets[i_motion] + init_state_id.extend(range(i_length)) + init_motion_id.extend([i_motion] * i_length) + if len(global_id) == 0: + global_id.extend(range(0, i_length)) + else: + global_id.extend(range(global_id[-1] + 1, + global_id[-1] + i_length + 1)) + return np.array(init_state_id), np.array(init_motion_id), \ + np.array(global_id) + + +def parse_states_from_reference_states_with_global_id(precomputed_state, + global_id): + # get the global id + global_id = global_id % precomputed_state['global_quats_ref'].shape[0] + + # parse reference states from DeepMimicState + global_quats_ref = precomputed_state['global_quats_ref'][global_id] + ts_ref = precomputed_state['ts_ref'][global_id] + vels_ref = precomputed_state['vels_ref'][global_id] + avels_ref = precomputed_state['avels_ref'][global_id] + return global_quats_ref, ts_ref, vels_ref, avels_ref + + +def get_robot_states_from_torch_tensor(config, ts, global_quats, vels, avels, + init_rot, progress, motion_length=-1, + actions=None, relative_rot=None, + motion_id=None, num_motion=None, + motion_onehot_matrix=None): + info = {} + # the observation with quaternion-based representation + torso_height = ts[..., 0, 1].cpu().numpy() + gttrny, gqny, vny, avny, info['root_yaw_inv'] = \ + quaternion_math.compute_observation_return_info(global_quats, ts, + vels, avels) + joint_obs = np.concatenate([gttrny.cpu().numpy(), gqny.cpu().numpy(), + vny.cpu().numpy(), avny.cpu().numpy()], axis=-1) + joint_obs = joint_obs.reshape(joint_obs.shape[0], -1) + num_envs = joint_obs.shape[0] + obs = np.concatenate([torso_height[:, np.newaxis], joint_obs], -1) + + # the previous action + if config['env_action_ob']: + obs = np.concatenate([obs, actions], axis=-1) + + # the orientation + if config['env_orientation_ob']: + if relative_rot is not None: + obs = np.concatenate([obs, relative_rot], axis=-1) + else: + curr_rot = global_quats[np.arange(num_envs)][:, 0] + curr_rot = curr_rot.reshape(num_envs, -1, 4) + relative_rot = quaternion_math.compute_orientation_drift( + init_rot, curr_rot + ).cpu().numpy() + obs = np.concatenate([obs, relative_rot], axis=-1) + + if config['env_frame_ob']: + if type(motion_length) == np.ndarray: + motion_length = motion_length.astype(np.float) + progress_ob = np.expand_dims(progress.astype(np.float) / + motion_length, axis=-1) + else: + progress_ob = np.expand_dims(progress.astype(np.float) / + float(motion_length), axis=-1) + obs = np.concatenate([obs, progress_ob], axis=-1) + + if config['env_motion_ob'] and not config['env_motion_ob_onehot']: + motion_id_ob = np.expand_dims(motion_id.astype(np.float) / + float(num_motion), axis=-1) + obs = np.concatenate([obs, motion_id_ob], axis=-1) + elif config['env_motion_ob'] and config['env_motion_ob_onehot']: + motion_id_ob = motion_onehot_matrix[motion_id] + obs = np.concatenate([obs, motion_id_ob], axis=-1) + + return obs, info + + +def get_xyzoffset(start_ts, end_ts, root_yaw_inv): + xyoffset = (end_ts - start_ts)[:, [0], :].reshape(1, -1, 1, 3) + ryinv = root_yaw_inv.reshape(1, -1, 1, 4) + + calibrated_xyz_offset = quaternion_math.quat_apply(ryinv, xyoffset)[0, :, 0, :] + return calibrated_xyz_offset diff --git a/phc/env/util/traj_generator.py b/phc/env/util/traj_generator.py new file mode 100644 index 0000000..c85d444 --- /dev/null +++ b/phc/env/util/traj_generator.py @@ -0,0 +1,208 @@ +# Copyright (c) 2018-2023, NVIDIA Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import numpy as np +import torch +import joblib +import random +from phc.utils.flags import flags +# from phc.env.tasks.base_task import PORT, SERVER + +class TrajGenerator(): + def __init__(self, num_envs, episode_dur, num_verts, device, dtheta_max, + speed_min, speed_max, accel_max, sharp_turn_prob): + + + self._device = device + self._dt = episode_dur / (num_verts - 1) + self._dtheta_max = dtheta_max + self._speed_min = speed_min + self._speed_max = speed_max + self._accel_max = accel_max + self._sharp_turn_prob = sharp_turn_prob + + self._verts_flat = torch.zeros((num_envs * num_verts, 3), dtype=torch.float32, device=self._device) + self._verts = self._verts_flat.view((num_envs, num_verts, 3)) + + env_ids = torch.arange(self.get_num_envs(), dtype=np.int) + + # self.traj_data = joblib.load("data/traj/traj_data.pkl") + self.heading = torch.zeros(num_envs, 1) + return + + def reset(self, env_ids, init_pos): + n = len(env_ids) + if (n > 0): + num_verts = self.get_num_verts() + dtheta = 2 * torch.rand([n, num_verts - 1], device=self._device) - 1.0 # Sample the angles at each waypoint + dtheta *= self._dtheta_max * self._dt + + dtheta_sharp = np.pi * (2 * torch.rand([n, num_verts - 1], device=self._device) - 1.0) # Sharp Angles Angle + sharp_probs = self._sharp_turn_prob * torch.ones_like(dtheta) + sharp_mask = torch.bernoulli(sharp_probs) == 1.0 + dtheta[sharp_mask] = dtheta_sharp[sharp_mask] + + dtheta[:, 0] = np.pi * (2 * torch.rand([n], device=self._device) - 1.0) # Heading + + + dspeed = 2 * torch.rand([n, num_verts - 1], device=self._device) - 1.0 + dspeed *= self._accel_max * self._dt + dspeed[:, 0] = (self._speed_max - self._speed_min) * torch.rand([n], device=self._device) + self._speed_min # Speed + + speed = torch.zeros_like(dspeed) + speed[:, 0] = dspeed[:, 0] + for i in range(1, dspeed.shape[-1]): + speed[:, i] = torch.clip(speed[:, i - 1] + dspeed[:, i], self._speed_min, self._speed_max) + + ################################################ + # if flags.fixed_path: + # dtheta[:, :] = 0 # ZL: Hacking to make everything 0 + # dtheta[0, 0] = 0 # ZL: Hacking to create collision + # if len(dtheta) > 1: + # dtheta[1, 0] = -np.pi # ZL: Hacking to create collision + # speed[:] = (self._speed_min + self._speed_max)/2 + # ################################################ + + # if flags.slow: + # speed[:] = speed/4 + + dtheta = torch.cumsum(dtheta, dim=-1) + + # speed[:] = 6 + seg_len = speed * self._dt + + dpos = torch.stack([torch.cos(dtheta), -torch.sin(dtheta), torch.zeros_like(dtheta)], dim=-1) + dpos *= seg_len.unsqueeze(-1) + dpos[..., 0, 0:2] += init_pos[..., 0:2] + vert_pos = torch.cumsum(dpos, dim=-2) + + self._verts[env_ids, 0, 0:2] = init_pos[..., 0:2] + self._verts[env_ids, 1:] = vert_pos + + ####### ZL: Loading random real-world trajectories ####### + if flags.real_path: + rids = random.sample(self.traj_data.keys(), n) + traj = torch.stack([ + torch.from_numpy( + self.traj_data[id]['coord_dense'])[:num_verts] + for id in rids + ], + dim=0).to(self._device).float() + + traj[..., 0:2] = traj[..., 0:2] - (traj[..., 0, 0:2] - init_pos[..., 0:2])[:, None] + self._verts[env_ids] = traj + + return + + def input_new_trajs(self, env_ids): + import json + import requests + from scipy.interpolate import interp1d + x = requests.get( + f'http://{SERVER}:{PORT}/path?num_envs={len(env_ids)}') + + data_lists = [value for idx, value in x.json().items()] + coord = np.array(data_lists) + x = np.linspace(0, coord.shape[1] - 1, num = coord.shape[1]) + fx = interp1d(x, coord[..., 0], kind='linear') + fy = interp1d(x, coord[..., 1], kind='linear') + x4 = np.linspace(0, coord.shape[1] - 1, num = coord.shape[1] * 10) + coord_dense = np.stack([fx(x4), fy(x4), np.zeros([len(env_ids), x4.shape[0]])], axis = -1) + coord_dense = np.concatenate([coord_dense, coord_dense[..., -1:, :]], axis = -2) + self._verts[env_ids] = torch.from_numpy(coord_dense).float().to(env_ids.device) + return self._verts[env_ids] + + + def get_num_verts(self): + return self._verts.shape[1] + + def get_num_segs(self): + return self.get_num_verts() - 1 + + def get_num_envs(self): + return self._verts.shape[0] + + def get_traj_duration(self): + num_verts = self.get_num_verts() + dur = num_verts * self._dt + return dur + + def get_traj_verts(self, traj_id): + return self._verts[traj_id] + + def calc_pos(self, traj_ids, times): + traj_dur = self.get_traj_duration() + num_verts = self.get_num_verts() + num_segs = self.get_num_segs() + + traj_phase = torch.clip(times / traj_dur, 0.0, 1.0) + seg_idx = traj_phase * num_segs + seg_id0 = torch.floor(seg_idx).long() + seg_id1 = torch.ceil(seg_idx).long() + lerp = seg_idx - seg_id0 + pos0 = self._verts_flat[traj_ids * num_verts + seg_id0] + pos1 = self._verts_flat[traj_ids * num_verts + seg_id1] + + lerp = lerp.unsqueeze(-1) + pos = (1.0 - lerp) * pos0 + lerp * pos1 + + return pos + + def mock_calc_pos(self, env_ids, traj_ids, times, query_value_gradient): + traj_dur = self.get_traj_duration() + num_verts = self.get_num_verts() + num_segs = self.get_num_segs() + + traj_phase = torch.clip(times / traj_dur, 0.0, 1.0) + seg_idx = traj_phase * num_segs + seg_id0 = torch.floor(seg_idx).long() + seg_id1 = torch.ceil(seg_idx).long() + lerp = seg_idx - seg_id0 + + pos0 = self._verts_flat[traj_ids * num_verts + seg_id0] + pos1 = self._verts_flat[traj_ids * num_verts + seg_id1] + + lerp = lerp.unsqueeze(-1) + pos = (1.0 - lerp) * pos0 + lerp * pos1 + + new_obs, func = query_value_gradient(env_ids, pos) + if not new_obs is None: + # ZL: computes grad + with torch.enable_grad(): + new_obs.requires_grad_(True) + new_val = func(new_obs) + + disc_grad = torch.autograd.grad( + new_val, + new_obs, + grad_outputs=torch.ones_like(new_val), + create_graph=False, + retain_graph=True, + only_inputs=True) + + return pos diff --git a/phc/learning/__init__.py b/phc/learning/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/phc/learning/amp_agent.py b/phc/learning/amp_agent.py new file mode 100644 index 0000000..467d700 --- /dev/null +++ b/phc/learning/amp_agent.py @@ -0,0 +1,1105 @@ +from phc.utils.running_mean_std import RunningMeanStd +from rl_games.algos_torch import torch_ext +from rl_games.common import a2c_common +from rl_games.common import schedulers +from rl_games.common import vecenv + +from isaacgym.torch_utils import * + +import time +from datetime import datetime +import numpy as np +from torch import optim +import torch +from torch import nn +from phc.env.tasks.humanoid_amp_task import HumanoidAMPTask + +import learning.replay_buffer as replay_buffer +import learning.common_agent as common_agent + +from tensorboardX import SummaryWriter +import copy +from phc.utils.torch_utils import project_to_norm +import learning.amp_datasets as amp_datasets +from phc.learning.loss_functions import kl_multi +from smpl_sim.utils.math_utils import LinearAnneal + +def load_my_state_dict(target, saved_dict): + for name, param in saved_dict.items(): + if name not in target: + continue + + if target[name].shape == param.shape: + target[name].copy_(param) + + +class AMPAgent(common_agent.CommonAgent): + + def __init__(self, base_name, config): + super().__init__(base_name, config) + if self.config.get('use_seq_rl', False): + # Use the is_rnn to force the dataset to have sequencal format. + self.dataset = amp_datasets.AMPDataset(self.batch_size, self.minibatch_size, self.is_discrete, True, self.ppo_device, self.seq_len) + else: + self.dataset = amp_datasets.AMPDataset(self.batch_size, self.minibatch_size, self.is_discrete, self.is_rnn, self.ppo_device, self.seq_len) + + + if self.normalize_value: + self.value_mean_std = RunningMeanStd((1,)).to(self.ppo_device) # Override and get new value + + if self._normalize_amp_input: + self._amp_input_mean_std = RunningMeanStd(self._amp_observation_space.shape).to(self.ppo_device) + + norm_disc_reward = config.get('norm_disc_reward', False) + if (norm_disc_reward): + self._disc_reward_mean_std = RunningMeanStd((1,)).to(self.ppo_device) + else: + self._disc_reward_mean_std = None + + self.save_kin_info = self.vec_env.env.task.cfg.env.get("save_kin_info", False) + self.only_kin_loss = self.vec_env.env.task.cfg.env.get("only_kin_loss", False) + self.temp_running_mean = self.vec_env.env.task.temp_running_mean # use temp running mean to make sure the obs used for training is the same as calc gradient. + + kin_lr = float(self.vec_env.env.task.kin_lr) + + if self.save_kin_info: + self.kin_dict_info = None + self.kin_optimizer = torch.optim.Adam(self.model.a2c_network.parameters(), kin_lr) + + # ZL Hack + if self.vec_env.env.task.fitting: + print("#################### Fitting and freezing!! ####################") + checkpoint = torch_ext.load_checkpoint(self.vec_env.env.task.models_path[0]) + + self.set_stats_weights(checkpoint) # loads mean std. essential for distilling knowledge. will not load if has a shape mismatch. + self.freeze_state_weights() # freeze the mean stds. + load_my_state_dict(self.model.state_dict(), checkpoint['model']) # loads everything (model, std, ect.). that can be load from the last model. + # self.value_mean_std # not freezing value function though. + + return + + def set_stats_weights(self, weights): + if self.normalize_input: + if weights['running_mean_std']['running_mean'].shape == self.running_mean_std.state_dict()['running_mean'].shape: + self.running_mean_std.load_state_dict(weights['running_mean_std']) + else: + print("shape mismatch, can not load input mean std") + + if self.normalize_value: + self.value_mean_std.load_state_dict(weights['reward_mean_std']) + + if self.has_central_value: + self.central_value_net.set_stats_weights(weights['assymetric_vf_mean_std']) + + if self.mixed_precision and 'scaler' in weights: + self.scaler.load_state_dict(weights['scaler']) + + if self._normalize_amp_input: + if weights['amp_input_mean_std']['running_mean'].shape == self._amp_input_mean_std.state_dict()['running_mean'].shape: + self._amp_input_mean_std.load_state_dict(weights['amp_input_mean_std']) + else: + print("shape mismatch, can not load AMP mean std") + + + if (self._norm_disc_reward()): + self._disc_reward_mean_std.load_state_dict(weights['disc_reward_mean_std']) + + def get_full_state_weights(self): + state = super().get_full_state_weights() + + if "kin_optimizer" in self.__dict__: + print("!!!saving kin_optimizer!!! Remove this message asa p!!") + state['kin_optimizer'] = self.kin_optimizer.state_dict() + + return state + + def set_full_state_weights(self, weights): + super().set_full_state_weights(weights) + if "kin_optimizer" in weights: + print("!!!loading kin_optimizer!!! Remove this message asa p!!") + self.kin_optimizer.load_state_dict(weights['kin_optimizer']) + + + def freeze_state_weights(self): + if self.normalize_input: + self.running_mean_std.freeze() + if self.normalize_value: + self.value_mean_std.freeze() + if self.has_central_value: + raise NotImplementedError() + if self.mixed_precision: + raise NotImplementedError() + + def unfreeze_state_weights(self): + if self.normalize_input: + self.running_mean_std.unfreeze() + if self.normalize_value: + self.value_mean_std.unfreeze() + if self.has_central_value: + raise NotImplementedError() + if self.mixed_precision: + raise NotImplementedError() + + def init_tensors(self): + super().init_tensors() + self._build_amp_buffers() + + if self.save_kin_info: + B, S, _ = self.experience_buffer.tensor_dict['obses'].shape + kin_dict = self.vec_env.env.task.kin_dict + kin_dict_size = np.sum([v.reshape(v.shape[0], -1).shape[-1] for k, v in kin_dict.items()]) + self.experience_buffer.tensor_dict['kin_dict'] = torch.zeros((B, S, kin_dict_size)).to(self.experience_buffer.tensor_dict['obses']) + self.tensor_list += ['kin_dict'] + + if self.vec_env.env.task.z_type == "vae": + B, S, _ = self.experience_buffer.tensor_dict['obses'].shape + self.experience_buffer.tensor_dict['z_noise'] = torch.zeros(B, S, self.model.a2c_network.embedding_size).to(self.experience_buffer.tensor_dict['obses']) + self.tensor_list += ['z_noise'] + + return + + def set_eval(self): + super().set_eval() + if self._normalize_amp_input: + self._amp_input_mean_std.eval() + + if (self._norm_disc_reward()): + self._disc_reward_mean_std.eval() + + return + + def set_train(self): + super().set_train() + if self._normalize_amp_input: + self._amp_input_mean_std.train() + + if (self._norm_disc_reward()): + self._disc_reward_mean_std.train() + + return + + def get_stats_weights(self): + state = super().get_stats_weights() + if self._normalize_amp_input: + state['amp_input_mean_std'] = self._amp_input_mean_std.state_dict() + + if (self._norm_disc_reward()): + state['disc_reward_mean_std'] = self._disc_reward_mean_std.state_dict() + + return state + + + def play_steps_rnn(self): + self.set_eval() + mb_rnn_states = [] + epinfos = [] + self.experience_buffer.tensor_dict['values'].fill_(0) + self.experience_buffer.tensor_dict['rewards'].fill_(0) + self.experience_buffer.tensor_dict['dones'].fill_(1) + step_time = 0.0 + + update_list = self.update_list + + batch_size = self.num_agents * self.num_actors + mb_rnn_masks = None + + mb_rnn_masks, indices, steps_mask, steps_state, play_mask, mb_rnn_states = self.init_rnn_step(batch_size, mb_rnn_states) # mb_rnn_states means "memory bank" rnn states + + ### ZL + done_indices = [] + terminated_flags = torch.zeros(self.num_actors, device=self.device) + reward_raw = torch.zeros(1, device=self.device) + + for n in range(self.horizon_length): + + + + self.obs = self.env_reset(done_indices) + + # self.rnn_states[0][:, :, -1] = n; print('debugg!!!!') + # self.rnn_states[0][:, :, -2] = torch.arange(self.num_actors) + + seq_indices, full_tensor = self.process_rnn_indices(mb_rnn_masks, indices, steps_mask, steps_state, mb_rnn_states) # this should upate mb_rnn_states + if full_tensor: + break + + if self.has_central_value: + self.central_value_net.pre_step_rnn(self.last_rnn_indices, self.last_state_indices) + + if self.use_action_masks: + masks = self.vec_env.get_action_masks() + res_dict = self.get_masked_action_values(self.obs, masks) + else: + res_dict = self.get_action_values(self.obs) + + self.rnn_states = res_dict['rnn_states'] + self.experience_buffer.update_data_rnn('obses', indices, play_mask, self.obs['obs']) + + for k in update_list: + self.experience_buffer.update_data_rnn(k, indices, play_mask, res_dict[k]) + + if self.has_central_value: + self.experience_buffer.update_data_rnn('states', indices[::self.num_agents], play_mask[::self.num_agents] // self.num_agents, self.obs['states']) + + if self.only_kin_loss: + # pure behavior cloning, kinemaitc loss. + self.obs, rewards, self.dones, infos = self.env_step(res_dict['mus']) + else: + self.obs, rewards, self.dones, infos = self.env_step(res_dict['actions']) + + + shaped_rewards = self.rewards_shaper(rewards) + + if self.value_bootstrap and 'time_outs' in infos: + shaped_rewards += self.gamma * res_dict['values'] * self.cast_obs(infos['time_outs']).unsqueeze(1).float() + self.experience_buffer.update_data_rnn('rewards', indices, play_mask, shaped_rewards) + self.experience_buffer.update_data_rnn('next_obses', indices, play_mask, self.obs['obs']) + self.experience_buffer.update_data_rnn('dones', indices, play_mask, self.dones.byte()) + self.experience_buffer.update_data_rnn('amp_obs', indices, play_mask, infos['amp_obs']) + + ### ZL + terminated = infos['terminate'].float() + terminated_flags += terminated + reward_raw_mean = infos['reward_raw'].mean(dim=0) + + if reward_raw.shape != reward_raw_mean.shape: + reward_raw = reward_raw_mean + else: + reward_raw += reward_raw_mean + + terminated = terminated.unsqueeze(-1) + input_dict = {"obs": self.obs['obs'], "rnn_states": self.rnn_states} + next_vals = self._eval_critic(input_dict) # ZL this has issues? (maybe not, since we are passing the states in.) + next_vals *= (1.0 - terminated) + self.experience_buffer.update_data_rnn('next_values', indices, play_mask, next_vals) + + self.current_rewards += rewards + self.current_lengths += 1 + all_done_indices = self.dones.nonzero(as_tuple=False) + done_indices = all_done_indices[::self.num_agents] + + self.process_rnn_dones(all_done_indices, indices, seq_indices) + + if self.has_central_value: + self.central_value_net.post_step_rnn(all_done_indices) + + self.algo_observer.process_infos(infos, done_indices) + + fdones = self.dones.float() + not_dones = 1.0 - self.dones.float() + + self.game_rewards.update(self.current_rewards[done_indices]) + self.game_lengths.update(self.current_lengths[done_indices]) + + self.current_rewards = self.current_rewards * not_dones.unsqueeze(1) + self.current_lengths = self.current_lengths * not_dones + + if self.only_kin_loss: + self.experience_buffer.update_data_rnn('kin_dict', indices, play_mask, torch.cat([v.reshape(v.shape[0], -1) for k, v in infos['kin_dict'].items()], dim = -1)) + if self.kin_dict_info is None: + self.kin_dict_info = {k: (v.shape, v.reshape(v.shape[0], -1).shape) for k, v in infos['kin_dict'].items()} + + if (self.vec_env.env.task.viewer): + self._amp_debug(infos) + + done_indices = done_indices[:, 0] + + + mb_fdones = self.experience_buffer.tensor_dict['dones'].float() + mb_values = self.experience_buffer.tensor_dict['values'] + mb_next_values = self.experience_buffer.tensor_dict['next_values'] + + mb_rewards = self.experience_buffer.tensor_dict['rewards'] + mb_amp_obs = self.experience_buffer.tensor_dict['amp_obs'] + amp_rewards = self._calc_amp_rewards(mb_amp_obs) + mb_rewards = self._combine_rewards(mb_rewards, amp_rewards) + + + mb_advs = self.discount_values(mb_fdones, mb_values, mb_rewards, mb_next_values) + mb_returns = mb_advs + mb_values + + # self.experience_buffer.tensor_dict['actions']: is num_env, Batch, feat. That's why we swap and flatten, mb_rnn_states is already in that format. + batch_dict = self.experience_buffer.get_transformed_list(a2c_common.swap_and_flatten01, self.tensor_list) # swap to step, num_envs, feat + batch_dict['returns'] = a2c_common.swap_and_flatten01(mb_returns) + batch_dict['rnn_states'] = mb_rnn_states + + batch_dict['rnn_masks'] = mb_rnn_masks # ZL: this should be swap and flattened, but it's all ones for now + batch_dict['terminated_flags'] = terminated_flags + batch_dict['reward_raw'] =reward_raw / self.horizon_length + + batch_dict['played_frames'] = n * self.num_actors * self.num_agents + batch_dict['step_time'] = step_time + + + for k, v in amp_rewards.items(): + batch_dict[k] = a2c_common.swap_and_flatten01(v) + + batch_dict['mb_rewards'] = a2c_common.swap_and_flatten01(mb_rewards) + + return batch_dict + + def play_steps(self): + self.set_eval() + humanoid_env = self.vec_env.env.task + + epinfos = [] + done_indices = [] + update_list = self.update_list + terminated_flags = torch.zeros(self.num_actors, device=self.device) + reward_raw = torch.zeros(1, device=self.device) + for n in range(self.horizon_length): + + self.obs = self.env_reset(done_indices) + self.experience_buffer.update_data('obses', n, self.obs['obs']) + + if self.use_action_masks: + masks = self.vec_env.get_action_masks() + res_dict = self.get_masked_action_values(self.obs, masks) + else: + res_dict = self.get_action_values(self.obs) + + for k in update_list: + self.experience_buffer.update_data(k, n, res_dict[k]) + + if self.has_central_value: + self.experience_buffer.update_data('states', n, self.obs['states']) + + if self.only_kin_loss and self.save_kin_info: + # pure behavior cloning, kinemaitc loss. + self.obs, rewards, self.dones, infos = self.env_step(res_dict['mus']) + else: + self.obs, rewards, self.dones, infos = self.env_step(res_dict['actions']) + + shaped_rewards = self.rewards_shaper(rewards) + self.experience_buffer.update_data('rewards', n, shaped_rewards) + self.experience_buffer.update_data('next_obses', n, self.obs['obs']) + self.experience_buffer.update_data('dones', n, self.dones) + self.experience_buffer.update_data('amp_obs', n, infos['amp_obs']) + + if self.save_kin_info: + self.experience_buffer.update_data('kin_dict', n, torch.cat([v.reshape(v.shape[0], -1) for k, v in infos['kin_dict'].items()], dim = -1)) + + if self.kin_dict_info is None: + self.kin_dict_info = {k: (v.shape, v.reshape(v.shape[0], -1).shape) for k, v in infos['kin_dict'].items()} + + + terminated = infos['terminate'].float() + terminated_flags += terminated + + reward_raw_mean = infos['reward_raw'].mean(dim=0) + if reward_raw.shape != reward_raw_mean.shape: + reward_raw = reward_raw_mean + else: + reward_raw += reward_raw_mean + terminated = terminated.unsqueeze(-1) + + next_vals = self._eval_critic(self.obs) + next_vals *= (1.0 - terminated) + self.experience_buffer.update_data('next_values', n, next_vals) + + self.current_rewards += rewards + self.current_lengths += 1 + all_done_indices = self.dones.nonzero(as_tuple=False) + done_indices = all_done_indices[::self.num_agents] + self.game_rewards.update(self.current_rewards[done_indices]) + self.game_lengths.update(self.current_lengths[done_indices]) + self.algo_observer.process_infos(infos, done_indices) + + not_dones = 1.0 - self.dones.float() + + self.current_rewards = self.current_rewards * not_dones.unsqueeze(1) + self.current_lengths = self.current_lengths * not_dones + + if (self.vec_env.env.task.viewer): + self._amp_debug(infos) + + done_indices = done_indices[:, 0] + + mb_fdones = self.experience_buffer.tensor_dict['dones'].float() + mb_values = self.experience_buffer.tensor_dict['values'] + mb_next_values = self.experience_buffer.tensor_dict['next_values'] + + mb_rewards = self.experience_buffer.tensor_dict['rewards'] + mb_amp_obs = self.experience_buffer.tensor_dict['amp_obs'] + amp_rewards = self._calc_amp_rewards(mb_amp_obs) + mb_rewards = self._combine_rewards(mb_rewards, amp_rewards) + mb_advs = self.discount_values(mb_fdones, mb_values, mb_rewards, mb_next_values) + mb_returns = mb_advs + mb_values + + batch_dict = self.experience_buffer.get_transformed_list(a2c_common.swap_and_flatten01, self.tensor_list) + batch_dict['returns'] = a2c_common.swap_and_flatten01(mb_returns) + batch_dict['terminated_flags'] = terminated_flags + batch_dict['reward_raw'] =reward_raw / self.horizon_length + batch_dict['played_frames'] = self.batch_size + + for k, v in amp_rewards.items(): + batch_dict[k] = a2c_common.swap_and_flatten01(v) + batch_dict['mb_rewards'] = a2c_common.swap_and_flatten01(mb_rewards) + + return batch_dict + + def prepare_dataset(self, batch_dict): + + + dataset_dict = super().prepare_dataset(batch_dict) + dataset_dict['amp_obs'] = batch_dict['amp_obs'] + dataset_dict['amp_obs_demo'] = batch_dict['amp_obs_demo'] + dataset_dict['amp_obs_replay'] = batch_dict['amp_obs_replay'] + + if self.save_kin_info: + dataset_dict['kin_dict'] = batch_dict['kin_dict'] + + if self.vec_env.env.task.z_type == "vae": + dataset_dict['z_noise'] = batch_dict['z_noise'] + + self.dataset.update_values_dict(dataset_dict, rnn_format = True, horizon_length = self.horizon_length, num_envs = self.num_actors) + # self.dataset.update_values_dict(dataset_dict) + + return + + def train_epoch(self): + self.pre_epoch(self.epoch_num) + play_time_start = time.time() + + ### ZL: do not update state weights during play + + with torch.no_grad(): + if self.is_rnn: + batch_dict = self.play_steps_rnn() + else: + batch_dict = self.play_steps() + + play_time_end = time.time() + update_time_start = time.time() + rnn_masks = batch_dict.get('rnn_masks', None) + + self._update_amp_demos() + num_obs_samples = batch_dict['amp_obs'].shape[0] + amp_obs_demo = self._amp_obs_demo_buffer.sample(num_obs_samples)['amp_obs'] + batch_dict['amp_obs_demo'] = amp_obs_demo + + if (self._amp_replay_buffer.get_total_count() == 0): + batch_dict['amp_obs_replay'] = batch_dict['amp_obs'] + else: + batch_dict['amp_obs_replay'] = self._amp_replay_buffer.sample(num_obs_samples)['amp_obs'] + + self.set_train() + + self.curr_frames = batch_dict.pop('played_frames') + + self.prepare_dataset(batch_dict) + self.algo_observer.after_steps() + + if self.has_central_value: + self.train_central_value() + + train_info = None + + # if self.is_rnn: + # frames_mask_ratio = rnn_masks.sum().item() / (rnn_masks.nelement()) + + for _ in range(0, self.mini_epochs_num): + ep_kls = [] + for i in range(len(self.dataset)): + curr_train_info = self.train_actor_critic(self.dataset[i]) + + if self.schedule_type == 'legacy': + if self.multi_gpu: + curr_train_info['kl'] = self.hvd.average_value(curr_train_info['kl'], 'ep_kls') + self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, curr_train_info['kl'].item()) + self.update_lr(self.last_lr) + + if (train_info is None): + train_info = dict() + for k, v in curr_train_info.items(): + train_info[k] = [v] + else: + for k, v in curr_train_info.items(): + train_info[k].append(v) + + av_kls = torch_ext.mean_list(train_info['kl']) + + if self.schedule_type == 'standard': + if self.multi_gpu: + av_kls = self.hvd.average_value(av_kls, 'ep_kls') + self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item()) + self.update_lr(self.last_lr) + + if self.schedule_type == 'standard_epoch': + if self.multi_gpu: + av_kls = self.hvd.average_value(torch_ext.mean_list(kls), 'ep_kls') + self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item()) + self.update_lr(self.last_lr) + + update_time_end = time.time() + play_time = play_time_end - play_time_start + update_time = update_time_end - update_time_start + total_time = update_time_end - play_time_start + + self._store_replay_amp_obs(batch_dict['amp_obs']) + + train_info['play_time'] = play_time + train_info['update_time'] = update_time + train_info['total_time'] = total_time + train_info['terminated_flags'] = batch_dict['terminated_flags'] + train_info['reward_raw'] = batch_dict['reward_raw'] + train_info['mb_rewards'] = batch_dict['mb_rewards'] + train_info['returns'] = batch_dict['returns'] + self._record_train_batch_info(batch_dict, train_info) + self.post_epoch(self.epoch_num) + + if self.save_kin_info: + print_str = "Kin: " + " \t".join([f"{k}: {torch.mean(torch.tensor(train_info[k])):.4f}" for k, v in train_info.items() if k.startswith("kin")]) + print(print_str) + + return train_info + + def pre_epoch(self, epoch_num): + # print("freeze running mean/std") + + if self.vec_env.env.task.humanoid_type in ["smpl", "smplh", "smplx"]: + humanoid_env = self.vec_env.env.task + if (epoch_num > 1) and epoch_num % humanoid_env.shape_resampling_interval == 1: # + 1 to evade the evaluations. + # if (epoch_num > 0) and epoch_num % humanoid_env.shape_resampling_interval == 0 and not (epoch_num % (self.save_freq)): # Remove the resampling for this. + # Different from AMP, always resample motion no matter the motion type. + print("Resampling Shape") + humanoid_env.resample_motions() + # self.current_rewards # Fixing these values such that they do not get whacked by the + # self.current_lengths + if humanoid_env.getup_schedule: + humanoid_env.update_getup_schedule(epoch_num, getup_udpate_epoch=humanoid_env.getup_udpate_epoch) + if epoch_num > humanoid_env.getup_udpate_epoch: # ZL fix janky hack + self._task_reward_w = 0.5 + self._disc_reward_w = 0.5 + else: + self._task_reward_w = 0 + self._disc_reward_w = 1 + + self.running_mean_std_temp = copy.deepcopy(self.running_mean_std) # Freeze running mean/std, so that the actor does not use the updated mean/std + self.running_mean_std_temp.freeze() + + def post_epoch(self, epoch_num): + self.running_mean_std_temp = copy.deepcopy(self.running_mean_std) # Unfreeze running mean/std + self.running_mean_std_temp.freeze() + + + def _preproc_obs(self, obs_batch, use_temp=False): + if type(obs_batch) is dict: + for k, v in obs_batch.items(): + obs_batch[k] = self._preproc_obs(v, use_temp = use_temp) + else: + if obs_batch.dtype == torch.uint8: + obs_batch = obs_batch.float() / 255.0 + + if self.normalize_input: + obs_batch_proc = obs_batch[:, :self.running_mean_std.mean_size] + if use_temp: + obs_batch_out = self.running_mean_std_temp(obs_batch_proc) + obs_batch_orig = self.running_mean_std(obs_batch_proc) # running through mean std, but do not use its value. use temp + else: + obs_batch_out = self.running_mean_std(obs_batch_proc) # running through mean std, but do not use its value. use temp + obs_batch_out = torch.cat([obs_batch_out, obs_batch[:, self.running_mean_std.mean_size:]], dim=-1) + + return obs_batch_out + + def calc_gradients(self, input_dict): + + self.set_train() + humanoid_env = self.vec_env.env.task + + value_preds_batch = input_dict['old_values'] + old_action_log_probs_batch = input_dict['old_logp_actions'] + advantage = input_dict['advantages'] + old_mu_batch = input_dict['mu'] + old_sigma_batch = input_dict['sigma'] + return_batch = input_dict['returns'] + actions_batch = input_dict['actions'] + obs_batch = input_dict['obs'] + obs_batch_processed = self._preproc_obs(obs_batch, use_temp=self.temp_running_mean) + input_dict['obs_processed'] = obs_batch_processed + + amp_obs = input_dict['amp_obs'][0:self._amp_minibatch_size] + amp_obs = self._preproc_amp_obs(amp_obs) + + amp_obs_replay = input_dict['amp_obs_replay'][0:self._amp_minibatch_size] + amp_obs_replay = self._preproc_amp_obs(amp_obs_replay) + + amp_obs_demo = input_dict['amp_obs_demo'][0:self._amp_minibatch_size] + amp_obs_demo = self._preproc_amp_obs(amp_obs_demo) + amp_obs_demo.requires_grad_(True) + + lr = self.last_lr + kl = 1.0 + lr_mul = 1.0 + curr_e_clip = lr_mul * self.e_clip + + self.train_result = {} + if self.only_kin_loss: + # pure behavior cloning, kinemaitc loss. + batch_dict = {} + batch_dict['obs_orig'] = obs_batch + batch_dict['obs'] = input_dict['obs_processed'] + batch_dict['kin_dict'] = input_dict['kin_dict'] + + # if humanoid_env.z_type == "vae": + # batch_dict['z_noise'] = input_dict['z_noise'] + + rnn_len = self.horizon_length + rnn_len = 1 + if self.is_rnn: + batch_dict['rnn_states'] = input_dict['rnn_states'] + batch_dict['seq_length'] = rnn_len + + kin_loss_info = self._optimize_kin(batch_dict) + self.train_result.update( {'entropy': torch.tensor(0).float(), 'kl': torch.tensor(0).float(), 'last_lr': self.last_lr, 'lr_mul': torch.tensor(0).float()}) + + else: + batch_dict = {'is_train': True, 'amp_steps': self.vec_env.env.task._num_amp_obs_steps, \ + 'prev_actions': actions_batch, 'obs': obs_batch_processed, 'amp_obs': amp_obs, 'amp_obs_replay': amp_obs_replay, 'amp_obs_demo': amp_obs_demo, \ + "obs_orig": obs_batch + } + + rnn_masks = None + rnn_len = self.horizon_length + rnn_len = 1 + if self.is_rnn: + rnn_masks = input_dict['rnn_masks'] + batch_dict['rnn_states'] = input_dict['rnn_states'] + batch_dict['seq_length'] = rnn_len + + + with torch.cuda.amp.autocast(enabled=self.mixed_precision): + res_dict = self.model(batch_dict) # current model if RNN, has BPTT enabled. + + action_log_probs = res_dict['prev_neglogp'] + values = res_dict['values'] + entropy = res_dict['entropy'] + mu = res_dict['mus'] + sigma = res_dict['sigmas'] + disc_agent_logit = res_dict['disc_agent_logit'] + disc_agent_replay_logit = res_dict['disc_agent_replay_logit'] + disc_demo_logit = res_dict['disc_demo_logit'] + + if not rnn_masks is None: + rnn_mask_bool = rnn_masks.squeeze().bool() + old_action_log_probs_batch, action_log_probs, advantage, values, entropy, mu, sigma, return_batch, old_mu_batch, old_sigma_batch = \ + old_action_log_probs_batch[rnn_mask_bool], action_log_probs[rnn_mask_bool], advantage[rnn_mask_bool], values[rnn_mask_bool], \ + entropy[rnn_mask_bool], mu[rnn_mask_bool], sigma[rnn_mask_bool], return_batch[rnn_mask_bool], old_mu_batch[rnn_mask_bool], old_sigma_batch[rnn_mask_bool] + + # flatten values for computing loss + + a_info = self._actor_loss(old_action_log_probs_batch, action_log_probs, advantage, curr_e_clip) + a_loss = a_info['actor_loss'] + + c_info = self._critic_loss(value_preds_batch, values, curr_e_clip, return_batch, self.clip_value) + c_loss = c_info['critic_loss'] + + b_loss = self.bound_loss(mu) + + a_loss = torch.mean(a_loss) + c_loss = torch.mean(c_loss) + b_loss = torch.mean(b_loss) + entropy = torch.mean(entropy) + + disc_agent_cat_logit = torch.cat([disc_agent_logit, disc_agent_replay_logit], dim=0) + + disc_info = self._disc_loss(disc_agent_cat_logit, disc_demo_logit, amp_obs_demo) + disc_loss = disc_info['disc_loss'] + + loss = a_loss + self.critic_coef * c_loss - self.entropy_coef * entropy + self.bounds_loss_coef * b_loss \ + + self._disc_coef * disc_loss + + + a_clip_frac = torch.mean(a_info['actor_clipped'].float()) + + a_info['actor_loss'] = a_loss + a_info['actor_clip_frac'] = a_clip_frac + c_info['critic_loss'] = c_loss + + if self.multi_gpu: + self.optimizer.zero_grad() + else: + for param in self.model.parameters(): + param.grad = None + + self.scaler.scale(loss).backward() + + with torch.no_grad(): + reduce_kl = not self.is_rnn + kl_dist = torch_ext.policy_kl(mu.detach(), sigma.detach(), old_mu_batch, old_sigma_batch, reduce_kl) + if self.is_rnn: + kl_dist = kl_dist.mean() + + + #TODO: Refactor this ugliest code of the year + if self.truncate_grads: + if self.multi_gpu: + self.optimizer.synchronize() + self.scaler.unscale_(self.optimizer) + nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm) + with self.optimizer.skip_synchronize(): + self.scaler.step(self.optimizer) + self.scaler.update() + else: + self.scaler.unscale_(self.optimizer) + nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm) + self.scaler.step(self.optimizer) + self.scaler.update() + else: + self.scaler.step(self.optimizer) + self.scaler.update() + + self.train_result.update( {'entropy': entropy, 'kl': kl_dist, 'last_lr': self.last_lr, 'lr_mul': lr_mul, 'b_loss': b_loss}) + self.train_result.update(a_info) + self.train_result.update(c_info) + self.train_result.update(disc_info) + + if self.save_kin_info: + self.train_result.update(kin_loss_info) + + return + + def _assamble_kin_dict(self, kin_dict_flat): + B = kin_dict_flat.shape[0] + len_acc = 0 + kin_dict = {} + for k, v in self.kin_dict_info.items(): + kin_dict[k] = kin_dict_flat[:, len_acc:(len_acc + v[1][-1])].view(B, *v[0][1:]) + len_acc += v[1][-1] + return kin_dict + + def _optimize_kin(self, batch_dict): + info_dict = {} + humanoid_env = self.vec_env.env.task + if humanoid_env.distill: + kin_dict = self._assamble_kin_dict(batch_dict['kin_dict']) + gt_action = kin_dict['gt_action'] + + kin_body_rot_geo_loss, kin_vel_loss_l2 = 0.0, 0.0 + if humanoid_env.z_type == "vae": + pred_action, pred_action_sigma, extra_dict = self.model.a2c_network.eval_actor(batch_dict, return_extra = True) + # kin_body_loss = (pred_action - gt_action).pow(2).mean() * 10 ## MSE + kin_action_loss = torch.norm(pred_action - gt_action, dim=-1).mean() ## RMSE + + vae_mu, vae_log_var = extra_dict['vae_mu'], extra_dict['vae_log_var'] + if humanoid_env.use_vae_prior or humanoid_env.use_vae_fixed_prior: + prior_mu, prior_log_var = self.model.a2c_network.compute_prior(batch_dict) + KLD = kl_multi(vae_mu, vae_log_var, prior_mu, prior_log_var).mean() + else: + KLD = -0.5 * torch.sum(1 + vae_log_var - vae_mu.pow(2) - vae_log_var.exp()) / vae_mu.shape[0] + + ar1_prior, regu_prior = 0, 0 + if humanoid_env.use_ar1_prior: + time_zs = vae_mu.view(self.minibatch_size // self.horizon_length, self.horizon_length, -1) + phi = 0.99 + + error = time_zs[:, 1:] - time_zs[:, :-1] * phi + + idxes = kin_dict['progress_buf'].view(self.minibatch_size // self.horizon_length, self.horizon_length, -1) + + not_consecs = ((idxes[:, 1:] - idxes[:, :-1]) != 1).view(-1) + error = error.view(-1, error.shape[-1]) + error[not_consecs] = 0 + + starteres = ((idxes <= 2)[:, 1:] + (idxes <= 2)[:, :-1]).view(-1) # make sure the "drop" is not affected. + error[starteres] = 0 + + ar1_prior = torch.norm(error, dim=-1).mean() + info_dict["kin_ar1"] = ar1_prior + + if humanoid_env.use_vae_prior_regu: + prior_mean_regu = ((prior_mu ** 2).mean() + (vae_mu ** 2).mean()) * 0.001 # penalize large prior values + prior_var_regu = ((prior_log_var ** 2).mean() + (vae_log_var ** 2).mean()) * 0.001 # penalize large variance values + regu_prior = prior_mean_regu + prior_var_regu + info_dict["kin_prior_regu"] = regu_prior + + kin_loss = kin_action_loss + KLD * humanoid_env.kld_coefficient + ar1_prior * humanoid_env.ar1_coefficient + regu_prior * 0.005 + + + info_dict["kin_action_loss"] = kin_action_loss + info_dict["kin_KLD"] = KLD + + if KLD > 100: + import ipdb; ipdb.set_trace() + print("KLD is too large, clipping to 10") + + ######### KLD annealing ####### + if humanoid_env.kld_anneal: + anneal_start_epoch = 2500 + anneal_end_epoch = 5000 + min_val = humanoid_env.kld_coefficient_min + if self.epoch_num > anneal_start_epoch: + humanoid_env.kld_coefficient = (0.01 - min_val) * max((anneal_end_epoch -self.epoch_num) / (anneal_end_epoch - anneal_start_epoch), 0) + min_val + info_dict["kin_kld_w"] = humanoid_env.kld_coefficient + ######### KLD annealing ####### + else: + raise NotImplementedError() + + self.kin_optimizer.zero_grad() + kin_loss.backward() + nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm) + self.kin_optimizer.step() + + info_dict.update({"kin_loss": kin_loss}) + + return info_dict + + + + def _load_config_params(self, config): + super()._load_config_params(config) + + self._task_reward_w = config['task_reward_w'] + self._disc_reward_w = config['disc_reward_w'] + + self._amp_observation_space = self.env_info['amp_observation_space'] + self._amp_batch_size = int(config['amp_batch_size']) + self._amp_minibatch_size = int(config['amp_minibatch_size']) + assert (self._amp_minibatch_size <= self.minibatch_size) + + self._disc_coef = config['disc_coef'] + self._disc_logit_reg = config['disc_logit_reg'] + self._disc_grad_penalty = config['disc_grad_penalty'] + self._disc_weight_decay = config['disc_weight_decay'] + self._disc_reward_scale = config['disc_reward_scale'] + self._normalize_amp_input = config.get('normalize_amp_input', True) + return + + def _build_net_config(self): + config = super()._build_net_config() + config['amp_input_shape'] = self._amp_observation_space.shape + + config['task_obs_size_detail'] = self.vec_env.env.task.get_task_obs_size_detail() + if self.vec_env.env.task.has_task: + config['self_obs_size'] = self.vec_env.env.task.get_self_obs_size() + config['task_obs_size'] = self.vec_env.env.task.get_task_obs_size() + + return config + + def _init_train(self): + super()._init_train() + self._init_amp_demo_buf() + return + + + def _oracle_loss(self, obs): + oracle_a, _ = self.oracle_model.a2c_network.eval_actor({"obs": obs}) + model_a, _ = self.model.a2c_network.eval_actor({"obs": obs}) + oracle_loss = (oracle_a - model_a).pow(2).mean(dim=-1) * 50 + return {'oracle_loss': oracle_loss} + + def _disc_loss(self, disc_agent_logit, disc_demo_logit, obs_demo): + ''' + disc_agent_logit: replay and current episode logit (fake examples) + disc_demo_logit: disc_demo_logit logit + obs_demo: gradient penalty demo obs (real examples) + ''' + # prediction loss + disc_loss_agent = self._disc_loss_neg(disc_agent_logit) + disc_loss_demo = self._disc_loss_pos(disc_demo_logit) + + disc_loss = 0.5 * (disc_loss_agent + disc_loss_demo) + + # logit reg + logit_weights = self.model.a2c_network.get_disc_logit_weights() + disc_logit_loss = torch.sum(torch.square(logit_weights)) # make weight small?? + disc_loss += self._disc_logit_reg * disc_logit_loss + + # grad penalty + disc_demo_grad = torch.autograd.grad(disc_demo_logit, obs_demo, grad_outputs=torch.ones_like(disc_demo_logit), create_graph=True, retain_graph=True, only_inputs=True) + disc_demo_grad = disc_demo_grad[0] + + ### ZL Hack for zeroing out gradient penalty on the shape (406,) + # if self.vec_env.env.task.__dict__.get("smpl_humanoid", False): + # humanoid_env = self.vec_env.env.task + # B, feat_dim = disc_demo_grad.shape + # shape_obs_dim = 17 + # if humanoid_env.has_shape_obs: + # amp_obs_dim = int(feat_dim / humanoid_env._num_amp_obs_steps) + # for i in range(humanoid_env._num_amp_obs_steps): + # disc_demo_grad[:, + # ((i + 1) * amp_obs_dim - + # shape_obs_dim):((i + 1) * amp_obs_dim)] = 0 + + disc_demo_grad = torch.sum(torch.square(disc_demo_grad), dim=-1) + + disc_grad_penalty = torch.mean(disc_demo_grad) + disc_loss += self._disc_grad_penalty * disc_grad_penalty + + # weight decay + if (self._disc_weight_decay != 0): + disc_weights = self.model.a2c_network.get_disc_weights() + disc_weights = torch.cat(disc_weights, dim=-1) + disc_weight_decay = torch.sum(torch.square(disc_weights)) + disc_loss += self._disc_weight_decay * disc_weight_decay + + disc_agent_acc, disc_demo_acc = self._compute_disc_acc(disc_agent_logit, disc_demo_logit) + + # print(f"agent_loss: {disc_loss_agent.item():.3f} | disc_loss_demo {disc_loss_demo.item():.3f}") + disc_info = { + 'disc_loss': disc_loss, + 'disc_grad_penalty': disc_grad_penalty.detach(), + 'disc_logit_loss': disc_logit_loss.detach(), + 'disc_agent_acc': disc_agent_acc.detach(), + 'disc_demo_acc': disc_demo_acc.detach(), + 'disc_agent_logit': disc_agent_logit.detach(), + 'disc_demo_logit': disc_demo_logit.detach() + } + return disc_info + + def _disc_loss_neg(self, disc_logits): + bce = torch.nn.BCEWithLogitsLoss() + loss = bce(disc_logits, torch.zeros_like(disc_logits)) + return loss + + def _disc_loss_pos(self, disc_logits): + bce = torch.nn.BCEWithLogitsLoss() + loss = bce(disc_logits, torch.ones_like(disc_logits)) + return loss + + def _compute_disc_acc(self, disc_agent_logit, disc_demo_logit): + agent_acc = disc_agent_logit < 0 + agent_acc = torch.mean(agent_acc.float()) + demo_acc = disc_demo_logit > 0 + demo_acc = torch.mean(demo_acc.float()) + return agent_acc, demo_acc + + def _fetch_amp_obs_demo(self, num_samples): + amp_obs_demo = self.vec_env.env.fetch_amp_obs_demo(num_samples) + return amp_obs_demo + + def _build_amp_buffers(self): + batch_shape = self.experience_buffer.obs_base_shape + self.experience_buffer.tensor_dict['amp_obs'] = torch.zeros(batch_shape + self._amp_observation_space.shape, device=self.ppo_device) + amp_obs_demo_buffer_size = int(self.config['amp_obs_demo_buffer_size']) + self._amp_obs_demo_buffer = replay_buffer.ReplayBuffer(amp_obs_demo_buffer_size, self.ppo_device) # Demo is the data from the dataset. Real samples + + self._amp_replay_keep_prob = self.config['amp_replay_keep_prob'] + replay_buffer_size = int(self.config['amp_replay_buffer_size']) + self._amp_replay_buffer = replay_buffer.ReplayBuffer(replay_buffer_size, self.ppo_device) + + self.tensor_list += ['amp_obs'] + return + + def _init_amp_demo_buf(self): + buffer_size = self._amp_obs_demo_buffer.get_buffer_size() + num_batches = int(np.ceil(buffer_size / self._amp_batch_size)) + + for i in range(num_batches): + curr_samples = self._fetch_amp_obs_demo(self._amp_batch_size) + self._amp_obs_demo_buffer.store({'amp_obs': curr_samples}) + + return + + def _update_amp_demos(self): + new_amp_obs_demo = self._fetch_amp_obs_demo(self._amp_batch_size) + self._amp_obs_demo_buffer.store({'amp_obs': new_amp_obs_demo}) + return + + def _norm_disc_reward(self): + return self._disc_reward_mean_std is not None + + def _preproc_amp_obs(self, amp_obs): + if self._normalize_amp_input: + amp_obs = self._amp_input_mean_std(amp_obs) + return amp_obs + + def _combine_rewards(self, task_rewards, amp_rewards): + disc_r = amp_rewards['disc_rewards'] + + combined_rewards = self._task_reward_w * task_rewards + \ + + self._disc_reward_w * disc_r + return combined_rewards + + def _eval_disc(self, amp_obs): + proc_amp_obs = self._preproc_amp_obs(amp_obs) + return self.model.a2c_network.eval_disc(proc_amp_obs) + + def _calc_amp_rewards(self, amp_obs): + disc_r = self._calc_disc_rewards(amp_obs) + output = {'disc_rewards': disc_r} + return output + + def _calc_disc_rewards(self, amp_obs): + with torch.no_grad(): + disc_logits = self._eval_disc(amp_obs) + prob = 1 / (1 + torch.exp(-disc_logits)) + disc_r = -torch.log(torch.maximum(1 - prob, torch.tensor(0.0001, device=self.ppo_device))) + + if (self._norm_disc_reward()): + self._disc_reward_mean_std.train() + norm_disc_r = self._disc_reward_mean_std(disc_r.flatten()) + disc_r = norm_disc_r.reshape(disc_r.shape) + disc_r = 0.5 * disc_r + 0.25 + + disc_r *= self._disc_reward_scale + + return disc_r + + def _store_replay_amp_obs(self, amp_obs): + buf_size = self._amp_replay_buffer.get_buffer_size() + buf_total_count = self._amp_replay_buffer.get_total_count() + if (buf_total_count > buf_size): + keep_probs = to_torch(np.array([self._amp_replay_keep_prob] * amp_obs.shape[0]), device=self.ppo_device) + keep_mask = torch.bernoulli(keep_probs) == 1.0 + amp_obs = amp_obs[keep_mask] + + if (amp_obs.shape[0] > buf_size): + rand_idx = torch.randperm(amp_obs.shape[0]) + rand_idx = rand_idx[:buf_size] + amp_obs = amp_obs[rand_idx] + + self._amp_replay_buffer.store({'amp_obs': amp_obs}) + return + + def _record_train_batch_info(self, batch_dict, train_info): + super()._record_train_batch_info(batch_dict, train_info) + train_info['disc_rewards'] = batch_dict['disc_rewards'] + return + + def _assemble_train_info(self, train_info, frame): + train_info_dict = super()._assemble_train_info(train_info, frame) + + if "disc_loss" in train_info: + disc_reward_std, disc_reward_mean = torch.std_mean(train_info['disc_rewards']) + train_info_dict.update({ + "disc_loss": torch_ext.mean_list(train_info['disc_loss']).item(), + "disc_agent_acc": torch_ext.mean_list(train_info['disc_agent_acc']).item(), + "disc_demo_acc": torch_ext.mean_list(train_info['disc_demo_acc']).item(), + "disc_agent_logit": torch_ext.mean_list(train_info['disc_agent_logit']).item(), + "disc_demo_logit": torch_ext.mean_list(train_info['disc_demo_logit']).item(), + "disc_grad_penalty": torch_ext.mean_list(train_info['disc_grad_penalty']).item(), + "disc_logit_loss": torch_ext.mean_list(train_info['disc_logit_loss']).item(), + "disc_reward_mean": disc_reward_mean.item(), + "disc_reward_std": disc_reward_std.item(), + }) + + if "returns" in train_info: + train_info_dict['returns'] = train_info['returns'].mean().item() + + if "mb_rewards" in train_info: + train_info_dict['mb_rewards'] = train_info['mb_rewards'].mean().item() + + # if 'terminated_flags' in train_info: + # train_info_dict["success_rate"] = 1 - torch.mean((train_info['terminated_flags'] > 0).float()).item() + + if "reward_raw" in train_info: + for idx, v in enumerate(train_info['reward_raw'].cpu().numpy().tolist()): + train_info_dict[f"ind_reward.{idx}"] = v + + if "sym_loss" in train_info: + train_info_dict['sym_loss'] = torch_ext.mean_list(train_info['sym_loss']).item() + return train_info_dict + + def _amp_debug(self, info): + with torch.no_grad(): + amp_obs = info['amp_obs'] + amp_obs = amp_obs[0:1] + disc_pred = self._eval_disc(amp_obs) + amp_rewards = self._calc_amp_rewards(amp_obs) + disc_reward = amp_rewards['disc_rewards'] + + disc_pred = disc_pred.detach().cpu().numpy()[0, 0] + disc_reward = disc_reward.cpu().numpy()[0, 0] + # print("disc_pred: ", disc_pred, disc_reward) + return diff --git a/phc/learning/amp_datasets.py b/phc/learning/amp_datasets.py new file mode 100644 index 0000000..7c6ea43 --- /dev/null +++ b/phc/learning/amp_datasets.py @@ -0,0 +1,101 @@ +import torch +from rl_games.common import datasets + +class AMPDataset(datasets.PPODataset): + def __init__(self, batch_size, minibatch_size, is_discrete, is_rnn, device, seq_len): + super().__init__(batch_size, minibatch_size, is_discrete, is_rnn, device, seq_len) + self._idx_buf = torch.randperm(self.batch_size) + + + + return + + def update_mu_sigma(self, mu, sigma): + raise NotImplementedError() + return + + # def _get_item_rnn(self, idx): + # gstart = idx * self.num_games_batch + # gend = (idx + 1) * self.num_games_batch + # start = gstart * self.seq_len + # end = gend * self.seq_len + # self.last_range = (start, end) + # input_dict = {} + # for k,v in self.values_dict.items(): + # if k not in self.special_names: + # if v is dict: + # v_dict = { kd:vd[start:end] for kd, vd in v.items() } + # input_dict[k] = v_dict + # else: + # input_dict[k] = v[start:end] + + # rnn_states = self.values_dict['rnn_states'] + # input_dict['rnn_states'] = [s[:,gstart:gend,:] for s in rnn_states] + # return input_dict + + def update_values_dict(self, values_dict, rnn_format = False, horizon_length = 1, num_envs = 1): + self.values_dict = values_dict + self.horizon_length = horizon_length + self.num_envs = num_envs + + if rnn_format and self.is_rnn: + for k,v in self.values_dict.items(): + if k not in self.special_names and v is not None: + self.values_dict[k] = self.values_dict[k].view(self.num_envs, self.horizon_length, -1).squeeze() # Actions are already swapped to the correct format. + if not self.values_dict['rnn_states'] is None: + self.values_dict['rnn_states'] = [s.reshape(self.num_envs, self.horizon_length, -1) for s in self.values_dict['rnn_states']] # rnn_states are not swapped in AMP, so do not swap it here. + self._idx_buf = torch.randperm(self.num_envs) # Update to only shuffle the envs. + + # def _get_item_rnn(self, idx): + # data = super()._get_item_rnn(idx) + # import ipdb; ipdb.set_trace() + # return data + + def _get_item_rnn(self, idx): + # ZL: I am doubling the get_item_rnn function to in a way also get the sequential data. Pretty hacky. + # BPTT, input dict is [batch, seqlen, features]. This function return the sequences that are from the same episide and enviornment in sequentila mannar. Not used at the moment since seq_len is set to 1 for RNN right now. + step_size = int(self.minibatch_size/self.horizon_length) + + start = idx * step_size + end = (idx + 1) * step_size + sample_idx = self._idx_buf[start:end] + + input_dict = {} + + for k,v in self.values_dict.items(): + if k not in self.special_names and v is not None: + input_dict[k] = v[sample_idx, :].view(step_size * self.horizon_length, -1).squeeze() # flatten to batch size + + input_dict['old_values'] = input_dict['old_values'][:, None] # ZL Hack: following compute assumes that the old_values is [batch, 1], so has to change this back. Otherwise, the loss will be wrong. + input_dict['returns'] = input_dict['returns'][:, None] # ZL Hack: following compute assumes that the old_values is [batch, 1], so has to change this back. Otherwise, the loss will be wrong. + + if not self.values_dict['rnn_states'] is None: + input_dict['rnn_states'] = [s[sample_idx, :].view(step_size * self.horizon_length, -1) for s in self.values_dict["rnn_states"]] + + if (end >= self.batch_size): + self._shuffle_idx_buf() + + + return input_dict + + def _get_item(self, idx): + start = idx * self.minibatch_size + end = (idx + 1) * self.minibatch_size + sample_idx = self._idx_buf[start:end] + + input_dict = {} + for k,v in self.values_dict.items(): + if k not in self.special_names and v is not None: + input_dict[k] = v[sample_idx] + + if (end >= self.batch_size): + self._shuffle_idx_buf() + + return input_dict + + def _shuffle_idx_buf(self): + if self.is_rnn: + self._idx_buf = torch.randperm(self.num_envs) + else: + self._idx_buf[:] = torch.randperm(self.batch_size) + return \ No newline at end of file diff --git a/phc/learning/amp_models.py b/phc/learning/amp_models.py new file mode 100644 index 0000000..c3df3fb --- /dev/null +++ b/phc/learning/amp_models.py @@ -0,0 +1,109 @@ +# This is the overall forward pass of the model. + +import torch.nn as nn +from rl_games.algos_torch.models import ModelA2CContinuousLogStd +import torch +class ModelAMPContinuous(ModelA2CContinuousLogStd): + def __init__(self, network): + super().__init__(network) + return + + def build(self, config): + net = self.network_builder.build('amp', **config) + for name, _ in net.named_parameters(): + print(name) + return ModelAMPContinuous.Network(net) + + class Network(ModelA2CContinuousLogStd.Network): + def __init__(self, a2c_network): + super().__init__(a2c_network) + + return + + def forward(self, input_dict): + is_train = input_dict.get('is_train', True) + amp_steps = input_dict.get("amp_steps", 2) + + + result = super().forward(input_dict) + + if (is_train): + amp_obs, amp_obs_replay, amp_demo_obs = input_dict['amp_obs'], input_dict['amp_obs_replay'], input_dict['amp_obs_demo'] + + disc_agent_logit = self.a2c_network.eval_disc(amp_obs) + result["disc_agent_logit"] = disc_agent_logit + + disc_agent_replay_logit = self.a2c_network.eval_disc(amp_obs_replay) + result["disc_agent_replay_logit"] = disc_agent_replay_logit + + disc_demo_logit = self.a2c_network.eval_disc(amp_demo_obs) + result["disc_demo_logit"] = disc_demo_logit + + # # HACK.... + # if input_dict.get("compute_direct_logit", False): + # from phc.utils.torch_utils import project_to_norm + # import ipdb; ipdb.set_trace() + # mus = project_to_norm(result['mus'], input_dict.get("embedding_norm", 1.0)) + # mus = mus.view(-1, 32, 64) + # mus = mus.reshape(-1, 2048) + # result['disc_direct_logit'] = self.a2c_network.eval_disc(mus) + + + # amp_obs.requires_grad_(True) + # disc_agent_logit = self.a2c_network.eval_disc(amp_obs) + # import ipdb; ipdb.set_trace() + # torch.autograd.grad(disc_agent_logit, amp_obs, grad_outputs=torch.ones_like(disc_agent_logit), create_graph=False, retain_graph=True, only_inputs=True) + # torch.autograd.grad(disc_agent_replay_logit, amp_obs_replay, grad_outputs=torch.ones_like(disc_agent_replay_logit), create_graph=False, retain_graph=True, only_inputs=True) + # torch.autograd.grad(disc_demo_logit, amp_demo_obs, grad_outputs=torch.ones_like(disc_demo_logit), create_graph=False, retain_graph=True, only_inputs=True) + # (1 / (1 + torch.exp(-disc_demo_logit)))[:50] + + return result + + def dropout_amp_obs(self, amp_obs, dropout_mask): + return amp_obs * dropout_mask + + def get_dropout_mask(self, + amp_obs, + steps, + num_masks=3, + dropout_rate=0.3): + # ZL Hack: amp_obs_dims, should drop out whole joints + # [root_rot 6, root_vel 3, root_ang_vel 3, dof_pos 23 * 6 - 4 * 6, dof_vel 69 - 12, key_body_pos 3 * 4, shape_obs_disc 11] + # [root_rot 6, root_vel 3, root_ang_vel 3, dof_pos 23 * 6 - 4 * 6, dof_vel 69 - 12, key_body_pos 3 * 4, shape_obs_disc 47] + # 6 + 3 + 3 + 19 * 6 + 19 * 3 + 3 * 4 + 11 = 206 + # 6 + 3 + 3 + 19 * 6 + 19 * 3 + 3 * 4 = 195 # mean body + # 6 + 3 + 3 + 19 * 6 + 19 * 3 + 3 * 4 = 196 # mean body + height + # 1 + 6 + 3 + 3 + 19 * 6 + 19 * 3 + 3 * 4 + 11 = 207 # shape body + height + # 6 + 3 + 3 + 19 * 6 + 19 * 3 + 3 * 4 + 10 = 205 # concise limb weight + # 6 + 3 + 3 + 19 * 6 + 19 * 3 + 3 * 4 + 47 = 242 # full limb weight + # 6 + 3 + 3 + 19 * 6 + 19 * 3 + 3 * 4 + 59 = 254 - masterfoot + B, F = amp_obs.shape + B, _, amp_f = amp_obs.view(B, steps, -1).shape + try: + assert (F / steps == 205 or F / steps == 254 or F / steps == 242 or F / steps == 206 or F / steps == 197 or F / steps == 188 or F / steps == 195 or F / steps == 196 or F / steps == 207) + except: + print(F/steps) + import ipdb; ipdb.set_trace() + print(F/steps) + + dof_joints_offset = 12 # 6 + 3 + 3 + num_joints = 19 + + if F / steps == 197: # Remove neck + num_joints = 18 + elif F / steps == 188: # Remove hands + num_joints = 17 + elif F / steps == 196 or F / steps == 207: + dof_joints_offset = 13 # 1 + 6 + 3 + 3 + + dof_vel_offsets = dof_joints_offset + num_joints * 6 # 12 + 19 * 6 + + dropout_mask = torch.ones([B, amp_f, num_masks]) + + for idx_joint in range(num_joints): + has_drop_out = torch.rand(B, num_masks) > dropout_rate + dropout_mask[:, dof_joints_offset + idx_joint * 6 : dof_joints_offset + idx_joint * 6 + 6, :] = has_drop_out[:, None] + dropout_mask[:, dof_vel_offsets + idx_joint * 3 : dof_vel_offsets + idx_joint * 3 + 3, :] = has_drop_out[:, None] + return dropout_mask.repeat(1, steps, 1).to(amp_obs) + + diff --git a/phc/learning/amp_network_builder.py b/phc/learning/amp_network_builder.py new file mode 100644 index 0000000..08ba5d6 --- /dev/null +++ b/phc/learning/amp_network_builder.py @@ -0,0 +1,253 @@ +from rl_games.algos_torch import torch_ext +from rl_games.algos_torch import layers +import phc.learning.network_builder as network_builder +import torch +import torch.nn as nn +import numpy as np + +DISC_LOGIT_INIT_SCALE = 1.0 + + +class AMPBuilder(network_builder.A2CBuilder): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + return + + class Network(network_builder.A2CBuilder.Network): + + def __init__(self, params, **kwargs): + super().__init__(params, **kwargs) + + if self.is_continuous: + if (not self.space_config['learn_sigma']): + actions_num = kwargs.get('actions_num') + sigma_init = self.init_factory.create(**self.space_config['sigma_init']) + self.sigma = nn.Parameter(torch.zeros(actions_num, requires_grad=False, dtype=torch.float32), requires_grad=False) + sigma_init(self.sigma) + + amp_input_shape = kwargs.get('amp_input_shape') + self._build_disc(amp_input_shape) + + return + + def load(self, params): + super().load(params) + + self._disc_units = params['disc']['units'] + self._disc_activation = params['disc']['activation'] + self._disc_initializer = params['disc']['initializer'] + return + + def forward(self, obs_dict): + states = obs_dict.get('rnn_states', None) + + actor_outputs = self.eval_actor(obs_dict) + value_outputs = self.eval_critic(obs_dict) + + if self.has_rnn: + mu, sigma, a_states = actor_outputs + value, c_states = value_outputs + states = a_states + c_states + output = mu, sigma, value, states + else: + output = actor_outputs + (value_outputs, states) + + return output + + def eval_actor(self, obs_dict): + # RNN is built with Batch-first enabled. + obs = obs_dict['obs'] + states = obs_dict.get('rnn_states', None) + seq_length = obs_dict.get('seq_length', 1) + a_out = self.actor_cnn(obs) + a_out = a_out.contiguous().view(-1, a_out.size(-1)) + + if self.has_rnn: + if not self.is_rnn_before_mlp: + a_out_in = a_out + a_out = self.actor_mlp(a_out_in) + + if self.rnn_concat_input: + a_out = torch.cat([a_out, a_out_in], dim=1) + + batch_size = a_out.size()[0] + num_seqs = batch_size // seq_length + a_out = a_out.reshape(num_seqs, seq_length, -1) + + if self.rnn_name == 'sru': + a_out = a_out.transpose(0, 1) + + ################# New RNN + if len(states) == 2: + a_states = states[0].reshape(num_seqs, seq_length, -1) + else: + a_states = states[:2].reshape(num_seqs, seq_length, -1) + a_out, a_states = self.a_rnn(a_out, a_states[:, 0:1].transpose(0, 1).contiguous()) + + ################ Old RNN + # if len(states) == 2: + # a_states = states[0] + # else: + # a_states = states[:2] + # a_out, a_states = self.a_rnn(a_out, a_states) + + if self.rnn_name == 'sru': + a_out = a_out.transpose(0, 1) + else: + if self.rnn_ln: + a_out = self.a_layer_norm(a_out) + + a_out = a_out.contiguous().reshape(a_out.size()[0] * a_out.size()[1], -1) + + if type(a_states) is not tuple: + a_states = (a_states,) + + if self.is_rnn_before_mlp: + a_out = self.actor_mlp(a_out) + + if self.is_discrete: + logits = self.logits(a_out) + return logits, a_states + + if self.is_multi_discrete: + logits = [logit(a_out) for logit in self.logits] + return logits, a_states + + if self.is_continuous: + mu = self.mu_act(self.mu(a_out)) + if self.space_config['fixed_sigma']: + sigma = mu * 0.0 + self.sigma_act(self.sigma) + else: + sigma = self.sigma_act(self.sigma(a_out)) + + return mu, sigma, a_states + + else: + a_out = self.actor_mlp(a_out) + + # mlp_out = self.actor_mlp(a_out[:1]) + # (self.actor_mlp(a_out[:5])[0] - self.actor_mlp(a_out[:2])[0]).abs() + + if self.is_discrete: + logits = self.logits(a_out) + return logits, + + if self.is_multi_discrete: + logits = [logit(a_out) for logit in self.logits] + return logits, + + if self.is_continuous: + + mu = self.mu_act(self.mu(a_out)) + if self.space_config['fixed_sigma']: + sigma = mu * 0.0 + self.sigma_act(self.sigma) + else: + sigma = self.sigma_act(self.sigma(a_out)) + + return mu, sigma + # return torch.round(mu, decimals=3), sigma + + return + + def get_actor_paramters(self): + return list(self.actor_mlp.parameters()) + list(self.actor_cnn.parameters()) + list(self.mu.parameters()) + + def eval_critic(self, obs_dict): + obs = obs_dict['obs'] + c_out = self.critic_cnn(obs) + c_out = c_out.contiguous().view(-1, c_out.size(-1)) + seq_length = obs_dict.get('seq_length', 1) + states = obs_dict.get('rnn_states', None) + + if self.has_rnn: + if not self.is_rnn_before_mlp: + c_out_in = c_out + c_out = self.critic_mlp(c_out_in) + + if self.rnn_concat_input: + c_out = torch.cat([c_out, c_out_in], dim=1) + + batch_size = c_out.size()[0] + num_seqs = batch_size // seq_length + c_out = c_out.reshape(num_seqs, seq_length, -1) + + if self.rnn_name == 'sru': + c_out = c_out.transpose(0, 1) + ################# New RNN + if len(states) == 2: + c_states = states[1].reshape(num_seqs, seq_length, -1) + else: + c_states = states[2:].reshape(num_seqs, seq_length, -1) + c_out, c_states = self.c_rnn(c_out, c_states[:, 0:1].transpose(0, 1).contiguous()) # ZL: only pass the first state, others are ignored. ??? + + ################# Old RNN + # if len(states) == 2: + # c_states = states[1] + # else: + # c_states = states[2:] + # c_out, c_states = self.c_rnn(c_out, c_states) + + + if self.rnn_name == 'sru': + c_out = c_out.transpose(0, 1) + else: + if self.rnn_ln: + c_out = self.c_layer_norm(c_out) + c_out = c_out.contiguous().reshape(c_out.size()[0] * c_out.size()[1], -1) + + if type(c_states) is not tuple: + c_states = (c_states,) + + if self.is_rnn_before_mlp: + c_out = self.critic_mlp(c_out) + value = self.value_act(self.value(c_out)) + return value, c_states + + else: + c_out = self.critic_mlp(c_out) + + value = self.value_act(self.value(c_out)) + return value + + def eval_disc(self, amp_obs): + disc_mlp_out = self._disc_mlp(amp_obs) + disc_logits = self._disc_logits(disc_mlp_out) + return disc_logits + + def get_disc_logit_weights(self): + return torch.flatten(self._disc_logits.weight) + + def get_disc_weights(self): + weights = [] + for m in self._disc_mlp.modules(): + if isinstance(m, nn.Linear): + weights.append(torch.flatten(m.weight)) + + weights.append(torch.flatten(self._disc_logits.weight)) + return weights + + def _build_disc(self, input_shape): + self._disc_mlp = nn.Sequential() + + mlp_args = {'input_size': input_shape[0], 'units': self._disc_units, 'activation': self._disc_activation, 'dense_func': torch.nn.Linear} + self._disc_mlp = self._build_mlp(**mlp_args) + + mlp_out_size = self._disc_units[-1] + self._disc_logits = torch.nn.Linear(mlp_out_size, 1) + + mlp_init = self.init_factory.create(**self._disc_initializer) + for m in self._disc_mlp.modules(): + if isinstance(m, nn.Linear): + mlp_init(m.weight) + if getattr(m, "bias", None) is not None: + torch.nn.init.zeros_(m.bias) + + torch.nn.init.uniform_(self._disc_logits.weight, -DISC_LOGIT_INIT_SCALE, DISC_LOGIT_INIT_SCALE) + torch.nn.init.zeros_(self._disc_logits.bias) + + return + + def build(self, name, **kwargs): + net = AMPBuilder.Network(self.params, **kwargs) + return net \ No newline at end of file diff --git a/phc/learning/amp_network_mcp_builder.py b/phc/learning/amp_network_mcp_builder.py new file mode 100644 index 0000000..97bc88b --- /dev/null +++ b/phc/learning/amp_network_mcp_builder.py @@ -0,0 +1,87 @@ + +from rl_games.algos_torch import torch_ext +from rl_games.algos_torch import layers +from learning.amp_network_builder import AMPBuilder +import torch +import torch.nn as nn +import numpy as np +import copy + +DISC_LOGIT_INIT_SCALE = 1.0 + + +class AMPMCPBuilder(AMPBuilder): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + return + + def build(self, name, **kwargs): + net = AMPMCPBuilder.Network(self.params, **kwargs) + return net + + class Network(AMPBuilder.Network): + + def __init__(self, params, **kwargs): + self.self_obs_size = kwargs['self_obs_size'] + self.task_obs_size = kwargs['task_obs_size'] + self.task_obs_size_detail = kwargs['task_obs_size_detail'] + self.fut_tracks = self.task_obs_size_detail['fut_tracks'] + self.obs_v = self.task_obs_size_detail['obs_v'] + self.num_traj_samples = self.task_obs_size_detail['num_traj_samples'] + self.track_bodies = self.task_obs_size_detail['track_bodies'] + self.has_softmax = params.get("has_softmax", True) + + kwargs['input_shape'] = (self.self_obs_size + self.task_obs_size,) # + + super().__init__(params, **kwargs) + + self.num_primitive = self.task_obs_size_detail.get("num_prim", 4) + + composer_mlp_args = { + 'input_size': self._calc_input_size((self.self_obs_size + self.task_obs_size,), self.actor_cnn), + 'units': self.units + [self.num_primitive], + 'activation': self.activation, + 'norm_func_name': self.normalization, + 'dense_func': torch.nn.Linear, + 'd2rl': self.is_d2rl, + 'norm_only_first_layer': self.norm_only_first_layer + } + + self.composer = self._build_mlp(**composer_mlp_args) + + if self.has_softmax: + print("!!!Has softmax!!!") + self.composer.append(nn.Softmax(dim=1)) + + self.running_mean = kwargs['mean_std'].running_mean + self.running_var = kwargs['mean_std'].running_var + + def load(self, params): + super().load(params) + return + + def eval_actor(self, obs_dict): + obs = obs_dict['obs'] + a_out = self.actor_cnn(obs) # This is empty + a_out = a_out.contiguous().view(a_out.size(0), -1) + + a_out = self.composer(a_out) + + if self.is_discrete: + logits = self.logits(a_out) + return logits + + if self.is_multi_discrete: + logits = [logit(a_out) for logit in self.logits] + return logits + + if self.is_continuous: + # mu = self.mu_act(self.mu(a_out)) + mu = a_out + if self.space_config['fixed_sigma']: + sigma = mu * 0.0 + self.sigma_act(self.sigma) + else: + sigma = self.sigma_act(self.sigma(a_out)) + return mu, sigma + return diff --git a/phc/learning/amp_network_pnn_builder.py b/phc/learning/amp_network_pnn_builder.py new file mode 100644 index 0000000..9851eae --- /dev/null +++ b/phc/learning/amp_network_pnn_builder.py @@ -0,0 +1,89 @@ + +from rl_games.algos_torch import torch_ext +from rl_games.algos_torch import layers +from learning.amp_network_builder import AMPBuilder +import torch +import torch.nn as nn +import numpy as np +import copy +from phc.learning.pnn import PNN +from rl_games.algos_torch import torch_ext + +DISC_LOGIT_INIT_SCALE = 1.0 + + +class AMPPNNBuilder(AMPBuilder): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + return + + def build(self, name, **kwargs): + net = AMPPNNBuilder.Network(self.params, **kwargs) + return net + + class Network(AMPBuilder.Network): + + def __init__(self, params, **kwargs): + self.self_obs_size = kwargs['self_obs_size'] + self.task_obs_size = kwargs['task_obs_size'] + self.task_obs_size_detail = kwargs['task_obs_size_detail'] + self.fut_tracks = self.task_obs_size_detail['fut_tracks'] + self.obs_v = self.task_obs_size_detail['obs_v'] + self.num_traj_samples = self.task_obs_size_detail['num_traj_samples'] + self.track_bodies = self.task_obs_size_detail['track_bodies'] + self.num_prim = self.task_obs_size_detail['num_prim'] + self.training_prim = self.task_obs_size_detail['training_prim'] + self.model_base = self.task_obs_size_detail['models_path'][0] + self.actors_to_load = self.task_obs_size_detail['actors_to_load'] + self.has_lateral = self.task_obs_size_detail['has_lateral'] + + kwargs['input_shape'] = (self.self_obs_size + self.task_obs_size,) # + + super().__init__(params, **kwargs) + actor_mlp_args = { + 'input_size': self._calc_input_size((self.self_obs_size + self.task_obs_size,), self.actor_cnn), + 'units': self.units, + 'activation': self.activation, + 'norm_func_name': self.normalization, + 'dense_func': torch.nn.Linear, + } + + del self.actor_mlp + self.discrete = params.get("discrete", False) + + self.pnn = PNN(actor_mlp_args, output_size=kwargs['actions_num'], numCols=self.num_prim, has_lateral=self.has_lateral) + # self.pnn.load_base_net(self.model_base, self.actors_to_load) + self.pnn.freeze_pnn(self.training_prim) + + self.running_mean = kwargs['mean_std'].running_mean + self.running_var = kwargs['mean_std'].running_var + + def eval_actor(self, obs_dict): + obs = obs_dict['obs'] + + a_out = self.actor_cnn(obs) # This is empty + a_out = a_out.contiguous().view(a_out.size(0), -1) + a_out, a_outs = self.pnn(a_out, idx=self.training_prim) + + # a_out = a_outs[0] + # print("debugging") # Dubgging!!! + + if self.is_discrete: + logits = self.logits(a_out) + return logits + + if self.is_multi_discrete: + logits = [logit(a_out) for logit in self.logits] + return logits + + if self.is_continuous: + # mu = self.mu_act(self.mu(a_out)) + mu = a_out + if self.space_config['fixed_sigma']: + sigma = mu * 0.0 + self.sigma_act(self.sigma) + else: + sigma = self.sigma_act(self.sigma(a_out)) + + return mu, sigma + return diff --git a/phc/learning/amp_network_z_builder.py b/phc/learning/amp_network_z_builder.py new file mode 100644 index 0000000..2bff154 --- /dev/null +++ b/phc/learning/amp_network_z_builder.py @@ -0,0 +1,598 @@ +from rl_games.algos_torch import torch_ext +from rl_games.algos_torch import layers +from learning.amp_network_builder import AMPBuilder +from phc.learning.network_builder import init_mlp +import torch +import torch.nn as nn +import numpy as np +from phc.utils.torch_utils import project_to_norm +from phc.learning.vq_quantizer import EMAVectorQuantizer, Quantizer +from phc.utils.flags import flags +DISC_LOGIT_INIT_SCALE = 1.0 + + +class AMPZBuilder(AMPBuilder): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + return + + def build(self, name, **kwargs): + net = AMPZBuilder.Network(self.params, **kwargs) + return net + + class Network(AMPBuilder.Network): + + def __init__(self, params, **kwargs): + self.self_obs_size = kwargs['self_obs_size'] + self.task_obs_size = kwargs['task_obs_size'] + self.task_obs_size_detail = kwargs['task_obs_size_detail'] + + self.proj_norm = self.task_obs_size_detail["proj_norm"] + self.embedding_size = self.task_obs_size_detail['embedding_size'] + self.embedding_norm = self.task_obs_size_detail['embedding_norm'] + self.z_readout = self.task_obs_size_detail.get("z_readout", False) + self.z_type = self.task_obs_size_detail.get("z_type", "sphere") + self.dict_size = self.task_obs_size_detail.get("dict_size", 512) + self.z_all = self.task_obs_size_detail.get("z_all", False) + self.embedding_partion = self.task_obs_size_detail.get("embedding_partion", 1) + + self.use_vae_prior = self.task_obs_size_detail.get("use_vae_prior", False) + self.use_vae_fixed_prior = self.task_obs_size_detail.get("use_vae_fixed_prior", False) + self.use_vae_clamped_prior = self.task_obs_size_detail.get("use_vae_clamped_prior", False) + self.use_vae_sphere_prior = self.task_obs_size_detail.get("use_vae_sphere_prior", False) + self.use_vae_sphere_posterior = self.task_obs_size_detail.get("use_vae_sphere_posterior", False) + self.vae_prior_fixed_logvar = self.task_obs_size_detail.get("vae_prior_fixed_logvar", 0) + self.vae_var_clamp_max = self.task_obs_size_detail.get("vae_var_clamp_max", 0) + + ##### Debug utils + flags.idx = 0 + self.debug_idxes = [0] * self.embedding_partion + + if self.z_all: + kwargs['input_shape'] = (self.embedding_size,) # Task embedding size + self_obs + else: + kwargs['input_shape'] = (kwargs['self_obs_size'] + self.embedding_size,) # Task embedding size + self_obs + + + super().__init__(params, **kwargs) + self.running_mean = kwargs['mean_std'].running_mean + self.running_var = kwargs['mean_std'].running_var + + self._build_z_mlp() + if self.z_readout: + self._build_z_reader() + if self.separate: + self._build_critic_z_mlp() + + + self.actor_mlp + + def load(self, params): + super().load(params) + self._task_units = params['task_mlp']['units'] + + self._task_activation = params['task_mlp']['activation'] + self._task_initializer = params['task_mlp']['initializer'] + return + + def form_embedding(self, task_out_z, obs_dict = None): + extra_dict = {} + B, N = task_out_z.shape + if self.z_type == 'vae': + self.vae_mu = vae_mu = self.z_mu(task_out_z) + self.vae_log_var = vae_log_var = self.z_logvar(task_out_z) + + if self.use_vae_clamped_prior: + self.vae_log_var = vae_log_var = torch.clamp(vae_log_var, min = -5, max = self.vae_var_clamp_max) + + if "z_noise" in obs_dict and self.training: # bypass reparatzation and use the noise sampled during training. + task_out_proj = vae_mu + torch.exp(0.5*vae_log_var) * obs_dict['z_noise'] + else: + task_out_proj, self.z_noise = self.reparameterize(vae_mu, vae_log_var) + + if flags.test: + task_out_proj = vae_mu + + if flags.trigger_input: + flags.trigger_input = False + flags.debug = not flags.debug + + if flags.debug: + if self.use_vae_prior or self.use_vae_fixed_prior: + prior_mu, prior_logvar = self.compute_prior(obs_dict) + # if flags.trigger_input: + # ### Trigger input + # task_out_proj[:], noise = self.reparameterize(prior_mu, prior_logvar) ; print("\n debugging", end='') + # flags.trigger_input = False + # else: + # task_out_proj[:] = prior_mu + # task_out_proj[:], noise = self.reparameterize(prior_mu, torch.ones_like(prior_logvar) * -2.3 ) ; print("\r debugging with prior using -2.3 std.", end='') + # task_out_proj[:], noise = self.reparameterize(prior_mu, torch.ones_like(prior_logvar) * -1.5 ) ; print("\r debugging with prior using -1.5 std.", end='') + task_out_proj[:], noise = self.reparameterize(prior_mu, prior_logvar ) ; print(f"\r prior_mu {prior_mu.abs().max():.3f} {prior_logvar.exp().max():.3f}", end='') + # task_out_proj[:] = torch.randn_like(vae_mu) ; print("\r debugging randn", end='') + enhance = 0 + else: + task_out_proj[:] = torch.randn_like(vae_mu) ; print("\r debugging", end='') + + if self.use_vae_sphere_posterior: + task_out_proj = project_to_norm(task_out_proj, norm=self.embedding_norm, z_type="sphere") + + extra_dict = {"vae_mu": vae_mu, "vae_log_var": vae_log_var, "noise": self.z_noise} + + + # prior_mu, prior_logvar = self.compute_prior(obs_dict) + # print(prior_logvar.exp().max()) + # np.set_printoptions(precision=4, suppress=1) + # if "prev_task_out_proj" in self.__dict__: + # diff = self.prev_task_out_proj.cpu().numpy() - task_out_proj.cpu().numpy() + # print(f"{np.abs(diff).max():.4f}", diff) + # if np.abs(diff).max() > 0.5: + # import ipdb; ipdb.set_trace() + # print('...') + # self.prev_task_out_proj = task_out_proj + + # prior_mu, prior_logvar = self.compute_prior(obs_dict) + # import ipdb; ipdb.set_trace() + # print(prior_mu.abs().argmax(), prior_mu.abs().max(), task_out_proj.abs().argmax(), task_out_proj.abs().max()) + + # print(task_out_proj.abs().max(), task_out_proj.abs().argmax(), task_out_proj.cpu().numpy()) + # torch.exp(prior_logvar * 0.5), torch.exp(0.5 * vae_log_var) + # print(torch.exp(prior_logvar * 0.5).mean(), torch.exp(0.5 * vae_log_var).mean()) + # print(task_out_proj.abs().max(), (prior_mu - task_out_proj).abs().cpu().numpy().max(), (prior_mu - task_out_proj).cpu().numpy()) + # import ipdb; ipdb.set_trace() + + elif self.z_type == 'vq_vae': + z_before_quant = task_out_z + # loss, task_out_proj, indexes = self.quantizer(project_to_norm(z_before_quant, norm=self.embedding_norm, z_type="sphere")) + + loss, task_out_proj, indexes = self.quantizer(z_before_quant.view(B, -1, self.embedding_size//self.embedding_partion)) + task_out_proj = task_out_proj.view(B, self.embedding_size) + + # if flags.trigger_input: + # flags.trigger_input = False + # flags.debug = not flags.debug + # enhance = 0.5 + + if flags.debug: + if flags.trigger_input: + indexes_input = input("Enter word indexes:") + try: + self.debug_idxes = [int(i) for i in indexes_input.split()] + except: + import ipdb; ipdb.set_trace() + pass + flags.trigger_input = False + # import ipdb; ipdb.set_trace() + # self.debug_idxes = self.embedding_size//self.embedding_partion, self.embedding_partion + indexes = torch.tensor(self.debug_idxes) + embedding = self.quantizer.embedding.weight.data + fixed_task_out_proj = torch.cat([embedding[self.debug_idxes[idx]] for idx in range(len(self.debug_idxes))])[None, ]; print(" debugging", end='') + + if self.z_all: ## pass thorugh + fixed_task_out_proj = torch.cat([fixed_task_out_proj[:, :int(self.embedding_size * 3/4 )], task_out_proj[:, int(self.embedding_size * 3/4):]], dim=-1) + + task_out_proj = fixed_task_out_proj + + if flags.test: + # print(f'\r {indexes[:self.embedding_partion].numpy()[12:16]} ') + # print(f'\r { "".join([str(i) for i in indexes[:int(self.embedding_partion * 3/4)].numpy()]) } { "".join([str(i) for i in indexes[int(self.embedding_partion * 3/4):].numpy()]) } ') + print(f'\r {indexes[:self.embedding_partion].numpy()} {self.quantizer.embedding.weight.norm(dim = -1 ).data.numpy()} ') + # print(f'\r {indexes[:self.embedding_partion].numpy()} {indexes.unique().numpy()} {self.quantizer.embedding.weight.norm(dim = -1 ).data.numpy()} ') + + else: + if flags.trigger_input: + import ipdb; ipdb.set_trace() + flags.trigger_input = False + print('...') + + extra_dict = {"loss": loss, "indexes": indexes, "z_before_quant": z_before_quant, "quantized_z_out": task_out_proj} + elif self.z_type == 'vq_vae_hybrid': + z_before_quant = self.z_quant(task_out_z) + z_var = self.z_var(task_out_z) + loss, task_out_proj, indexes= self.quantizer(z_before_quant) + z_var = project_to_norm(z_var, norm=0.1, z_type="uniform") + # loss += torch.norm(z_var, dim = -1).mean() + + # task_out_proj = self.quantizer.embedding.weight.data[flags.idx % self.dict_size][None, ]; z_var[:] = 0; print(" debugging", end='') + # print(z_var) + # print(f'\r {indexes[:3].numpy()} {indexes.unique().numpy()} {self.quantizer.embedding.weight.norm(dim = -1 ).data.numpy()} ', end='') + + task_out_proj = torch.cat([task_out_proj, z_var], dim=-1) + extra_dict = {"loss": loss, "indexes": indexes, "z_before_quant": z_before_quant, "quantized_z_out": task_out_proj} + + elif self.z_type == 'vq_vae_res': + + z_before_quant = self.z_quant(task_out_z) + z_var = self.z_var(task_out_z) + + loss, task_out_proj, indexes = self.quantizer(project_to_norm(z_before_quant, norm=self.embedding_norm, z_type="sphere")) + task_out_proj = project_to_norm(task_out_proj, norm= self.embedding_norm, z_type = "sphere") + z_var = torch.sin(z_var) + 1 # bias the number towards 1 + # loss += torch.norm(z_var , dim = -1).mean() + # task_out_proj = self.quantizer.embedding.weight.data[flags.idx % self.dict_size][None, ]; z_var[:] = 1; print(" debugging", end='') + + task_out_proj = task_out_proj * z_var + print(f'\r {indexes[:3].numpy()} {indexes.unique().numpy()} {self.quantizer.embedding.weight.norm(dim = -1 ).data.numpy()} ', end='') + + extra_dict = {"loss": loss, "indexes": indexes, "z_before_quant": z_before_quant, "quantized_z_out": task_out_proj} + elif self.z_type == "sphere": + task_out_proj = project_to_norm(task_out_z, norm=self.embedding_norm, z_type=self.z_type) + + # print(task_out_proj.max(), task_out_proj.min()) + return task_out_proj, extra_dict + + + def compute_prior(self, obs_dict): + obs = obs_dict['obs'] + self_obs = obs[:, :self.self_obs_size] + + prior_latent = self.z_prior(self_obs) + prior_mu = self.z_prior_mu(prior_latent) + if self.use_vae_prior: + prior_logvar = self.z_prior_logvar(prior_latent) + if self.use_vae_clamped_prior: + prior_logvar = torch.clamp(prior_logvar, min = -5, max = self.vae_var_clamp_max) + return prior_mu, prior_logvar + elif self.use_vae_fixed_prior: + if self.use_vae_sphere_prior: + return project_to_norm(prior_mu, z_type="sphere", norm = self.embedding_norm), torch.ones_like(prior_mu) * self.vae_prior_fixed_logvar + else: + return prior_mu, torch.ones_like(prior_mu ) * self.vae_prior_fixed_logvar + + + + def reparameterize(self, mu, logvar): + std = torch.exp(0.5*logvar) + eps = torch.randn_like(std) + return mu + eps * std, eps + + def eval_z(self, obs_dict): + obs = obs_dict['obs'] + + a_out = self.actor_cnn(obs) # This is empty + a_out = a_out.contiguous().view(a_out.size(0), -1) + + z_out = self.z_mlp(obs) + if self.proj_norm: + z_out, extra_dict = self.form_embedding(z_out, obs_dict) + return z_out + + def read_z(self, z): + z_readout = self.z_reader_mlp(z) + return z_readout + + def eval_critic(self, obs_dict): + + obs = obs_dict['obs'] + c_out = self.critic_cnn(obs) + c_out = c_out.contiguous().view(c_out.size(0), -1) + seq_length = obs_dict.get('seq_length', 1) + states = obs_dict.get('rnn_states', None) + + self_obs = obs[:, :self.self_obs_size] + assert (obs.shape[-1] == self.self_obs_size + self.task_obs_size) + #### ZL: add CNN here + + if self.has_rnn: + c_out_in = c_out + c_out = self.critic_z_mlp(c_out_in) + + if self.rnn_concat_input: + c_out = torch.cat([c_out, c_out_in], dim=1) + + batch_size = c_out.size()[0] + num_seqs = batch_size // seq_length + c_out = c_out.reshape(num_seqs, seq_length, -1) + + if self.rnn_name == 'sru': + c_out = c_out.transpose(0, 1) + ################# New RNN + if len(states) == 2: + c_states = states[1].reshape(num_seqs, seq_length, -1) + else: + c_states = states[2:].reshape(num_seqs, seq_length, -1) + c_out, c_states = self.c_rnn(c_out, c_states[:, 0:1].transpose(0, 1).contiguous()) # ZL: only pass the first state, others are ignored. ??? + + ################# Old RNN + # if len(states) == 2: + # c_states = states[1] + # else: + # c_states = states[2:] + # c_out, c_states = self.c_rnn(c_out, c_states) + + + if self.rnn_name == 'sru': + c_out = c_out.transpose(0, 1) + else: + if self.rnn_ln: + c_out = self.c_layer_norm(c_out) + c_out = c_out.contiguous().reshape(c_out.size()[0] * c_out.size()[1], -1) + + if type(c_states) is not tuple: + c_states = (c_states,) + + c_out = self.critic_z_proj_linear(c_out) + # c_out, extra_dict = self.form_embedding(c_out) # do not form VAE embedding for cirtic. + if self.z_type == "sphere": + c_out = project_to_norm(c_out, norm=self.embedding_norm, z_type=self.z_type) + + c_out = torch.cat([self_obs, c_out], dim=-1) + + c_out = self.critic_mlp(c_out) + value = self.value_act(self.value(c_out)) + return value, c_states + + else: + task_out = self.critic_z_mlp(obs) + + # c_out, extra_dict = self.form_embedding(c_out) # do not form VAE embedding for cirtic. + if self.z_type == "sphere": # but we do project for z sphere.... + task_out = project_to_norm(task_out, norm=self.embedding_norm, z_type=self.z_type) + + if self.z_all: + c_input = task_out + else: + c_input = torch.cat([self_obs, task_out], dim=-1) + c_out = self.critic_mlp(c_input) + value = self.value_act(self.value(c_out)) + return value + + def eval_actor(self, obs_dict, return_extra = False): + obs = obs_dict['obs'] + states = obs_dict.get('rnn_states', None) + seq_length = obs_dict.get('seq_length', 1) + + a_out = self.actor_cnn(obs) # This is empty + a_out = a_out.contiguous().view(a_out.size(0), -1) + + self_obs = obs[:, :self.self_obs_size] + task_obs = obs[:, self.self_obs_size:] + assert (obs.shape[-1] == self.self_obs_size + self.task_obs_size) + + if self.has_rnn: + + a_out_in = a_out + + a_out = self.z_mlp(obs) + + if self.rnn_concat_input: + a_out = torch.cat([a_out, a_out_in], dim=1) + + batch_size = a_out.size()[0] + num_seqs = batch_size // seq_length + a_out = a_out.reshape(num_seqs, seq_length, -1) + + if self.rnn_name == 'sru': + a_out = a_out.transpose(0, 1) + + ################# New RNN + if len(states) == 2: + a_states = states[0].reshape(num_seqs, seq_length, -1) + else: + a_states = states[:2].reshape(num_seqs, seq_length, -1) + a_out, a_states = self.a_rnn(a_out, a_states[:, 0:1].transpose(0, 1).contiguous()) + + ################ Old RNN + # if len(states) == 2: + # a_states = states[0] + # else: + # a_states = states[:2] + # a_out, a_states = self.a_rnn(a_out, a_states) + + + if self.rnn_name == 'sru': + a_out = a_out.transpose(0, 1) + else: + if self.rnn_ln: + a_out = self.a_layer_norm(a_out) + + a_out = a_out.contiguous().reshape(a_out.size()[0] * a_out.size()[1], -1) + + z_out = self.z_proj_linear(a_out) + + if self.proj_norm: + z_out, extra_dict = self.form_embedding(z_out, obs_dict) + + if type(a_states) is not tuple: + a_states = (a_states,) + + actor_input = torch.cat([self_obs, z_out], dim=-1) + a_out = self.actor_mlp(actor_input) + + if self.is_discrete: + logits = self.logits(a_out) + return logits, a_states + + if self.is_multi_discrete: + logits = [logit(a_out) for logit in self.logits] + return logits, a_states + + if self.is_continuous: + mu = self.mu_act(self.mu(a_out)) + if self.space_config['fixed_sigma']: + sigma = mu * 0.0 + self.sigma_act(self.sigma) + else: + sigma = self.sigma_act(self.sigma(a_out)) + + if return_extra: + return mu, sigma, a_states, extra_dict + else: + return mu, sigma, a_states + else: + # if self.z_all: + # task_out_z = self.z_mlp(task_obs) + # self_out_z = self.z_self_mlp(self_obs) + # # self_out_z[:] = 0 + # task_out_z = torch.cat([task_out_z, self_out_z], dim=-1) + # else: + # task_out_z = self.z_mlp(obs) + + task_out_z = self.z_mlp(obs) + + if self.proj_norm: + z_out, extra_dict = self.form_embedding(task_out_z, obs_dict) + + # if "z_acc" not in self.__dict__.keys(): + # self.z_acc = [] + # self.z_acc.append(z_out) + # if len(self.z_acc) > 500: + # import ipdb; ipdb.set_trace() + # import joblib;joblib.dump(self.z_acc, "z_acc_compare_3.pkl") + if self.z_all: + actor_input = z_out + else: + actor_input = torch.cat([self_obs, z_out], dim=-1) + + a_out = self.actor_mlp(actor_input) + + if self.is_discrete: + logits = self.logits(a_out) + return logits + + if self.is_multi_discrete: + logits = [logit(a_out) for logit in self.logits] + return logits + + if self.is_continuous: + mu = self.mu_act(self.mu(a_out)) + if self.space_config['fixed_sigma']: + sigma = mu * 0.0 + self.sigma_act(self.sigma) + else: + sigma = self.sigma_act(self.sigma(a_out)) + + if return_extra: + return mu, sigma, extra_dict + else: + return mu, sigma + + def _build_z_mlp(self): + self_obs_size, task_obs_size, task_obs_size_detail = self.self_obs_size, self.task_obs_size, self.task_obs_size_detail + + if self.z_type == "vae" or self.z_type == "vq_vae_hybrid" or self.z_type == "vq_vae_res": + out_size = self.embedding_size * 5 + else: + # if self.z_all: + # out_size = int(self.embedding_size * 3/4 ) + # else: + # out_size = self.embedding_size + out_size = self.embedding_size + + # if self.z_all: + # mlp_input_shape = task_obs_size + # else: + # mlp_input_shape = self_obs_size + task_obs_size # target + + mlp_input_shape = self_obs_size + task_obs_size # target + + mlp_args = {'input_size': mlp_input_shape, 'units': self._task_units, 'activation': self._task_activation, 'dense_func': torch.nn.Linear} + self.z_mlp = self._build_mlp(**mlp_args) + + if not self.has_rnn: + self.z_mlp.append(nn.Linear(in_features=self._task_units[-1], out_features=out_size)) + else: + self.z_proj_linear = nn.Linear(in_features=self.rnn_units, out_features=out_size) + + mlp_init = self.init_factory.create(**self._task_initializer) + init_mlp(self.z_mlp, mlp_init) + + # if self.z_all: + # mlp_args = {'input_size': self_obs_size, 'units': self._self_units, 'activation': self._task_activation, 'dense_func': torch.nn.Linear} + # self.z_self_mlp = self._build_mlp(**mlp_args) + # if not self.has_rnn: + # self.z_self_mlp.append(nn.Linear(in_features=self._self_units[-1], out_features=int(self.embedding_size * 1/4 ))) + # else: + # self.z_self_proj_linear = nn.Linear(in_features=self.rnn_units, out_features=int(self.embedding_size * 1/4 )) + + + if self.z_type == "vae": + self.z_mu = nn.Linear(in_features=self.embedding_size * 5, out_features=self.embedding_size) + self.z_logvar = nn.Linear(in_features=self.embedding_size * 5, out_features=self.embedding_size) + + init_mlp(self.z_mu, mlp_init); init_mlp(self.z_logvar, mlp_init) + + if self.use_vae_prior: + mlp_args = {'input_size': self_obs_size, 'units': self._task_units, 'activation': self._task_activation, 'dense_func': torch.nn.Linear} + self.z_prior = self._build_mlp(**mlp_args) + self.z_prior_mu = nn.Linear(in_features=self._task_units[-1], out_features=self.embedding_size) + self.z_prior_logvar = nn.Linear(in_features=self._task_units[-1], out_features=self.embedding_size) + init_mlp(self.z_prior, mlp_init); init_mlp(self.z_prior_mu, mlp_init); init_mlp(self.z_prior_logvar, mlp_init) + + # import ipdb; ipdb.set_trace() + # print('..... Disabling prior training ......') + # print('..... Disabling prior training ......') + # print('..... Disabling prior training ......') + # self.z_prior.requires_grad_(False) + # self.z_prior_mu.requires_grad_(False) + # self.z_prior_logvar.requires_grad_(False) + + elif self.use_vae_fixed_prior: + mlp_args = {'input_size': self_obs_size, 'units': self._task_units, 'activation': self._task_activation, 'dense_func': torch.nn.Linear} + self.z_prior = self._build_mlp(**mlp_args) + self.z_prior_mu = nn.Linear(in_features=self._task_units[-1], out_features=self.embedding_size) + init_mlp(self.z_prior, mlp_init); init_mlp(self.z_prior_mu, mlp_init) + + elif self.z_type == 'vq_vae': + self.quantizer = Quantizer(self.dict_size, self.embedding_size//self.embedding_partion, 0.25) + # self.quantizer = EMAVectorQuantizer(self.dict_size, self.embedding_size//4, 0.25, decay = 0.99) + + elif self.z_type == 'vq_vae_hybrid': + self.z_quant = nn.Linear(in_features=self.embedding_size * 5, out_features=int(self.embedding_size - 1)) + self.z_var = nn.Linear(in_features=self.embedding_size * 5, out_features=int(1)) + + + # mlp_args = {'input_size': mlp_input_shape, 'units': self._task_units, 'activation': self._task_activation, 'dense_func': torch.nn.Linear} + # self.z_var = self._build_mlp(**mlp_args) + # self.z_var.append(nn.Linear(in_features=self._task_units[-1], out_features=self.embedding_size)) + + init_mlp(self.z_quant, mlp_init); init_mlp(self.z_var, mlp_init) + self.quantizer = Quantizer(self.dict_size, int(self.embedding_size - 1), 0.25) + + elif self.z_type == 'vq_vae_res': + self.z_quant = nn.Linear(in_features=self.embedding_size * 5, out_features=self.embedding_size) + self.z_var = nn.Linear(in_features=self.embedding_size * 5, out_features=1) + + self.quantizer = Quantizer(self.dict_size, self.embedding_size, 0.25) + init_mlp(self.z_quant, mlp_init); init_mlp(self.z_var, mlp_init) + return + + def _build_critic_z_mlp(self): + self_obs_size, task_obs_size, task_obs_size_detail = self.self_obs_size, self.task_obs_size, self.task_obs_size_detail + mlp_input_shape = self_obs_size + task_obs_size # target + + self.critic_z_mlp = nn.Sequential() + mlp_args = {'input_size': mlp_input_shape, 'units': self._task_units, 'activation': self._task_activation, 'dense_func': torch.nn.Linear} + self.critic_z_mlp = self._build_mlp(**mlp_args) + + if not self.has_rnn: + self.critic_z_mlp.append(nn.Linear(in_features=self._task_units[-1], out_features=self.embedding_size)) + else: + self.critic_z_proj_linear = nn.Linear(in_features=self._task_units[-1], out_features=self.embedding_size) + + + mlp_init = self.init_factory.create(**self._task_initializer) + for m in self.critic_z_mlp.modules(): + if isinstance(m, nn.Linear): + mlp_init(m.weight) + if getattr(m, "bias", None) is not None: + torch.nn.init.zeros_(m.bias) + + return + + def _build_z_reader(self): + self_obs_size, task_obs_size, task_obs_size_detail = self.self_obs_size, self.task_obs_size, self.task_obs_size_detail + mlp_input_shape = self.embedding_size # target + + self.z_reader_mlp = nn.Sequential() + mlp_args = {'input_size': mlp_input_shape, 'units': self._task_units, 'activation': self._task_activation, 'dense_func': torch.nn.Linear} + self.z_reader_mlp = self._build_mlp(**mlp_args) + self.z_reader_mlp.append(nn.Linear(in_features=self._task_units[-1], out_features=72)) + + mlp_init = self.init_factory.create(**self._task_initializer) + for m in self.z_reader_mlp.modules(): + if isinstance(m, nn.Linear): + mlp_init(m.weight) + if getattr(m, "bias", None) is not None: + torch.nn.init.zeros_(m.bias) + + return diff --git a/phc/learning/amp_players.py b/phc/learning/amp_players.py new file mode 100644 index 0000000..912a5f7 --- /dev/null +++ b/phc/learning/amp_players.py @@ -0,0 +1,312 @@ +import torch + + +from rl_games.algos_torch import torch_ext +from phc.utils.running_mean_std import RunningMeanStd +from rl_games.common.player import BasePlayer +import learning.common_player as common_player + +from rl_games.common.tr_helpers import unsqueeze_obs + +def rescale_actions(low, high, action): + d = (high - low) / 2.0 + m = (high + low) / 2.0 + scaled_action = action * d + m + return scaled_action + +class AMPPlayerContinuous(common_player.CommonPlayer): + def __init__(self, config): + self._normalize_amp_input = config.get('normalize_amp_input', True) + self._normalize_input = config['normalize_input'] + self._disc_reward_scale = config['disc_reward_scale'] + + super().__init__(config) + + # self.env.task.update_value_func(self._eval_critic, self._eval_actor) + # import copy + # self.orcale_model = copy.deepcopy(self.model) + # checkpoint = torch_ext.load_checkpoint("output/dgx/smpl_im_master_singles_6_3/Humanoid_00031250.pth") + # self.orcale_model.load_state_dict(checkpoint['model']) + return + + # #### Oracle debug + # def get_action(self, obs, is_determenistic=False): + # obs = obs['obs'] + # if self.has_batch_dimension == False: + # obs = unsqueeze_obs(obs) + # obs = self._preproc_obs(obs) + # input_dict = { + # 'is_train': False, + # 'prev_actions': None, + # 'obs': obs, + # 'rnn_states': self.states + # } + # with torch.no_grad(): + # res_dict = self.orcale_model(input_dict) + # print("orcale_model") + + # mu = res_dict['mus'] + # action = res_dict['actions'] + # self.states = res_dict['rnn_states'] + # if is_determenistic: + # current_action = mu + # else: + # current_action = action + # if self.has_batch_dimension == False: + # current_action = torch.squeeze(current_action.detach()) + + # if self.clip_actions: + # return rescale_actions(self.actions_low, self.actions_high, + # torch.clamp(current_action, -1.0, 1.0)) + # else: + # return current_action + + def restore(self, fn): + super().restore(fn) + if self._normalize_amp_input: + checkpoint = torch_ext.load_checkpoint(fn) + self._amp_input_mean_std.load_state_dict(checkpoint['amp_input_mean_std']) + + if self._normalize_input: + self.running_mean_std.load_state_dict(checkpoint['running_mean_std']) + + return + + def _build_net(self, config): + super()._build_net(config) + + if self._normalize_amp_input: + self._amp_input_mean_std = RunningMeanStd(config['amp_input_shape']).to(self.device) + self._amp_input_mean_std.eval() + + return + + def _eval_critic(self, input): + input = self._preproc_obs(input) + return self.model.a2c_network.eval_critic(input) + + def _post_step(self, info): + super()._post_step(info) + if (self.env.task.viewer): + self._amp_debug(info) + + return + + def _eval_task_value(self, input): + input = self._preproc_obs(input) + return self.model.a2c_network.eval_task_value(input) + + + def _build_net_config(self): + config = super()._build_net_config() + if (hasattr(self, 'env')): + config['amp_input_shape'] = self.env.amp_observation_space.shape + config['task_obs_size_detail'] = self.env.task.get_task_obs_size_detail() + if self.env.task.has_task: + config['self_obs_size'] = self.env.task.get_self_obs_size() + config['task_obs_size'] = self.env.task.get_task_obs_size() + + else: + config['amp_input_shape'] = self.env_info['amp_observation_space'] + + # if self.env.task.has_task: + # config['task_obs_size_detail'] = self.vec_env.env.task.get_task_obs_size_detail() + # config['self_obs_size'] = self.vec_env.env.task.get_self_obs_size() + # config['task_obs_size'] = self.vec_env.env.task.get_task_obs_size() + + return config + + def _amp_debug(self, info): + return + + def _preproc_amp_obs(self, amp_obs): + if self._normalize_amp_input: + amp_obs = self._amp_input_mean_std(amp_obs) + return amp_obs + + def _eval_disc(self, amp_obs): + proc_amp_obs = self._preproc_amp_obs(amp_obs) + return self.model.a2c_network.eval_disc(proc_amp_obs) + + def _eval_actor(self, input): + input = self._preproc_obs(input) + return self.model.a2c_network.eval_actor(input) + + def _preproc_obs(self, obs_batch): + + if type(obs_batch) is dict: + for k, v in obs_batch.items(): + obs_batch[k] = self._preproc_obs(v) + else: + if obs_batch.dtype == torch.uint8: + obs_batch = obs_batch.float() / 255.0 + if self.normalize_input: + obs_batch_proc = obs_batch[:, :self.running_mean_std.mean_size] + obs_batch_out = self.running_mean_std(obs_batch_proc) + obs_batch = torch.cat([obs_batch_out, obs_batch[:, self.running_mean_std.mean_size:]], dim=-1) + + return obs_batch + + + def _calc_amp_rewards(self, amp_obs): + disc_r = self._calc_disc_rewards(amp_obs) + output = { + 'disc_rewards': disc_r + } + return output + + def _calc_disc_rewards(self, amp_obs): + with torch.no_grad(): + disc_logits = self._eval_disc(amp_obs) + prob = 1 / (1 + torch.exp(-disc_logits)) + disc_r = -torch.log(torch.maximum(1 - prob, torch.tensor(0.0001, device=self.device))) + disc_r *= self._disc_reward_scale + return disc_r + + +class AMPPlayerDiscrete(common_player.CommonPlayerDiscrete): + def __init__(self, config): + self._normalize_amp_input = config.get('normalize_amp_input', True) + self._normalize_input = config['normalize_input'] + self._disc_reward_scale = config['disc_reward_scale'] + + super().__init__(config) + + # self.env.task.update_value_func(self._eval_critic, self._eval_actor) + # import copy + # self.orcale_model = copy.deepcopy(self.model) + # checkpoint = torch_ext.load_checkpoint("output/dgx/smpl_im_master_singles_6_3/Humanoid_00031250.pth") + # self.orcale_model.load_state_dict(checkpoint['model']) + return + + # #### Oracle debug + # def get_action(self, obs, is_determenistic=False): + # obs = obs['obs'] + # if self.has_batch_dimension == False: + # obs = unsqueeze_obs(obs) + # obs = self._preproc_obs(obs) + # input_dict = { + # 'is_train': False, + # 'prev_actions': None, + # 'obs': obs, + # 'rnn_states': self.states + # } + # with torch.no_grad(): + # res_dict = self.orcale_model(input_dict) + # print("orcale_model") + + # mu = res_dict['mus'] + # action = res_dict['actions'] + # self.states = res_dict['rnn_states'] + # if is_determenistic: + # current_action = mu + # else: + # current_action = action + # if self.has_batch_dimension == False: + # current_action = torch.squeeze(current_action.detach()) + + # if self.clip_actions: + # return rescale_actions(self.actions_low, self.actions_high, + # torch.clamp(current_action, -1.0, 1.0)) + # else: + # return current_action + + def restore(self, fn): + super().restore(fn) + if self._normalize_amp_input: + checkpoint = torch_ext.load_checkpoint(fn) + self._amp_input_mean_std.load_state_dict(checkpoint['amp_input_mean_std']) + + if self._normalize_input: + self.running_mean_std.load_state_dict(checkpoint['running_mean_std']) + + return + + def _build_net(self, config): + super()._build_net(config) + + if self._normalize_amp_input: + self._amp_input_mean_std = RunningMeanStd(config['amp_input_shape']).to(self.device) + self._amp_input_mean_std.eval() + + return + + def _eval_critic(self, input): + input = self._preproc_input(input) + return self.model.a2c_network.eval_critic(input) + + def _post_step(self, info): + super()._post_step(info) + if (self.env.task.viewer): + self._amp_debug(info) + + return + + def _eval_task_value(self, input): + input = self._preproc_input(input) + return self.model.a2c_network.eval_task_value(input) + + + def _build_net_config(self): + config = super()._build_net_config() + if (hasattr(self, 'env')): + config['amp_input_shape'] = self.env.amp_observation_space.shape + config['task_obs_size_detail'] = self.env.task.get_task_obs_size_detail() + if self.env.task.has_task: + config['self_obs_size'] = self.env.task.get_self_obs_size() + config['task_obs_size'] = self.env.task.get_task_obs_size() + + else: + config['amp_input_shape'] = self.env_info['amp_observation_space'] + config['task_obs_size_detail'] = self.vec_env.env.task.get_task_obs_size_detail() + if self.env.task.has_task: + config['self_obs_size'] = self.vec_env.env.task.get_self_obs_size() + config['task_obs_size'] = self.vec_env.env.task.get_task_obs_size() + + return config + + def _amp_debug(self, info): + return + + def _preproc_amp_obs(self, amp_obs): + if self._normalize_amp_input: + amp_obs = self._amp_input_mean_std(amp_obs) + return amp_obs + + def _eval_disc(self, amp_obs): + proc_amp_obs = self._preproc_amp_obs(amp_obs) + return self.model.a2c_network.eval_disc(proc_amp_obs) + + def _eval_actor(self, input): + input = self._preproc_input(input) + return self.model.a2c_network.eval_actor(input) + + def _preproc_obs(self, obs_batch): + + if type(obs_batch) is dict: + for k, v in obs_batch.items(): + obs_batch[k] = self._preproc_obs(v) + else: + if obs_batch.dtype == torch.uint8: + obs_batch = obs_batch.float() / 255.0 + if self.normalize_input: + obs_batch_proc = obs_batch[:, :self.running_mean_std.mean_size] + obs_batch_out = self.running_mean_std(obs_batch_proc) + obs_batch = torch.cat([obs_batch_out, obs_batch[:, self.running_mean_std.mean_size:]], dim=-1) + + return obs_batch + + def _calc_amp_rewards(self, amp_obs): + disc_r = self._calc_disc_rewards(amp_obs) + output = { + 'disc_rewards': disc_r + } + return output + + def _calc_disc_rewards(self, amp_obs): + with torch.no_grad(): + disc_logits = self._eval_disc(amp_obs) + prob = 1 / (1 + torch.exp(-disc_logits)) + disc_r = -torch.log(torch.maximum(1 - prob, torch.tensor(0.0001, device=self.device))) + disc_r *= self._disc_reward_scale + return disc_r diff --git a/phc/learning/ar_prior.py b/phc/learning/ar_prior.py new file mode 100644 index 0000000..ddd4557 --- /dev/null +++ b/phc/learning/ar_prior.py @@ -0,0 +1,16 @@ +import torch +import torch.nn as nn + +class AR1Prior(nn.Module): + def __init__(self): + super(AR1Prior, self).__init__() + # Initializing phi as a learnable parameter + # self.phi = nn.Parameter(torch.tensor(0.5)) + self.phi = 0.95 + + def forward(self, series): + # Calculate the likelihood of the series given phi + # Ignoring the first term since it doesn't have a previous term + error = series[1:] - self.phi * series[:-1] + log_likelihood = -0.5 * torch.sum(error**2) + return log_likelihood diff --git a/phc/learning/common_agent.py b/phc/learning/common_agent.py new file mode 100644 index 0000000..46b3d0f --- /dev/null +++ b/phc/learning/common_agent.py @@ -0,0 +1,1124 @@ +import copy +from datetime import datetime +from gym import spaces +import numpy as np +import os +import time +import yaml +import glob +import sys +import pdb +import os.path as osp + +sys.path.append(os.getcwd()) + +from rl_games.algos_torch import a2c_continuous, a2c_discrete +from rl_games.algos_torch import torch_ext +from rl_games.algos_torch import central_value +from phc.utils.running_mean_std import RunningMeanStd +from rl_games.common import a2c_common +from rl_games.common import datasets +from rl_games.common import schedulers +from rl_games.common import vecenv + +import torch +from torch import optim +from gym import spaces + +import learning.amp_datasets as amp_datasets + +from tensorboardX import SummaryWriter +import wandb + + +class CommonAgent(a2c_continuous.A2CAgent): + + def __init__(self, base_name, config): + a2c_common.A2CBase.__init__(self, base_name, config) + self.cfg = config + self.exp_name = self.cfg['train_dir'].split('/')[-1] + + self._load_config_params(config) + + self.is_discrete = False + self._setup_action_space() + self.bounds_loss_coef = config.get('bounds_loss_coef', None) + + self.clip_actions = config.get('clip_actions', True) + self._save_intermediate = config.get('save_intermediate', False) + + net_config = self._build_net_config() + + if self.normalize_input: + if "vec_env" in self.__dict__: + obs_shape = torch_ext.shape_whc_to_cwh(self.vec_env.env.task.get_running_mean_size()) + else: + obs_shape = self.obs_shape + self.running_mean_std = RunningMeanStd(obs_shape).to(self.ppo_device) + + net_config['mean_std'] = self.running_mean_std + self.model = self.network.build(net_config) + self.model.to(self.ppo_device) + self.states = None + + self.init_rnn_from_model(self.model) + self.last_lr = float(self.last_lr) + + self.optimizer = optim.Adam(self.model.parameters(), float(self.last_lr), eps=1e-08, weight_decay=self.weight_decay) + + if self.has_central_value: + cv_config = { + 'state_shape': torch_ext.shape_whc_to_cwh(self.state_shape), + 'value_size': self.value_size, + 'ppo_device': self.ppo_device, + 'num_agents': self.num_agents, + 'horizon_length': self.horizon_length, + 'num_actors': self.num_actors, + 'num_actions': self.actions_num, + 'seq_len': self.seq_len, + 'model': self.central_value_config['network'], + 'config': self.central_value_config, + 'writter': self.writer, + 'multi_gpu': self.multi_gpu + } + self.central_value_net = central_value.CentralValueTrain(**cv_config).to(self.ppo_device) + + self.use_experimental_cv = self.config.get('use_experimental_cv', True) + self.dataset = amp_datasets.AMPDataset(self.batch_size, self.minibatch_size, self.is_discrete, self.is_rnn, self.ppo_device, self.seq_len) + self.algo_observer.after_init(self) + + return + + def init_tensors(self): + super().init_tensors() + self.experience_buffer.tensor_dict['next_obses'] = torch.zeros_like(self.experience_buffer.tensor_dict['obses']) + self.experience_buffer.tensor_dict['next_values'] = torch.zeros_like(self.experience_buffer.tensor_dict['values']) + + self.tensor_list += ['next_obses'] + return + + def train(self): + self.init_tensors() + self.last_mean_rewards = -100500 + start_time = time.time() + total_time = 0 + rep_count = 0 + self.frame = 0 + self.obs = self.env_reset() + self.curr_frames = self.batch_size_envs + + model_output_file = osp.join(self.network_path, self.config['name']) + + if self.multi_gpu: + self.hvd.setup_algo(self) + + self._init_train() + + while True: + epoch_start = time.time() + + epoch_num = self.update_epoch() + train_info = self.train_epoch() + + sum_time = train_info['total_time'] + total_time += sum_time + frame = self.frame + if self.multi_gpu: + self.hvd.sync_stats(self) + + if self.rank == 0: + scaled_time = sum_time + scaled_play_time = train_info['play_time'] + curr_frames = self.curr_frames + self.frame += curr_frames + fps_step = curr_frames / scaled_play_time + fps_total = curr_frames / scaled_time + + self.writer.add_scalar('performance/total_fps', curr_frames / scaled_time, frame) + self.writer.add_scalar('performance/step_fps', curr_frames / scaled_play_time, frame) + self.writer.add_scalar('info/epochs', epoch_num, frame) + train_info_dict = self._assemble_train_info(train_info, frame) + self.algo_observer.after_print_stats(frame, epoch_num, total_time) + if self.save_freq > 0: + + if epoch_num % min(50, self.save_best_after) == 0: + self.save(model_output_file) + + if (self._save_intermediate) and (epoch_num % (self.save_freq) == 0): + # Save intermediate model every save_freq epoches + int_model_output_file = model_output_file + '_' + str(epoch_num).zfill(8) + self.save(int_model_output_file) + + if self.game_rewards.current_size > 0: + mean_rewards = self._get_mean_rewards() + mean_lengths = self.game_lengths.get_mean() + + for i in range(self.value_size): + self.writer.add_scalar('rewards{0}/frame'.format(i), mean_rewards[i], frame) + self.writer.add_scalar('rewards{0}/iter'.format(i), mean_rewards[i], epoch_num) + self.writer.add_scalar('rewards{0}/time'.format(i), mean_rewards[i], total_time) + + self.writer.add_scalar('episode_lengths/frame', mean_lengths, frame) + self.writer.add_scalar('episode_lengths/iter', mean_lengths, epoch_num) + + if (self._save_intermediate) and (epoch_num % (self.save_freq) == 0): + eval_info = self.eval() + train_info_dict.update(eval_info) + + train_info_dict.update({"episode_lengths": mean_lengths, "mean_rewards": np.mean(mean_rewards)}) + self._log_train_info(train_info_dict, frame) + + epoch_end = time.time() + log_str = f"{self.exp_name}-Ep: {self.epoch_num}\trwd: {np.mean(mean_rewards):.1f}\tfps_step: {fps_step:.1f}\tfps_total: {fps_total:.1f}\tep_time:{epoch_end - epoch_start:.1f}\tframe: {self.frame}\teps_len: {mean_lengths:.1f}" + + print(log_str) + + if self.has_self_play_config: + self.self_play_manager.update(self) + + if epoch_num > self.max_epochs: + self.save(model_output_file) + print('MAX EPOCHS NUM!') + return self.last_mean_rewards, epoch_num + + update_time = 0 + return + + def eval(self): + print("evaluation routine not implemented") + return {} + + def train_epoch(self): + play_time_start = time.time() + with torch.no_grad(): + if self.is_rnn: + batch_dict = self.play_steps_rnn() + else: + batch_dict = self.play_steps() + + play_time_end = time.time() + update_time_start = time.time() + rnn_masks = batch_dict.get('rnn_masks', None) + + self.set_train() + + self.curr_frames = batch_dict.pop('played_frames') + self.prepare_dataset(batch_dict) + self.algo_observer.after_steps() + + if self.has_central_value: + self.train_central_value() + + train_info = None + + if self.is_rnn: + frames_mask_ratio = rnn_masks.sum().item() / (rnn_masks.nelement()) + print(frames_mask_ratio) + + for _ in range(0, self.mini_epochs_num): + ep_kls = [] + for i in range(len(self.dataset)): + curr_train_info = self.train_actor_critic(self.dataset[i]) + + if self.schedule_type == 'legacy': + if self.multi_gpu: + curr_train_info['kl'] = self.hvd.average_value(curr_train_info['kl'], 'ep_kls') + self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, curr_train_info['kl'].item()) + self.update_lr(self.last_lr) + + if (train_info is None): + train_info = dict() + for k, v in curr_train_info.items(): + train_info[k] = [v] + else: + for k, v in curr_train_info.items(): + train_info[k].append(v) + + av_kls = torch_ext.mean_list(train_info['kl']) + + if self.schedule_type == 'standard': + if self.multi_gpu: + av_kls = self.hvd.average_value(av_kls, 'ep_kls') + self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item()) + self.update_lr(self.last_lr) + + if self.schedule_type == 'standard_epoch': + if self.multi_gpu: + av_kls = self.hvd.average_value(torch_ext.mean_list(kls), 'ep_kls') + self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item()) + self.update_lr(self.last_lr) + + update_time_end = time.time() + play_time = play_time_end - play_time_start + update_time = update_time_end - update_time_start + total_time = update_time_end - play_time_start + + train_info['play_time'] = play_time + train_info['update_time'] = update_time + train_info['total_time'] = total_time + self._record_train_batch_info(batch_dict, train_info) + return train_info + + def get_action_values(self, obs): + obs_orig = obs['obs'] + processed_obs = self._preproc_obs(obs['obs']) + self.model.eval() + input_dict = { + 'is_train': False, + 'prev_actions': None, + 'obs' : processed_obs, + "obs_orig": obs_orig, + 'rnn_states' : self.rnn_states + } + + with torch.no_grad(): + res_dict = self.model(input_dict) + if self.has_central_value: + states = obs['states'] + input_dict = { + 'is_train': False, + 'states' : states, + #'actions' : res_dict['action'], + #'rnn_states' : self.rnn_states + } + value = self.get_central_value(input_dict) + res_dict['values'] = value + if self.normalize_value: + res_dict['values'] = self.value_mean_std(res_dict['values'], True) + return res_dict + + def play_steps(self): + self.set_eval() + + epinfos = [] + done_indices = [] + update_list = self.update_list + + for n in range(self.horizon_length): + self.obs = self.env_reset(done_indices) + + self.experience_buffer.update_data('obses', n, self.obs['obs']) + + if self.use_action_masks: + masks = self.vec_env.get_action_masks() + res_dict = self.get_masked_action_values(self.obs, masks) + else: + res_dict = self.get_action_values(self.obs) + + for k in update_list: + self.experience_buffer.update_data(k, n, res_dict[k]) + + if self.has_central_value: + self.experience_buffer.update_data('states', n, self.obs['states']) + + self.obs, rewards, self.dones, infos = self.env_step(res_dict['actions']) + shaped_rewards = self.rewards_shaper(rewards) + self.experience_buffer.update_data('rewards', n, shaped_rewards) + self.experience_buffer.update_data('next_obses', n, self.obs['obs']) + self.experience_buffer.update_data('dones', n, self.dones) + + terminated = infos['terminate'].float() + terminated = terminated.unsqueeze(-1) + + next_vals = self._eval_critic(self.obs) + next_vals *= (1.0 - terminated) + self.experience_buffer.update_data('next_values', n, next_vals) + + self.current_rewards += rewards + self.current_lengths += 1 + all_done_indices = self.dones.nonzero(as_tuple=False) + done_indices = all_done_indices[::self.num_agents] + + self.game_rewards.update(self.current_rewards[done_indices]) + self.game_lengths.update(self.current_lengths[done_indices]) + self.algo_observer.process_infos(infos, done_indices) + + not_dones = 1.0 - self.dones.float() + + self.current_rewards = self.current_rewards * not_dones.unsqueeze(1) + self.current_lengths = self.current_lengths * not_dones + + done_indices = done_indices[:, 0] + + mb_fdones = self.experience_buffer.tensor_dict['dones'].float() + mb_values = self.experience_buffer.tensor_dict['values'] + mb_next_values = self.experience_buffer.tensor_dict['next_values'] + mb_rewards = self.experience_buffer.tensor_dict['rewards'] + + mb_advs = self.discount_values(mb_fdones, mb_values, mb_rewards, mb_next_values) + mb_returns = mb_advs + mb_values + + batch_dict = self.experience_buffer.get_transformed_list(a2c_common.swap_and_flatten01, self.tensor_list) + batch_dict['returns'] = a2c_common.swap_and_flatten01(mb_returns) + batch_dict['played_frames'] = self.batch_size + + return batch_dict + + def prepare_dataset(self, batch_dict): + obses = batch_dict['obses'] + returns = batch_dict['returns'] + dones = batch_dict['dones'] + values = batch_dict['values'] + actions = batch_dict['actions'] + neglogpacs = batch_dict['neglogpacs'] + mus = batch_dict['mus'] + sigmas = batch_dict['sigmas'] + rnn_states = batch_dict.get('rnn_states', None) + rnn_masks = batch_dict.get('rnn_masks', None) + + advantages = self._calc_advs(batch_dict) + + if self.normalize_value: + values = self.value_mean_std(values) + returns = self.value_mean_std(returns) + + dataset_dict = {} + dataset_dict['old_values'] = values + dataset_dict['old_logp_actions'] = neglogpacs + dataset_dict['advantages'] = advantages + dataset_dict['returns'] = returns + dataset_dict['actions'] = actions + dataset_dict['obs'] = obses + dataset_dict['rnn_states'] = rnn_states + dataset_dict['rnn_masks'] = rnn_masks + dataset_dict['mu'] = mus + dataset_dict['sigma'] = sigmas + + if self.has_central_value: + dataset_dict = {} + dataset_dict['old_values'] = values + dataset_dict['advantages'] = advantages + dataset_dict['returns'] = returns + dataset_dict['actions'] = actions + dataset_dict['obs'] = batch_dict['states'] + dataset_dict['rnn_masks'] = rnn_masks + self.central_value_net.update_dataset(dataset_dict) + + self.dataset.update_values_dict(dataset_dict) + return dataset_dict + + def calc_gradients(self, input_dict): + self.set_train() + + value_preds_batch = input_dict['old_values'] + old_action_log_probs_batch = input_dict['old_logp_actions'] + advantage = input_dict['advantages'] + old_mu_batch = input_dict['mu'] + old_sigma_batch = input_dict['sigma'] + return_batch = input_dict['returns'] + actions_batch = input_dict['actions'] + obs_batch = input_dict['obs'] + obs_batch = self._preproc_obs(obs_batch) + + lr = self.last_lr + kl = 1.0 + lr_mul = 1.0 + curr_e_clip = lr_mul * self.e_clip + + batch_dict = {'is_train': True, 'prev_actions': actions_batch, 'obs': obs_batch} + + rnn_masks = None + if self.is_rnn: + rnn_masks = input_dict['rnn_masks'] + batch_dict['rnn_states'] = input_dict['rnn_states'] + batch_dict['seq_length'] = self.seq_len + + with torch.cuda.amp.autocast(enabled=self.mixed_precision): + res_dict = self.model(batch_dict) + action_log_probs = res_dict['prev_neglogp'] + values = res_dict['values'] + entropy = res_dict['entropy'] + mu = res_dict['mus'] + sigma = res_dict['sigmas'] + + a_info = self._actor_loss(old_action_log_probs_batch, action_log_probs, advantage, curr_e_clip) + a_loss = a_info['actor_loss'] + + c_info = self._critic_loss(value_preds_batch, values, curr_e_clip, return_batch, self.clip_value) + c_loss = c_info['critic_loss'] + + b_loss = self.bound_loss(mu) + + # gotta average + a_loss = torch.mean(a_loss) + c_loss = torch.mean(c_loss) + b_loss = torch.mean(b_loss) + entropy = torch.mean(entropy) + + loss = a_loss + self.critic_coef * c_loss - self.entropy_coef * entropy + self.bounds_loss_coef * b_loss + + a_clip_frac = torch.mean(a_info['actor_clipped'].float()) + + a_info['actor_loss'] = a_loss + a_info['actor_clip_frac'] = a_clip_frac + + if self.multi_gpu: + self.optimizer.zero_grad() + else: + for param in self.model.parameters(): + param.grad = None + + self.scaler.scale(loss).backward() + + #TODO: Refactor this ugliest code of the year + if self.truncate_grads: + if self.multi_gpu: + self.optimizer.synchronize() + self.scaler.unscale_(self.optimizer) + nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm) + with self.optimizer.skip_synchronize(): + self.scaler.step(self.optimizer) + self.scaler.update() + else: + self.scaler.unscale_(self.optimizer) + nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm) + self.scaler.step(self.optimizer) + self.scaler.update() + else: + self.scaler.step(self.optimizer) + self.scaler.update() + + with torch.no_grad(): + reduce_kl = not self.is_rnn + kl_dist = torch_ext.policy_kl(mu.detach(), sigma.detach(), old_mu_batch, old_sigma_batch, reduce_kl) + if self.is_rnn: + kl_dist = (kl_dist * rnn_masks).sum() / rnn_masks.numel() #/ sum_mask + + self.train_result = {'entropy': entropy, 'kl': kl_dist, 'last_lr': self.last_lr, 'lr_mul': lr_mul, 'b_loss': b_loss} + self.train_result.update(a_info) + self.train_result.update(c_info) + + return + + def discount_values(self, mb_fdones, mb_values, mb_rewards, mb_next_values): + lastgaelam = 0 + mb_advs = torch.zeros_like(mb_rewards) + + for t in reversed(range(self.horizon_length)): + not_done = 1.0 - mb_fdones[t] + not_done = not_done.unsqueeze(1) + + delta = mb_rewards[t] + self.gamma * mb_next_values[t] - mb_values[t] + lastgaelam = delta + self.gamma * self.tau * not_done * lastgaelam + mb_advs[t] = lastgaelam + + return mb_advs + + def env_reset(self, env_ids=None): + obs = self.vec_env.reset(env_ids) + obs = self.obs_to_tensors(obs) + return obs + + def bound_loss(self, mu): + if self.bounds_loss_coef is not None: + soft_bound = 1.0 + mu_loss_high = torch.clamp_min(mu - soft_bound, 0.0)**2 + mu_loss_low = torch.clamp_max(mu + soft_bound, 0.0)**2 + b_loss = (mu_loss_low + mu_loss_high).sum(axis=-1) + else: + b_loss = 0 + return b_loss + + def _get_mean_rewards(self): + return self.game_rewards.get_mean() + + def _load_config_params(self, config): + self.last_lr = config['learning_rate'] + return + + def _build_net_config(self): + obs_shape = torch_ext.shape_whc_to_cwh(self.obs_shape) + config = { + 'actions_num': self.actions_num, + 'input_shape': obs_shape, + 'num_seqs': self.num_actors * self.num_agents, + 'value_size': self.env_info.get('value_size', 1), + } + return config + + def _setup_action_space(self): + action_space = self.env_info['action_space'] + + self.actions_num = action_space.shape[0] + + # todo introduce device instead of cuda() + self.actions_low = torch.from_numpy(action_space.low.copy()).float().to(self.ppo_device) + self.actions_high = torch.from_numpy(action_space.high.copy()).float().to(self.ppo_device) + return + + def _init_train(self): + return + + def _eval_critic(self, obs_dict): + self.model.eval() + obs_dict['obs'] = self._preproc_obs(obs_dict['obs']) + if self.model.is_rnn(): + value, state = self.model.a2c_network.eval_critic(obs_dict) + else: + value = self.model.a2c_network.eval_critic(obs_dict) + + if self.normalize_value: + value = self.value_mean_std(value, True) + return value + + def _actor_loss(self, old_action_log_probs_batch, action_log_probs, advantage, curr_e_clip): + ratio = torch.exp(old_action_log_probs_batch - action_log_probs) + surr1 = advantage * ratio + surr2 = advantage * torch.clamp(ratio, 1.0 - curr_e_clip, 1.0 + curr_e_clip) + a_loss = torch.max(-surr1, -surr2) + + clipped = torch.abs(ratio - 1.0) > curr_e_clip + clipped = clipped.detach() + + info = {'actor_loss': a_loss, 'actor_clipped': clipped.detach()} + return info + + def _critic_loss(self, value_preds_batch, values, curr_e_clip, return_batch, clip_value): + if clip_value: + value_pred_clipped = value_preds_batch + \ + (values - value_preds_batch).clamp(-curr_e_clip, curr_e_clip) + value_losses = (values - return_batch)**2 + value_losses_clipped = (value_pred_clipped - return_batch)**2 + c_loss = torch.max(value_losses, value_losses_clipped) + else: + c_loss = (return_batch - values)**2 + + info = {'critic_loss': c_loss} + return info + + def _calc_advs(self, batch_dict): + returns = batch_dict['returns'] + values = batch_dict['values'] + + advantages = returns - values + advantages = torch.sum(advantages, axis=1) + + if self.normalize_advantage: + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + + return advantages + + def _record_train_batch_info(self, batch_dict, train_info): + return + + def _assemble_train_info(self, train_info, frame): + train_info_dict = { + "update_time": train_info['update_time'], + "play_time": train_info['play_time'], + "last_lr": train_info['last_lr'][-1] * train_info['lr_mul'][-1], + "lr_mul": train_info['lr_mul'][-1], + "e_clip": self.e_clip * train_info['lr_mul'][-1], + } + + if "actor_loss" in train_info: + train_info_dict.update( + { + "a_loss": torch_ext.mean_list(train_info['actor_loss']).item(), + "c_loss": torch_ext.mean_list(train_info['critic_loss']).item(), + "bounds_loss": torch_ext.mean_list(train_info['b_loss']).item(), + "entropy": torch_ext.mean_list(train_info['entropy']).item(), + "clip_frac": torch_ext.mean_list(train_info['actor_clip_frac']).item(), + "kl": torch_ext.mean_list(train_info['kl']).item(), + } + ) + + return train_info_dict + + def _log_train_info(self, train_info, frame): + + for k, v in train_info.items(): + self.writer.add_scalar(k, v, self.epoch_num) + + if not wandb.run is None: + wandb.log(train_info, step=self.epoch_num) + + return + + def post_epoch(self, epoch_num): + pass + + + +class CommonDiscreteAgent(a2c_discrete.DiscreteA2CAgent): + + def __init__(self, base_name, config): + a2c_common.DiscreteA2CBase.__init__(self, base_name, config) + self.cfg = config + self.exp_name = self.cfg['train_dir'].split('/')[-1] + + self._load_config_params(config) + + self._setup_action_space() + self.bounds_loss_coef = config.get('bounds_loss_coef', None) + + self.clip_actions = config.get('clip_actions', True) + self._save_intermediate = config.get('save_intermediate', False) + + net_config = self._build_net_config() + if self.normalize_input: + obs_shape = torch_ext.shape_whc_to_cwh(self.obs_shape) + self.running_mean_std = RunningMeanStd(obs_shape).to(self.ppo_device) + net_config['mean_std'] = self.running_mean_std + self.model = self.network.build(net_config) + self.model.to(self.ppo_device) + self.states = None + + self.init_rnn_from_model(self.model) + self.last_lr = float(self.last_lr) + + self.optimizer = optim.Adam(self.model.parameters(), float(self.last_lr), eps=1e-08, weight_decay=self.weight_decay) + + if self.has_central_value: + cv_config = { + 'state_shape': torch_ext.shape_whc_to_cwh(self.state_shape), + 'value_size': self.value_size, + 'ppo_device': self.ppo_device, + 'num_agents': self.num_agents, + 'horizon_length': self.horizon_length, + 'num_actors': self.num_actors, + 'num_actions': self.actions_num, + 'seq_len': self.seq_len, + 'model': self.central_value_config['network'], + 'config': self.central_value_config, + 'writter': self.writer, + 'multi_gpu': self.multi_gpu + } + self.central_value_net = central_value.CentralValueTrain(**cv_config).to(self.ppo_device) + + self.use_experimental_cv = self.config.get('use_experimental_cv', True) + self.dataset = amp_datasets.AMPDataset(self.batch_size, self.minibatch_size, self.is_discrete, self.is_rnn, self.ppo_device, self.seq_len) + self.algo_observer.after_init(self) + + return + + def init_tensors(self): + super().init_tensors() + self.experience_buffer.tensor_dict['next_obses'] = torch.zeros_like(self.experience_buffer.tensor_dict['obses']) + self.experience_buffer.tensor_dict['next_values'] = torch.zeros_like(self.experience_buffer.tensor_dict['values']) + + self.tensor_list += ['next_obses'] + return + + def train(self): + self.init_tensors() + self.last_mean_rewards = -100500 + start_time = time.time() + total_time = 0 + rep_count = 0 + self.frame = 0 + self.obs = self.env_reset() + self.curr_frames = self.batch_size_envs + + model_output_file = osp.join(self.network_path, self.config['name']) + + if self.multi_gpu: + self.hvd.setup_algo(self) + + self._init_train() + + while True: + epoch_start = time.time() + + epoch_num = self.update_epoch() + train_info = self.train_epoch() + + sum_time = train_info['total_time'] + total_time += sum_time + frame = self.frame + if self.multi_gpu: + self.hvd.sync_stats(self) + + if self.rank == 0: + scaled_time = sum_time + scaled_play_time = train_info['play_time'] + curr_frames = self.curr_frames + self.frame += curr_frames + fps_step = curr_frames / scaled_play_time + fps_total = curr_frames / scaled_time + + self.writer.add_scalar('performance/total_fps', curr_frames / scaled_time, frame) + self.writer.add_scalar('performance/step_fps', curr_frames / scaled_play_time, frame) + self.writer.add_scalar('info/epochs', epoch_num, frame) + train_info_dict = self._assemble_train_info(train_info, frame) + self.algo_observer.after_print_stats(frame, epoch_num, total_time) + if self.save_freq > 0: + + if epoch_num % min(50, self.save_best_after) == 0: + self.save(model_output_file) + + if (self._save_intermediate) and (epoch_num % (self.save_freq) == 0): + # Save intermediate model every save_freq epoches + int_model_output_file = model_output_file + '_' + str(epoch_num).zfill(8) + self.save(int_model_output_file) + + if self.game_rewards.current_size > 0: + mean_rewards = self._get_mean_rewards() + mean_lengths = self.game_lengths.get_mean() + + for i in range(self.value_size): + self.writer.add_scalar('rewards{0}/frame'.format(i), mean_rewards[i], frame) + self.writer.add_scalar('rewards{0}/iter'.format(i), mean_rewards[i], epoch_num) + self.writer.add_scalar('rewards{0}/time'.format(i), mean_rewards[i], total_time) + + self.writer.add_scalar('episode_lengths/frame', mean_lengths, frame) + self.writer.add_scalar('episode_lengths/iter', mean_lengths, epoch_num) + + if (self._save_intermediate) and (epoch_num % (self.save_freq) == 0): + eval_info = self.eval() + train_info_dict.update(eval_info) + + train_info_dict.update({"episode_lengths": mean_lengths, "mean_rewards": np.mean(mean_rewards)}) + self._log_train_info(train_info_dict, frame) + + epoch_end = time.time() + log_str = f"{self.exp_name}-Ep: {self.epoch_num}\trwd: {np.mean(mean_rewards):.1f}\tfps_step: {fps_step:.1f}\tfps_total: {fps_total:.1f}\tep_time:{epoch_end - epoch_start:.1f}\tframe: {self.frame}\teps_len: {mean_lengths:.1f}" + print(log_str) + + if self.has_self_play_config: + self.self_play_manager.update(self) + + if epoch_num > self.max_epochs: + self.save(model_output_file) + print('MAX EPOCHS NUM!') + return self.last_mean_rewards, epoch_num + + update_time = 0 + return + + def eval(self): + print("evaluation routine not implemented") + return {} + + def train_epoch(self): + play_time_start = time.time() + with torch.no_grad(): + if self.is_rnn: + batch_dict = self.play_steps_rnn() + else: + batch_dict = self.play_steps() + + play_time_end = time.time() + update_time_start = time.time() + rnn_masks = batch_dict.get('rnn_masks', None) + + self.set_train() + + self.curr_frames = batch_dict.pop('played_frames') + self.prepare_dataset(batch_dict) + self.algo_observer.after_steps() + + if self.has_central_value: + self.train_central_value() + + train_info = None + + if self.is_rnn: + frames_mask_ratio = rnn_masks.sum().item() / (rnn_masks.nelement()) + print(frames_mask_ratio) + + for _ in range(0, self.mini_epochs_num): + ep_kls = [] + for i in range(len(self.dataset)): + curr_train_info = self.train_actor_critic(self.dataset[i]) + + if self.schedule_type == 'legacy': + if self.multi_gpu: + curr_train_info['kl'] = self.hvd.average_value(curr_train_info['kl'], 'ep_kls') + self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, curr_train_info['kl'].item()) + self.update_lr(self.last_lr) + + if (train_info is None): + train_info = dict() + for k, v in curr_train_info.items(): + train_info[k] = [v] + else: + for k, v in curr_train_info.items(): + train_info[k].append(v) + + av_kls = torch_ext.mean_list(train_info['kl']) + + if self.schedule_type == 'standard': + if self.multi_gpu: + av_kls = self.hvd.average_value(av_kls, 'ep_kls') + self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item()) + self.update_lr(self.last_lr) + + if self.schedule_type == 'standard_epoch': + if self.multi_gpu: + av_kls = self.hvd.average_value(torch_ext.mean_list(kls), 'ep_kls') + self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item()) + self.update_lr(self.last_lr) + + update_time_end = time.time() + play_time = play_time_end - play_time_start + update_time = update_time_end - update_time_start + total_time = update_time_end - play_time_start + + train_info['play_time'] = play_time + train_info['update_time'] = update_time + train_info['total_time'] = total_time + self._record_train_batch_info(batch_dict, train_info) + return train_info + + def play_steps(self): + self.set_eval() + + epinfos = [] + done_indices = [] + update_list = self.update_list + + for n in range(self.horizon_length): + self.obs = self.env_reset(done_indices) + + self.experience_buffer.update_data('obses', n, self.obs['obs']) + + if self.use_action_masks: + masks = self.vec_env.get_action_masks() + res_dict = self.get_masked_action_values(self.obs, masks) + else: + res_dict = self.get_action_values(self.obs) + + for k in update_list: + self.experience_buffer.update_data(k, n, res_dict[k]) + + if self.has_central_value: + self.experience_buffer.update_data('states', n, self.obs['states']) + + self.obs, rewards, self.dones, infos = self.env_step(res_dict['actions']) + shaped_rewards = self.rewards_shaper(rewards) + self.experience_buffer.update_data('rewards', n, shaped_rewards) + self.experience_buffer.update_data('next_obses', n, self.obs['obs']) + self.experience_buffer.update_data('dones', n, self.dones) + + terminated = infos['terminate'].float() + terminated = terminated.unsqueeze(-1) + + next_vals = self._eval_critic(self.obs) + next_vals *= (1.0 - terminated) + self.experience_buffer.update_data('next_values', n, next_vals) + + self.current_rewards += rewards + self.current_lengths += 1 + all_done_indices = self.dones.nonzero(as_tuple=False) + done_indices = all_done_indices[::self.num_agents] + + self.game_rewards.update(self.current_rewards[done_indices]) + self.game_lengths.update(self.current_lengths[done_indices]) + self.algo_observer.process_infos(infos, done_indices) + + not_dones = 1.0 - self.dones.float() + + self.current_rewards = self.current_rewards * not_dones.unsqueeze(1) + self.current_lengths = self.current_lengths * not_dones + + done_indices = done_indices[:, 0] + + mb_fdones = self.experience_buffer.tensor_dict['dones'].float() + mb_values = self.experience_buffer.tensor_dict['values'] + mb_next_values = self.experience_buffer.tensor_dict['next_values'] + mb_rewards = self.experience_buffer.tensor_dict['rewards'] + + mb_advs = self.discount_values(mb_fdones, mb_values, mb_rewards, mb_next_values) + mb_returns = mb_advs + mb_values + + batch_dict = self.experience_buffer.get_transformed_list(a2c_common.swap_and_flatten01, self.tensor_list) + batch_dict['returns'] = a2c_common.swap_and_flatten01(mb_returns) + batch_dict['played_frames'] = self.batch_size + + return batch_dict + + def prepare_dataset(self, batch_dict): + obses = batch_dict['obses'] + returns = batch_dict['returns'] + dones = batch_dict['dones'] + values = batch_dict['values'] + actions = batch_dict['actions'] + neglogpacs = batch_dict['neglogpacs'] + rnn_states = batch_dict.get('rnn_states', None) + rnn_masks = batch_dict.get('rnn_masks', None) + + advantages = self._calc_advs(batch_dict) + + if self.normalize_value: + values = self.value_mean_std(values) + returns = self.value_mean_std(returns) + + dataset_dict = {} + dataset_dict['old_values'] = values + dataset_dict['old_logp_actions'] = neglogpacs + dataset_dict['advantages'] = advantages + dataset_dict['returns'] = returns + dataset_dict['actions'] = actions + dataset_dict['obs'] = obses + dataset_dict['rnn_states'] = rnn_states + dataset_dict['rnn_masks'] = rnn_masks + + if self.has_central_value: + dataset_dict = {} + dataset_dict['old_values'] = values + dataset_dict['advantages'] = advantages + dataset_dict['returns'] = returns + dataset_dict['actions'] = actions + dataset_dict['obs'] = batch_dict['states'] + dataset_dict['rnn_masks'] = rnn_masks + self.central_value_net.update_dataset(dataset_dict) + + self.dataset.update_values_dict(dataset_dict) + return dataset_dict + + + def discount_values(self, mb_fdones, mb_values, mb_rewards, mb_next_values): + lastgaelam = 0 + mb_advs = torch.zeros_like(mb_rewards) + + for t in reversed(range(self.horizon_length)): + not_done = 1.0 - mb_fdones[t] + not_done = not_done.unsqueeze(1) + + delta = mb_rewards[t] + self.gamma * mb_next_values[t] - mb_values[t] + lastgaelam = delta + self.gamma * self.tau * not_done * lastgaelam + mb_advs[t] = lastgaelam + + return mb_advs + + def env_reset(self, env_ids=None): + obs = self.vec_env.reset(env_ids) + obs = self.obs_to_tensors(obs) + return obs + + def bound_loss(self, mu): + if self.bounds_loss_coef is not None: + soft_bound = 1.0 + mu_loss_high = torch.clamp_min(mu - soft_bound, 0.0)**2 + mu_loss_low = torch.clamp_max(mu + soft_bound, 0.0)**2 + b_loss = (mu_loss_low + mu_loss_high).sum(axis=-1) + else: + b_loss = 0 + return b_loss + + def _get_mean_rewards(self): + return self.game_rewards.get_mean() + + def _load_config_params(self, config): + self.last_lr = config['learning_rate'] + return + + def _build_net_config(self): + obs_shape = torch_ext.shape_whc_to_cwh(self.obs_shape) + config = { + 'actions_num': self.actions_num, + 'input_shape': obs_shape, + 'num_seqs': self.num_actors * self.num_agents, + 'value_size': self.env_info.get('value_size', 1), + } + return config + + def _setup_action_space(self): + action_space = self.env_info['action_space'] + self.actions_num = action_space.shape + + batch_size = self.num_agents * self.num_actors + if type(action_space) is spaces.Discrete: + self.actions_shape = (self.horizon_length, batch_size) + self.actions_num = action_space.n + self.is_multi_discrete = False + if type(action_space) is spaces.Tuple: + self.actions_shape = (self.horizon_length, batch_size, len(action_space)) + self.actions_num = [action.n for action in action_space] + self.is_multi_discrete = True + return + + def _init_train(self): + return + + def _eval_critic(self, obs_dict): + self.model.eval() + obs_dict['obs'] = self._preproc_obs(obs_dict['obs']) + if self.model.is_rnn(): + value, state = self.model.a2c_network.eval_critic(obs_dict) + else: + value = self.model.a2c_network.eval_critic(obs_dict) + + if self.normalize_value: + value = self.value_mean_std(value, True) + return value + + def _actor_loss(self, old_action_log_probs_batch, action_log_probs, advantage, curr_e_clip): + ratio = torch.exp(old_action_log_probs_batch - action_log_probs) + surr1 = advantage * ratio + surr2 = advantage * torch.clamp(ratio, 1.0 - curr_e_clip, 1.0 + curr_e_clip) + a_loss = torch.max(-surr1, -surr2) + + clipped = torch.abs(ratio - 1.0) > curr_e_clip + clipped = clipped.detach() + + info = {'actor_loss': a_loss, 'actor_clipped': clipped.detach()} + return info + + def _critic_loss(self, value_preds_batch, values, curr_e_clip, return_batch, clip_value): + if clip_value: + value_pred_clipped = value_preds_batch + \ + (values - value_preds_batch).clamp(-curr_e_clip, curr_e_clip) + value_losses = (values - return_batch)**2 + value_losses_clipped = (value_pred_clipped - return_batch)**2 + c_loss = torch.max(value_losses, value_losses_clipped) + else: + c_loss = (return_batch - values)**2 + + info = {'critic_loss': c_loss} + return info + + def _calc_advs(self, batch_dict): + returns = batch_dict['returns'] + values = batch_dict['values'] + + advantages = returns - values + advantages = torch.sum(advantages, axis=1) + + if self.normalize_advantage: + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + + return advantages + + def _record_train_batch_info(self, batch_dict, train_info): + return + + def _assemble_train_info(self, train_info, frame): + train_info_dict = { + "update_time": train_info['update_time'], + "play_time": train_info['play_time'], + "a_loss": torch_ext.mean_list(train_info['actor_loss']).item(), + "c_loss": torch_ext.mean_list(train_info['critic_loss']).item(), + "entropy": torch_ext.mean_list(train_info['entropy']).item(), + "last_lr": train_info['last_lr'][-1] * train_info['lr_mul'][-1], + "lr_mul": train_info['lr_mul'][-1], + "e_clip": self.e_clip * train_info['lr_mul'][-1], + "clip_frac": torch_ext.mean_list(train_info['actor_clip_frac']).item(), + "kl": torch_ext.mean_list(train_info['kl']).item(), + } + + return train_info_dict + + def _log_train_info(self, train_info, frame): + + for k, v in train_info.items(): + self.writer.add_scalar(k, v, self.epoch_num) + + if not wandb.run is None: + wandb.log(train_info, step=self.epoch_num) + + return + + def post_epoch(self, epoch_num): + pass + + def _change_char_color(self, env_ids): + base_col = np.array([0.4, 0.4, 0.4]) + range_col = np.array([0.0706, 0.149, 0.2863]) + range_sum = np.linalg.norm(range_col) + + rand_col = np.random.uniform(0.0, 1.0, size=3) + rand_col = range_sum * rand_col / np.linalg.norm(rand_col) + rand_col += base_col + self.vec_env.env.task.set_char_color(rand_col, env_ids) + return \ No newline at end of file diff --git a/phc/learning/common_player.py b/phc/learning/common_player.py new file mode 100644 index 0000000..7e374fb --- /dev/null +++ b/phc/learning/common_player.py @@ -0,0 +1,418 @@ +import torch + +from rl_games.algos_torch import players +from rl_games.algos_torch import torch_ext +from phc.utils.running_mean_std import RunningMeanStd +from rl_games.common.player import BasePlayer + +import numpy as np +import gc +from gym import spaces + + +class CommonPlayer(players.PpoPlayerContinuous): + + def __init__(self, config): + BasePlayer.__init__(self, config) + self.network = config['network'] + + self._setup_action_space() + self.mask = [False] + + self.normalize_input = self.config['normalize_input'] + + net_config = self._build_net_config() + self._build_net(net_config) + self.first = True + return + + def run(self): + n_games = self.games_num + render = self.render_env + n_game_life = self.n_game_life + is_determenistic = self.is_determenistic + sum_rewards = 0 + sum_steps = 0 + sum_game_res = 0 + n_games = n_games * n_game_life + games_played = 0 + has_masks = False + has_masks_func = getattr(self.env, "has_action_mask", None) is not None + + op_agent = getattr(self.env, "create_agent", None) + if op_agent: + agent_inited = True + + if has_masks_func: + has_masks = self.env.has_action_mask() + + need_init_rnn = self.is_rnn + for t in range(n_games): + if games_played >= n_games: + break + + obs_dict = self.env_reset() + + batch_size = 1 + batch_size = self.get_batch_size(obs_dict['obs'], batch_size) + + if need_init_rnn: + self.init_rnn() + need_init_rnn = False + + cr = torch.zeros(batch_size, dtype=torch.float32, device=self.device) + steps = torch.zeros(batch_size, dtype=torch.float32, device=self.device) + + print_game_res = False + + done_indices = [] + + with torch.no_grad(): + for n in range(self.max_steps): + + obs_dict = self.env_reset(done_indices) + + if has_masks: + masks = self.env.get_action_mask() + action = self.get_masked_action(obs_dict, masks, is_determenistic) + else: + action = self.get_action(obs_dict, is_determenistic) + + # print(obs_dict[0].cpu().numpy()) + # print("needing a very very fine comb here. ") + # import joblib; joblib.dump(obs_dict[0].cpu().numpy(), "a.pkl") + # np.abs(joblib.load("a.pkl") - obs_dict[0].cpu().numpy()).sum() + + # import joblib; joblib.dump(obs_dict['obs'].detach().cpu().numpy(), "a.pkl") + # import joblib; np.abs(joblib.load("a.pkl")[0] - obs_dict['obs'][0].detach().cpu().numpy()).sum() + # joblib.dump(action, "a.pkl") + # joblib.load("a.pkl")[0] - action[0] + + obs_dict, r, done, info = self.env_step(self.env, action) + + cr += r + steps += 1 + + self._post_step(info) + + if render: + self.env.render(mode='human') + time.sleep(self.render_sleep) + + all_done_indices = done.nonzero(as_tuple=False) + done_indices = all_done_indices[::self.num_agents] + done_count = len(done_indices) + games_played += done_count + + if done_count > 0: + if self.is_rnn: + for s in self.states: + s[:, all_done_indices, :] = s[:, all_done_indices, :] * 0.0 + + cur_rewards = cr[done_indices].sum().item() + cur_steps = steps[done_indices].sum().item() + + cr = cr * (1.0 - done.float()) + steps = steps * (1.0 - done.float()) + sum_rewards += cur_rewards + sum_steps += cur_steps + + game_res = 0.0 + if isinstance(info, dict): + if 'battle_won' in info: + print_game_res = True + game_res = info.get('battle_won', 0.5) + if 'scores' in info: + print_game_res = True + game_res = info.get('scores', 0.5) + if self.print_stats: + if print_game_res: + print('reward:', cur_rewards / done_count, 'steps:', cur_steps / done_count, 'w:', game_res) + else: + print('reward:', cur_rewards / done_count, 'steps:', cur_steps / done_count) + + sum_game_res += game_res + # if batch_size//self.num_agents == 1 or games_played >= n_games: + if games_played >= n_games: + break + + done_indices = done_indices[:, 0] + + print(sum_rewards) + if print_game_res: + print('av reward:', sum_rewards / games_played * n_game_life, 'av steps:', sum_steps / games_played * n_game_life, 'winrate:', sum_game_res / games_played * n_game_life) + else: + print('av reward:', sum_rewards / games_played * n_game_life, 'av steps:', sum_steps / games_played * n_game_life) + + return + + def obs_to_torch(self, obs): + obs = super().obs_to_torch(obs) + obs_dict = {'obs': obs} + return obs_dict + + def get_action(self, obs_dict, is_determenistic=False): + output = super().get_action(obs_dict['obs'], is_determenistic) + return output + + def env_step(self, env, actions): + if not self.is_tensor_obses: + actions = actions.cpu().numpy() + + obs, rewards, dones, infos = env.step(actions) + + if hasattr(obs, 'dtype') and obs.dtype == np.float64: + obs = np.float32(obs) + if self.value_size > 1: + rewards = rewards[0] + if self.is_tensor_obses: + return obs, rewards.to(self.device), dones.to(self.device), infos + else: + if np.isscalar(dones): + rewards = np.expand_dims(np.asarray(rewards), 0) + dones = np.expand_dims(np.asarray(dones), 0) + return self.obs_to_torch(obs), torch.from_numpy(rewards), torch.from_numpy(dones), infos + + def _build_net(self, config): + if self.normalize_input: + if "vec_env" in self.__dict__: + obs_shape = torch_ext.shape_whc_to_cwh(self.env.task.get_running_mean_size()) + else: + obs_shape = torch_ext.shape_whc_to_cwh(self.obs_shape) + self.running_mean_std = RunningMeanStd(obs_shape).to(self.device) + self.running_mean_std.eval() + config['mean_std'] = self.running_mean_std + self.model = self.network.build(config) + self.model.to(self.device) + self.model.eval() + self.is_rnn = self.model.is_rnn() + + return + + def env_reset(self, env_ids=None): + obs = self.env.reset(env_ids) + return self.obs_to_torch(obs) + + def _post_step(self, info): + return + + def _build_net_config(self): + obs_shape = torch_ext.shape_whc_to_cwh(self.obs_shape) + config = {'actions_num': self.actions_num, 'input_shape': obs_shape, 'num_seqs': self.num_agents} + return config + + def _setup_action_space(self): + self.actions_num = self.action_space.shape[0] + self.actions_low = torch.from_numpy(self.action_space.low.copy()).float().to(self.device) + self.actions_high = torch.from_numpy(self.action_space.high.copy()).float().to(self.device) + return + +class CommonPlayerDiscrete(players.PpoPlayerDiscrete): + + def __init__(self, config): + BasePlayer.__init__(self, config) + self.network = config['network'] + + self._setup_action_space() + self.mask = [False] + + self.normalize_input = self.config['normalize_input'] + + net_config = self._build_net_config() + self._build_net(net_config) + self.first = True + return + + def run(self): + n_games = self.games_num + render = self.render_env + n_game_life = self.n_game_life + is_determenistic = self.is_determenistic + sum_rewards = 0 + sum_steps = 0 + sum_game_res = 0 + n_games = n_games * n_game_life + games_played = 0 + has_masks = False + has_masks_func = getattr(self.env, "has_action_mask", None) is not None + + op_agent = getattr(self.env, "create_agent", None) + if op_agent: + agent_inited = True + + if has_masks_func: + has_masks = self.env.has_action_mask() + + need_init_rnn = self.is_rnn + for t in range(n_games): + if games_played >= n_games: + break + + obs_dict = self.env_reset() + + batch_size = 1 + batch_size = self.get_batch_size(obs_dict['obs'], batch_size) + + if need_init_rnn: + self.init_rnn() + need_init_rnn = False + + cr = torch.zeros(batch_size, dtype=torch.float32, device=self.device) + steps = torch.zeros(batch_size, dtype=torch.float32, device=self.device) + + print_game_res = False + + done_indices = [] + + with torch.no_grad(): + for n in range(self.max_steps): + + obs_dict = self.env_reset(done_indices) + + if has_masks: + masks = self.env.get_action_mask() + action = self.get_masked_action(obs_dict, masks, is_determenistic) + else: + action = self.get_action(obs_dict, is_determenistic) + + # print(obs_dict[0].cpu().numpy()) + # print("needing a very very fine comb here. ") + # import joblib; joblib.dump(obs_dict[0].cpu().numpy(), "a.pkl") + # np.abs(joblib.load("a.pkl") - obs_dict[0].cpu().numpy()).sum() + + # import joblib; joblib.dump(obs_dict['obs'].detach().cpu().numpy(), "a.pkl") + # import joblib; np.abs(joblib.load("a.pkl")[0] - obs_dict['obs'][0].detach().cpu().numpy()).sum() + # joblib.dump(action, "a.pkl") + # joblib.load("a.pkl")[0] - action[0] + obs_dict, r, done, info = self.env_step(self.env, action) + + cr += r + steps += 1 + + self._post_step(info) + + if render: + self.env.render(mode='human') + time.sleep(self.render_sleep) + + all_done_indices = done.nonzero(as_tuple=False) + done_indices = all_done_indices[::self.num_agents] + done_count = len(done_indices) + games_played += done_count + + if done_count > 0: + if self.is_rnn: + for s in self.states: + s[:, all_done_indices, :] = s[:, all_done_indices, :] * 0.0 + + cur_rewards = cr[done_indices].sum().item() + cur_steps = steps[done_indices].sum().item() + + cr = cr * (1.0 - done.float()) + steps = steps * (1.0 - done.float()) + sum_rewards += cur_rewards + sum_steps += cur_steps + + game_res = 0.0 + if isinstance(info, dict): + if 'battle_won' in info: + print_game_res = True + game_res = info.get('battle_won', 0.5) + if 'scores' in info: + print_game_res = True + game_res = info.get('scores', 0.5) + if self.print_stats: + if print_game_res: + print('reward:', cur_rewards / done_count, 'steps:', cur_steps / done_count, 'w:', game_res) + else: + print('reward:', cur_rewards / done_count, 'steps:', cur_steps / done_count) + + sum_game_res += game_res + # if batch_size//self.num_agents == 1 or games_played >= n_games: + if games_played >= n_games: + break + + done_indices = done_indices[:, 0] + + print(sum_rewards) + if print_game_res: + print('av reward:', sum_rewards / games_played * n_game_life, 'av steps:', sum_steps / games_played * n_game_life, 'winrate:', sum_game_res / games_played * n_game_life) + else: + print('av reward:', sum_rewards / games_played * n_game_life, 'av steps:', sum_steps / games_played * n_game_life) + + return + + def obs_to_torch(self, obs): + obs = super().obs_to_torch(obs) + obs_dict = {'obs': obs} + return obs_dict + + def get_action(self, obs_dict, is_determenistic=False): + output = super().get_action(obs_dict['obs'], is_determenistic) + return output + + def env_step(self, env, actions): + if not self.is_tensor_obses: + actions = actions.cpu().numpy() + + obs, rewards, dones, infos = env.step(actions) + + if hasattr(obs, 'dtype') and obs.dtype == np.float64: + obs = np.float32(obs) + if self.value_size > 1: + rewards = rewards[0] + if self.is_tensor_obses: + return obs, rewards.to(self.device), dones.to(self.device), infos + else: + if np.isscalar(dones): + rewards = np.expand_dims(np.asarray(rewards), 0) + dones = np.expand_dims(np.asarray(dones), 0) + return self.obs_to_torch(obs), torch.from_numpy(rewards), torch.from_numpy(dones), infos + + def _build_net(self, config): + if self.normalize_input: + obs_shape = torch_ext.shape_whc_to_cwh(self.env.task.get_running_mean_size()) + self.running_mean_std = RunningMeanStd(obs_shape).to(self.device) + self.running_mean_std.eval() + config['mean_std'] = self.running_mean_std + self.model = self.network.build(config) + self.model.to(self.device) + self.model.eval() + self.is_rnn = self.model.is_rnn() + + return + + def env_reset(self, env_ids=None): + obs = self.env.reset(env_ids) + return self.obs_to_torch(obs) + + def _post_step(self, info): + return + + def _build_net_config(self): + obs_shape = torch_ext.shape_whc_to_cwh(self.obs_shape) + config = {'actions_num': self.actions_num, 'input_shape': obs_shape, 'num_seqs': self.num_agents} + return config + + def _setup_action_space(self): + action_space = self.env_info['action_space'] + self.actions_num = action_space.shape + + if type(action_space) is spaces.Discrete: + self.actions_num = action_space.n + self.is_multi_discrete = False + if type(action_space) is spaces.Tuple: + self.actions_num = [action.n for action in action_space] + self.is_multi_discrete = True + return + + def _change_char_color(self, env_ids): + base_col = np.array([0.4, 0.4, 0.4]) + range_col = np.array([0.0706, 0.149, 0.2863]) + range_sum = np.linalg.norm(range_col) + + rand_col = np.random.uniform(0.0, 1.0, size=3) + rand_col = range_sum * rand_col / np.linalg.norm(rand_col) + rand_col += base_col + self.vec_env.env.task.set_char_color(rand_col, env_ids) + return \ No newline at end of file diff --git a/phc/learning/im_amp.py b/phc/learning/im_amp.py new file mode 100644 index 0000000..0c1e940 --- /dev/null +++ b/phc/learning/im_amp.py @@ -0,0 +1,359 @@ + + +import glob +import os +import sys +import pdb +import os.path as osp +sys.path.append(os.getcwd()) + +from phc.utils.running_mean_std import RunningMeanStd +from rl_games.algos_torch import torch_ext +from rl_games.common import a2c_common +from rl_games.common import schedulers +from rl_games.common import vecenv + +from isaacgym.torch_utils import * + +import time +from datetime import datetime +import numpy as np +from torch import optim +import torch +from torch import nn +from phc.env.tasks.humanoid_amp_task import HumanoidAMPTask + +import learning.replay_buffer as replay_buffer +import phc.learning.amp_agent as amp_agent +from phc.utils.flags import flags +from rl_games.common.tr_helpers import unsqueeze_obs +from rl_games.algos_torch.players import rescale_actions + +from tensorboardX import SummaryWriter +import joblib +import gc +from smpl_sim.smpllib.smpl_eval import compute_metrics_lite +from tqdm import tqdm + + +class IMAmpAgent(amp_agent.AMPAgent): + def __init__(self, base_name, config): + super().__init__(base_name, config) + + + def get_action(self, obs_dict, is_determenistic=False): + obs = obs_dict["obs"] + + if self.has_batch_dimension == False: + obs = unsqueeze_obs(obs) + obs = self._preproc_obs(obs) + input_dict = { + "is_train": False, + "prev_actions": None, + "obs": obs, + "rnn_states": self.states, + } + with torch.no_grad(): + res_dict = self.model(input_dict) + mu = res_dict["mus"] + action = res_dict["actions"] + self.states = res_dict["rnn_states"] + if is_determenistic: + current_action = mu + else: + current_action = action + if self.has_batch_dimension == False: + current_action = torch.squeeze(current_action.detach()) + + if self.clip_actions: + return rescale_actions( + self.actions_low, + self.actions_high, + torch.clamp(current_action, -1.0, 1.0), + ) + else: + return current_action + + def env_eval_step(self, env, actions): + + if not self.is_tensor_obses: + actions = actions.cpu().numpy() + + obs, rewards, dones, infos = env.step(actions) + + if hasattr(obs, "dtype") and obs.dtype == np.float64: + obs = np.float32(obs) + if self.value_size > 1: + rewards = rewards[0] + if self.is_tensor_obses: + return obs, rewards.to(self.device), dones.to(self.device), infos + else: + if np.isscalar(dones): + rewards = np.expand_dims(np.asarray(rewards), 0) + dones = np.expand_dims(np.asarray(dones), 0) + return ( + self.obs_to_torch(obs), + torch.from_numpy(rewards), + torch.from_numpy(dones), + infos, + ) + + def restore(self, fn): + super().restore(fn) + + all_fails = glob.glob(osp.join(self.network_path, f"failed_*")) + if len(all_fails) > 0: + print("------------------------------------------------------ Restoring Termination History ------------------------------------------------------") + failed_pth = sorted(all_fails, key=lambda x: int(x.split("_")[-1].split(".")[0]))[-1] + print(f"loading: {failed_pth}") + termination_history = joblib.load(failed_pth)['termination_history'] + humanoid_env = self.vec_env.env.task + res = humanoid_env._motion_lib.update_sampling_prob(termination_history) + if res: + print("Successfully restored termination history") + else: + print("Termination history length does not match") + + return + + def init_rnn(self): + if self.is_rnn: + rnn_states = self.model.get_default_rnn_state() + self.states = [torch.zeros((s.size()[0], self.vec_env.env.task.num_envs, s.size( + )[2]), dtype=torch.float32).to(self.device) for s in rnn_states] + + + def update_training_data(self, failed_keys): + humanoid_env = self.vec_env.env.task + joblib.dump({"failed_keys": failed_keys, "termination_history": humanoid_env._motion_lib._termination_history}, osp.join(self.network_path, f"failed_{self.epoch_num:010d}.pkl")) + + + + def eval(self): + print("############################ Evaluation ############################") + if not flags.has_eval: + return {} + + self.set_eval() + + self.terminate_state = torch.zeros( + self.vec_env.env.task.num_envs, device=self.device + ) + self.terminate_memory = [] + self.mpjpe, self.mpjpe_all = [], [] + self.gt_pos, self.gt_pos_all = [], [] + self.pred_pos, self.pred_pos_all = [], [] + self.curr_stpes = 0 + + humanoid_env = self.vec_env.env.task + self.success_rate = 0 + self.pbar = tqdm( + range(humanoid_env._motion_lib._num_unique_motions // humanoid_env.num_envs) + ) + self.pbar.set_description("") + + ################## Save results first; ZL: Ugllllllllly code, refractor asap ################## + termination_distances, cycle_motion, zero_out_far, reset_ids = ( + humanoid_env._termination_distances.clone(), + humanoid_env.cycle_motion, + humanoid_env.zero_out_far, + humanoid_env._reset_bodies_id, + ) + + if "_recovery_episode_prob" in humanoid_env.__dict__: + recovery_episode_prob, fall_init_prob = ( + humanoid_env._recovery_episode_prob, + humanoid_env._fall_init_prob, + ) + humanoid_env._recovery_episode_prob, humanoid_env._fall_init_prob = 0, 0 + + humanoid_env._termination_distances[:] = 0.5 # if not humanoid_env.strict_eval else 0.25 # ZL: use UHC's termination distance + humanoid_env.cycle_motion = False + humanoid_env.zero_out_far = False + flags.test, flags.im_eval = (True, True,) # need to be test to have: motion_times[:] = 0 + humanoid_env._motion_lib = humanoid_env._motion_eval_lib + humanoid_env.begin_seq_motion_samples() + if len(humanoid_env._reset_bodies_id) > 15: + humanoid_env._reset_bodies_id = humanoid_env._eval_track_bodies_id # Following UHC. Only do it for full body, not for three point/two point trackings. + ################## Save results first; ZL: Ugllllllllly code, refractor asap ################## + + self.print_stats = False + self.has_batch_dimension = True + + need_init_rnn = self.is_rnn + obs_dict = self.env_reset() + batch_size = humanoid_env.num_envs + + if need_init_rnn: + self.init_rnn() + need_init_rnn = False + + cr = torch.zeros(batch_size, dtype=torch.float32, device=self.device) + steps = torch.zeros(batch_size, dtype=torch.float32, device=self.device) + + done_indices = [] + + with torch.no_grad(): + while True: + obs_dict = self.env_reset(done_indices) + + action = self.get_action(obs_dict, is_determenistic=True) + obs_dict, r, done, info = self.env_eval_step(self.vec_env.env, action) + cr += r + steps += 1 + done, info = self._post_step_eval(info, done.clone()) + + all_done_indices = done.nonzero(as_tuple=False) + done_indices = all_done_indices[:: self.num_agents] + done_count = len(done_indices) + if done_count > 0: + if self.is_rnn: + for s in self.states: + s[:, all_done_indices, :] = s[:, all_done_indices, :] * 0.0 + done_indices = done_indices[:, 0] + + if info['end']: + break + + ################## Save results first; ZL: Ugllllllllly code, refractor asap ################## + humanoid_env._termination_distances[:] = termination_distances + humanoid_env.cycle_motion = cycle_motion + humanoid_env.zero_out_far = zero_out_far + flags.test, flags.im_eval = False, False + humanoid_env._motion_lib = humanoid_env._motion_train_lib + if "_recovery_episode_prob" in humanoid_env.__dict__: + humanoid_env._recovery_episode_prob, humanoid_env._fall_init_prob = ( + recovery_episode_prob, + fall_init_prob, + ) + humanoid_env._reset_bodies_id = reset_ids + self.env_reset() # Reset ALL environments, go back to training mode. + + ################## Save results first; ZL: Ugllllllllly code, refractor asap ################## + torch.cuda.empty_cache() + gc.collect() + + self.update_training_data(info['failed_keys']) + del self.terminate_state, self.terminate_memory, self.mpjpe, self.mpjpe_all + return info["eval_info"] + + def _post_step_eval(self, info, done): + end = False + eval_info = {} + # modify done such that games will exit and reset. + humanoid_env = self.vec_env.env.task + termination_state = torch.logical_and(self.curr_stpes <= humanoid_env._motion_lib.get_motion_num_steps() - 1, info["terminate"]) # if terminate after the last frame, then it is not a termination. curr_step is one step behind simulation. + # termination_state = info["terminate"] + self.terminate_state = torch.logical_or(termination_state, self.terminate_state) + if (~self.terminate_state).sum() > 0: + max_possible_id = humanoid_env._motion_lib._num_unique_motions - 1 + curr_ids = humanoid_env._motion_lib._curr_motion_ids + if (max_possible_id == curr_ids).sum() > 0: + bound = (max_possible_id == curr_ids).nonzero()[0] + 1 + if (~self.terminate_state[:bound]).sum() > 0: + curr_max = humanoid_env._motion_lib.get_motion_num_steps()[:bound][ + ~self.terminate_state[:bound] + ].max() + else: + curr_max = (self.curr_stpes - 1) # the ones that should be counted have teimrated + else: + curr_max = humanoid_env._motion_lib.get_motion_num_steps()[~self.terminate_state].max() + + if self.curr_stpes >= curr_max: curr_max = self.curr_stpes + 1 # For matching up the current steps and max steps. + else: + curr_max = humanoid_env._motion_lib.get_motion_num_steps().max() + + self.mpjpe.append(info["mpjpe"]) + self.gt_pos.append(info["body_pos_gt"]) + self.pred_pos.append(info["body_pos"]) + self.curr_stpes += 1 + + if self.curr_stpes >= curr_max or self.terminate_state.sum() == humanoid_env.num_envs: + self.curr_stpes = 0 + self.terminate_memory.append(self.terminate_state.cpu().numpy()) + self.success_rate = (1- np.concatenate(self.terminate_memory)[: humanoid_env._motion_lib._num_unique_motions].mean()) + + # MPJPE + all_mpjpe = torch.stack(self.mpjpe) + assert(all_mpjpe.shape[0] == curr_max or self.terminate_state.sum() == humanoid_env.num_envs) # Max should be the same as the number of frames in the motion. + all_mpjpe = [all_mpjpe[:(i - 1), idx].mean() for idx, i in enumerate(humanoid_env._motion_lib.get_motion_num_steps())] + all_body_pos_pred = np.stack(self.pred_pos) + all_body_pos_pred = [all_body_pos_pred[:(i - 1), idx] for idx, i in enumerate(humanoid_env._motion_lib.get_motion_num_steps())] + all_body_pos_gt = np.stack(self.gt_pos) + all_body_pos_gt = [all_body_pos_gt[:(i - 1), idx] for idx, i in enumerate(humanoid_env._motion_lib.get_motion_num_steps())] + + + self.mpjpe_all.append(all_mpjpe) + self.pred_pos_all += all_body_pos_pred + self.gt_pos_all += all_body_pos_gt + + + if (humanoid_env.start_idx + humanoid_env.num_envs >= humanoid_env._motion_lib._num_unique_motions): + self.pbar.clear() + terminate_hist = np.concatenate(self.terminate_memory) + succ_idxes = np.flatnonzero(~terminate_hist[: humanoid_env._motion_lib._num_unique_motions]).tolist() + + pred_pos_all_succ = [(self.pred_pos_all[:humanoid_env._motion_lib._num_unique_motions])[i] for i in succ_idxes] + gt_pos_all_succ = [(self.gt_pos_all[: humanoid_env._motion_lib._num_unique_motions])[i] for i in succ_idxes] + + pred_pos_all = self.pred_pos_all[:humanoid_env._motion_lib._num_unique_motions] + gt_pos_all = self.gt_pos_all[: humanoid_env._motion_lib._num_unique_motions] + + + # np.sum([i.shape[0] for i in self.pred_pos_all[:humanoid_env._motion_lib._num_unique_motions]]) + # humanoid_env._motion_lib.get_motion_num_steps().sum() + + failed_keys = humanoid_env._motion_lib._motion_data_keys[terminate_hist[: humanoid_env._motion_lib._num_unique_motions]] + success_keys = humanoid_env._motion_lib._motion_data_keys[~terminate_hist[: humanoid_env._motion_lib._num_unique_motions]] + # print("failed", humanoid_env._motion_lib._motion_data_keys[np.concatenate(self.terminate_memory)[:humanoid_env._motion_lib._num_unique_motions]]) + + metrics_all = compute_metrics_lite(pred_pos_all, gt_pos_all) + metrics_succ = compute_metrics_lite(pred_pos_all_succ, gt_pos_all_succ) + + metrics_all_print = {m: np.mean(v) for m, v in metrics_all.items()} + metrics_succ_print = {m: np.mean(v) for m, v in metrics_succ.items()} + + if len(metrics_succ_print) == 0: + print("No success!!!") + metrics_succ_print = metrics_all_print + + print("------------------------------------------") + print(f"Success Rate: {self.success_rate:.10f}") + print("All: ", " \t".join([f"{k}: {v:.3f}" for k, v in metrics_all_print.items()])) + print("Succ: "," \t".join([f"{k}: {v:.3f}" for k, v in metrics_succ_print.items()])) + print("Failed keys: ", len(failed_keys), failed_keys) + + end = True + + eval_info = { + "eval_success_rate": self.success_rate, + "eval_mpjpe_all": metrics_all_print['mpjpe_g'], + "eval_mpjpe_succ": metrics_succ_print['mpjpe_g'], + "accel_dist": metrics_succ_print['accel_dist'], + "vel_dist": metrics_succ_print['vel_dist'], + "mpjpel_all": metrics_all_print['mpjpe_l'], + "mpjpel_succ": metrics_succ_print['mpjpe_l'], + "mpjpe_pa": metrics_succ_print['mpjpe_pa'], + } + # failed_keys = humanoid_env._motion_lib._motion_data_keys[terminate_hist[:humanoid_env._motion_lib._num_unique_motions]] + # success_keys = humanoid_env._motion_lib._motion_data_keys[~terminate_hist[:humanoid_env._motion_lib._num_unique_motions]] + # print("failed", humanoid_env._motion_lib._motion_data_keys[np.concatenate(self.terminate_memory)[:humanoid_env._motion_lib._num_unique_motions]]) + # joblib.dump(failed_keys, "output/dgx/smpl_im_shape_long_1/failed_1.pkl") + # joblib.dump(success_keys, "output/dgx/smpl_im_fit_3_1/long_succ.pkl") + # print("....") + return done, {"end": end, "eval_info": eval_info, "failed_keys": failed_keys, "success_keys": success_keys} + + done[:] = 1 # Turning all of the sequences done and reset for the next batch of eval. + + humanoid_env.forward_motion_samples() + self.terminate_state = torch.zeros(self.vec_env.env.task.num_envs, device=self.device) + + self.pbar.update(1) + self.pbar.refresh() + self.mpjpe, self.gt_pos, self.pred_pos, = [], [], [] + + + update_str = f"Terminated: {self.terminate_state.sum().item()} | max frames: {curr_max} | steps {self.curr_stpes} | Start: {humanoid_env.start_idx} | Succ rate: {self.success_rate:.3f} | Mpjpe: {np.mean(self.mpjpe_all) * 1000:.3f}" + self.pbar.set_description(update_str) + + return done, {"end": end, "eval_info": eval_info, "failed_keys": [], "success_keys": []} diff --git a/phc/learning/im_amp_players.py b/phc/learning/im_amp_players.py new file mode 100644 index 0000000..cc0cbea --- /dev/null +++ b/phc/learning/im_amp_players.py @@ -0,0 +1,332 @@ + + +import glob +import os +import sys +import pdb +import os.path as osp +sys.path.append(os.getcwd()) + +import numpy as np +import torch +from phc.utils.flags import flags +from rl_games.algos_torch import torch_ext +from phc.utils.running_mean_std import RunningMeanStd +from rl_games.common.player import BasePlayer + +import learning.amp_players as amp_players +from tqdm import tqdm +import joblib +import time +from smpl_sim.smpllib.smpl_eval import compute_metrics_lite +from rl_games.common.tr_helpers import unsqueeze_obs + +COLLECT_Z = False + +class IMAMPPlayerContinuous(amp_players.AMPPlayerContinuous): + def __init__(self, config): + super().__init__(config) + + self.terminate_state = torch.zeros(self.env.task.num_envs, device=self.device) + self.terminate_memory = [] + self.mpjpe, self.mpjpe_all = [], [] + self.gt_pos, self.gt_pos_all = [], [] + self.pred_pos, self.pred_pos_all = [], [] + self.curr_stpes = 0 + + if COLLECT_Z: + self.zs, self.zs_all = [], [] + + humanoid_env = self.env.task + humanoid_env._termination_distances[:] = 0.5 # if not humanoid_env.strict_eval else 0.25 # ZL: use UHC's termination distance + humanoid_env._recovery_episode_prob, humanoid_env._fall_init_prob = 0, 0 + + if flags.im_eval: + self.success_rate = 0 + self.pbar = tqdm(range(humanoid_env._motion_lib._num_unique_motions // humanoid_env.num_envs)) + humanoid_env.zero_out_far = False + humanoid_env.zero_out_far_train = False + + if len(humanoid_env._reset_bodies_id) > 15: + humanoid_env._reset_bodies_id = humanoid_env._eval_track_bodies_id # Following UHC. Only do it for full body, not for three point/two point trackings. + + humanoid_env.cycle_motion = False + self.print_stats = False + + # joblib.dump({"mlp": self.model.a2c_network.actor_mlp, "mu": self.model.a2c_network.mu}, "single_model.pkl") # ZL: for saving part of the model. + return + + def _post_step(self, info, done): + super()._post_step(info) + + + # modify done such that games will exit and reset. + if flags.im_eval: + + humanoid_env = self.env.task + + termination_state = torch.logical_and(self.curr_stpes <= humanoid_env._motion_lib.get_motion_num_steps() - 1, info["terminate"]) # if terminate after the last frame, then it is not a termination. curr_step is one step behind simulation. + # termination_state = info["terminate"] + self.terminate_state = torch.logical_or(termination_state, self.terminate_state) + if (~self.terminate_state).sum() > 0: + max_possible_id = humanoid_env._motion_lib._num_unique_motions - 1 + curr_ids = humanoid_env._motion_lib._curr_motion_ids + if (max_possible_id == curr_ids).sum() > 0: # When you are running out of motions. + bound = (max_possible_id == curr_ids).nonzero()[0] + 1 + if (~self.terminate_state[:bound]).sum() > 0: + curr_max = humanoid_env._motion_lib.get_motion_num_steps()[:bound][~self.terminate_state[:bound]].max() + else: + curr_max = (self.curr_stpes - 1) # the ones that should be counted have teimrated + else: + curr_max = humanoid_env._motion_lib.get_motion_num_steps()[~self.terminate_state].max() + + if self.curr_stpes >= curr_max: curr_max = self.curr_stpes + 1 # For matching up the current steps and max steps. + else: + curr_max = humanoid_env._motion_lib.get_motion_num_steps().max() + + self.mpjpe.append(info["mpjpe"]) + self.gt_pos.append(info["body_pos_gt"]) + self.pred_pos.append(info["body_pos"]) + if COLLECT_Z: self.zs.append(info["z"]) + self.curr_stpes += 1 + + if self.curr_stpes >= curr_max or self.terminate_state.sum() == humanoid_env.num_envs: + + self.terminate_memory.append(self.terminate_state.cpu().numpy()) + self.success_rate = (1 - np.concatenate(self.terminate_memory)[: humanoid_env._motion_lib._num_unique_motions].mean()) + + # MPJPE + all_mpjpe = torch.stack(self.mpjpe) + try: + assert(all_mpjpe.shape[0] == curr_max or self.terminate_state.sum() == humanoid_env.num_envs) # Max should be the same as the number of frames in the motion. + except: + import ipdb; ipdb.set_trace() + print('??') + + all_mpjpe = [all_mpjpe[: (i - 1), idx].mean() for idx, i in enumerate(humanoid_env._motion_lib.get_motion_num_steps())] # -1 since we do not count the first frame. + all_body_pos_pred = np.stack(self.pred_pos) + all_body_pos_pred = [all_body_pos_pred[: (i - 1), idx] for idx, i in enumerate(humanoid_env._motion_lib.get_motion_num_steps())] + all_body_pos_gt = np.stack(self.gt_pos) + all_body_pos_gt = [all_body_pos_gt[: (i - 1), idx] for idx, i in enumerate(humanoid_env._motion_lib.get_motion_num_steps())] + + if COLLECT_Z: + all_zs = torch.stack(self.zs) + all_zs = [all_zs[: (i - 1), idx] for idx, i in enumerate(humanoid_env._motion_lib.get_motion_num_steps())] + self.zs_all += all_zs + + + self.mpjpe_all.append(all_mpjpe) + self.pred_pos_all += all_body_pos_pred + self.gt_pos_all += all_body_pos_gt + + + if (humanoid_env.start_idx + humanoid_env.num_envs >= humanoid_env._motion_lib._num_unique_motions): + terminate_hist = np.concatenate(self.terminate_memory) + succ_idxes = np.nonzero(~terminate_hist[: humanoid_env._motion_lib._num_unique_motions])[0].tolist() + + pred_pos_all_succ = [(self.pred_pos_all[:humanoid_env._motion_lib._num_unique_motions])[i] for i in succ_idxes] + gt_pos_all_succ = [(self.gt_pos_all[: humanoid_env._motion_lib._num_unique_motions])[i] for i in succ_idxes] + + pred_pos_all = self.pred_pos_all[:humanoid_env._motion_lib._num_unique_motions] + gt_pos_all = self.gt_pos_all[: humanoid_env._motion_lib._num_unique_motions] + + # np.sum([i.shape[0] for i in self.pred_pos_all[:humanoid_env._motion_lib._num_unique_motions]]) + # humanoid_env._motion_lib.get_motion_num_steps().sum() + + failed_keys = humanoid_env._motion_lib._motion_data_keys[terminate_hist[: humanoid_env._motion_lib._num_unique_motions]] + success_keys = humanoid_env._motion_lib._motion_data_keys[~terminate_hist[: humanoid_env._motion_lib._num_unique_motions]] + # print("failed", humanoid_env._motion_lib._motion_data_keys[np.concatenate(self.terminate_memory)[:humanoid_env._motion_lib._num_unique_motions]]) + if flags.real_traj: + pred_pos_all = [i[:, humanoid_env._reset_bodies_id] for i in pred_pos_all] + gt_pos_all = [i[:, humanoid_env._reset_bodies_id] for i in gt_pos_all] + pred_pos_all_succ = [i[:, humanoid_env._reset_bodies_id] for i in pred_pos_all_succ] + gt_pos_all_succ = [i[:, humanoid_env._reset_bodies_id] for i in gt_pos_all_succ] + + + + metrics = compute_metrics_lite(pred_pos_all, gt_pos_all) + metrics_succ = compute_metrics_lite(pred_pos_all_succ, gt_pos_all_succ) + + metrics_all_print = {m: np.mean(v) for m, v in metrics.items()} + metrics_print = {m: np.mean(v) for m, v in metrics_succ.items()} + + print("------------------------------------------") + print("------------------------------------------") + print(f"Success Rate: {self.success_rate:.10f}") + print("All: ", " \t".join([f"{k}: {v:.3f}" for k, v in metrics_all_print.items()])) + print("Succ: "," \t".join([f"{k}: {v:.3f}" for k, v in metrics_print.items()])) + # print(1 - self.terminate_state.sum() / self.terminate_state.shape[0]) + print(self.config['network_path']) + if COLLECT_Z: + zs_all = self.zs_all[:humanoid_env._motion_lib._num_unique_motions] + zs_dump = {k: zs_all[idx].cpu().numpy() for idx, k in enumerate(humanoid_env._motion_lib._motion_data_keys)} + joblib.dump(zs_dump, osp.join(self.config['network_path'], "zs_run.pkl")) + + import ipdb; ipdb.set_trace() + + # joblib.dump(np.concatenate(self.zs_all[: humanoid_env._motion_lib._num_unique_motions]), osp.join(self.config['network_path'], "zs.pkl")) + + joblib.dump(failed_keys, osp.join(self.config['network_path'], "failed.pkl")) + joblib.dump(success_keys, osp.join(self.config['network_path'], "long_succ.pkl")) + print("....") + + done[:] = 1 # Turning all of the sequences done and reset for the next batch of eval. + + humanoid_env.forward_motion_samples() + self.terminate_state = torch.zeros( + self.env.task.num_envs, device=self.device + ) + + self.pbar.update(1) + self.pbar.refresh() + self.mpjpe, self.gt_pos, self.pred_pos, = [], [], [] + if COLLECT_Z: self.zs = [] + self.curr_stpes = 0 + + + update_str = f"Terminated: {self.terminate_state.sum().item()} | max frames: {curr_max} | steps {self.curr_stpes} | Start: {humanoid_env.start_idx} | Succ rate: {self.success_rate:.3f} | Mpjpe: {np.mean(self.mpjpe_all) * 1000:.3f}" + self.pbar.set_description(update_str) + + return done + + def get_z(self, obs_dict): + obs = obs_dict['obs'] + if self.has_batch_dimension == False: + obs = unsqueeze_obs(obs) + obs = self._preproc_obs(obs) + input_dict = { + 'is_train': False, + 'prev_actions': None, + 'obs': obs, + 'rnn_states': self.states + } + with torch.no_grad(): + z = self.model.a2c_network.eval_z(input_dict) + return z + + def run(self): + n_games = self.games_num + render = self.render_env + n_game_life = self.n_game_life + is_determenistic = self.is_determenistic + sum_rewards = 0 + sum_steps = 0 + sum_game_res = 0 + n_games = n_games * n_game_life + games_played = 0 + has_masks = False + has_masks_func = getattr(self.env, "has_action_mask", None) is not None + + op_agent = getattr(self.env, "create_agent", None) + if op_agent: + agent_inited = True + + if has_masks_func: + has_masks = self.env.has_action_mask() + + need_init_rnn = self.is_rnn + for t in range(n_games): + if games_played >= n_games: + break + obs_dict = self.env_reset() + + batch_size = 1 + batch_size = self.get_batch_size(obs_dict["obs"], batch_size) + + if need_init_rnn: + self.init_rnn() + need_init_rnn = False + + cr = torch.zeros(batch_size, dtype=torch.float32, device=self.device) + steps = torch.zeros(batch_size, dtype=torch.float32, device=self.device) + + print_game_res = False + + done_indices = [] + + with torch.no_grad(): + for n in range(self.max_steps): + obs_dict = self.env_reset(done_indices) + + + if COLLECT_Z: z = self.get_z(obs_dict) + + + if has_masks: + masks = self.env.get_action_mask() + action = self.get_masked_action(obs_dict, masks, is_determenistic) + else: + action = self.get_action(obs_dict, is_determenistic) + + obs_dict, r, done, info = self.env_step(self.env, action) + + cr += r + steps += 1 + + if COLLECT_Z: info['z'] = z + done = self._post_step(info, done.clone()) + + if render: + self.env.render(mode="human") + time.sleep(self.render_sleep) + + all_done_indices = done.nonzero(as_tuple=False) + done_indices = all_done_indices[:: self.num_agents] + done_count = len(done_indices) + games_played += done_count + + if done_count > 0: + if self.is_rnn: + for s in self.states: + s[:, all_done_indices, :] = ( + s[:, all_done_indices, :] * 0.0 + ) + + cur_rewards = cr[done_indices].sum().item() + cur_steps = steps[done_indices].sum().item() + + cr = cr * (1.0 - done.float()) + steps = steps * (1.0 - done.float()) + sum_rewards += cur_rewards + sum_steps += cur_steps + + game_res = 0.0 + if isinstance(info, dict): + if "battle_won" in info: + print_game_res = True + game_res = info.get("battle_won", 0.5) + if "scores" in info: + print_game_res = True + game_res = info.get("scores", 0.5) + if self.print_stats: + if print_game_res: + print("reward:", cur_rewards / done_count, "steps:", cur_steps / done_count, "w:", game_res,) + else: + print("reward:", cur_rewards / done_count, "steps:", cur_steps / done_count,) + + sum_game_res += game_res + # if batch_size//self.num_agents == 1 or games_played >= n_games: + if games_played >= n_games: + break + + done_indices = done_indices[:, 0] + + print(sum_rewards) + if print_game_res: + print( + "av reward:", + sum_rewards / games_played * n_game_life, + "av steps:", + sum_steps / games_played * n_game_life, + "winrate:", + sum_game_res / games_played * n_game_life, + ) + else: + print( + "av reward:", + sum_rewards / games_played * n_game_life, + "av steps:", + sum_steps / games_played * n_game_life, + ) + + return diff --git a/phc/learning/loss_functions.py b/phc/learning/loss_functions.py new file mode 100644 index 0000000..30ab00b --- /dev/null +++ b/phc/learning/loss_functions.py @@ -0,0 +1,11 @@ +import torch + +def kl_multi(qm, qv, pm, pv): + """ + q: posterior + p: prior + ​ + """ + element_wise = 0.5 * (pv - qv + qv.exp() / pv.exp() + (qm - pm).pow(2) / pv.exp() - 1) + kl = element_wise.sum(-1) + return kl \ No newline at end of file diff --git a/phc/learning/network_builder.py b/phc/learning/network_builder.py new file mode 100644 index 0000000..dcacb69 --- /dev/null +++ b/phc/learning/network_builder.py @@ -0,0 +1,986 @@ +from rl_games.common import object_factory +from rl_games.algos_torch import torch_ext + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim + +import math +import numpy as np +from rl_games.algos_torch.d2rl import D2RLNet +from rl_games.algos_torch.sac_helper import SquashedNormal + + +def _create_initializer(func, **kwargs): + return lambda v: func(v, **kwargs) + + +def init_mlp(net, init_func): + if isinstance(net, nn.ModuleList): + for m in net: + if isinstance(m, nn.Linear): + init_func(m.weight) + if getattr(m, "bias", None) is not None: + torch.nn.init.zeros_(m.bias) + + elif isinstance(net, nn.Linear): + init_func(net.weight) + if getattr(net, "bias", None) is not None: + torch.nn.init.zeros_(net.bias) + + +class NetworkBuilder: + + def __init__(self, **kwargs): + pass + + def load(self, params): + pass + + def build(self, name, **kwargs): + pass + + def __call__(self, name, **kwargs): + return self.build(name, **kwargs) + + class BaseNetwork(nn.Module): + + def __init__(self, **kwargs): + nn.Module.__init__(self, **kwargs) + + self.activations_factory = object_factory.ObjectFactory() + self.activations_factory.register_builder('relu', lambda **kwargs: nn.ReLU(**kwargs)) + self.activations_factory.register_builder('tanh', lambda **kwargs: nn.Tanh(**kwargs)) + self.activations_factory.register_builder('sigmoid', lambda **kwargs: nn.Sigmoid(**kwargs)) + self.activations_factory.register_builder('elu', lambda **kwargs: nn.ELU(**kwargs)) + self.activations_factory.register_builder('selu', lambda **kwargs: nn.SELU(**kwargs)) + self.activations_factory.register_builder('silu', lambda **kwargs: nn.SiLU(**kwargs)) + self.activations_factory.register_builder('gelu', lambda **kwargs: nn.GELU(**kwargs)) + self.activations_factory.register_builder('softplus', lambda **kwargs: nn.Softplus(**kwargs)) + self.activations_factory.register_builder('None', lambda **kwargs: nn.Identity()) + + self.init_factory = object_factory.ObjectFactory() + #self.init_factory.register_builder('normc_initializer', lambda **kwargs : normc_initializer(**kwargs)) + self.init_factory.register_builder('const_initializer', lambda **kwargs: _create_initializer(nn.init.constant_, **kwargs)) + self.init_factory.register_builder('orthogonal_initializer', lambda **kwargs: _create_initializer(nn.init.orthogonal_, **kwargs)) + self.init_factory.register_builder('glorot_normal_initializer', lambda **kwargs: _create_initializer(nn.init.xavier_normal_, **kwargs)) + self.init_factory.register_builder('glorot_uniform_initializer', lambda **kwargs: _create_initializer(nn.init.xavier_uniform_, **kwargs)) + self.init_factory.register_builder('variance_scaling_initializer', lambda **kwargs: _create_initializer(torch_ext.variance_scaling_initializer, **kwargs)) + self.init_factory.register_builder('random_uniform_initializer', lambda **kwargs: _create_initializer(nn.init.uniform_, **kwargs)) + self.init_factory.register_builder('kaiming_normal', lambda **kwargs: _create_initializer(nn.init.kaiming_normal_, **kwargs)) + self.init_factory.register_builder('orthogonal', lambda **kwargs: _create_initializer(nn.init.orthogonal_, **kwargs)) + self.init_factory.register_builder('default', lambda **kwargs: nn.Identity()) + + def is_separate_critic(self): + return False + + def is_rnn(self): + return False + + def get_default_rnn_state(self): + return None + + def _calc_input_size(self, input_shape, cnn_layers=None): + if cnn_layers is None: + assert (len(input_shape) == 1) + return input_shape[0] + else: + return nn.Sequential(*cnn_layers)(torch.rand(1, *(input_shape))).flatten(1).data.size(1) + + def _noisy_dense(self, inputs, units): + return layers.NoisyFactorizedLinear(inputs, units) + + def _build_rnn(self, name, input, units, layers): + if name == 'identity': + return torch_ext.IdentityRNN(input, units) + if name == 'lstm': + return torch.nn.LSTM(input, units, layers, batch_first=True) + if name == 'gru': + return torch.nn.GRU(input, units, layers, batch_first=True) + if name == 'sru': + from sru import SRU + return SRU(input, units, layers, dropout=0, layer_norm=False) + + def _build_res_mlp(self, input_size, units, activation, dense_func, norm_only_first_layer=False, norm_func_name=None): + print('build mlp:', input_size) + in_size = input_size + layers = [] + need_norm = True + for unit in units: + layers.append(dense_func(in_size, unit)) + layers.append(self.activations_factory.create(activation)) + + if not need_norm: + continue + if norm_only_first_layer and norm_func_name is not None: + need_norm = False + if norm_func_name == 'layer_norm': + layers.append(torch.nn.LayerNorm(unit)) + elif norm_func_name == 'batch_norm': + layers.append(torch.nn.BatchNorm1d(unit)) + in_size = unit + + return nn.Sequential(*layers) + + def _build_mlp(self, input_size, units, activation, dense_func, norm_only_first_layer=False, norm_func_name=None, d2rl=False): + if d2rl: + act_layers = [self.activations_factory.create(activation) for i in range(len(units))] + return D2RLNet(input_size, units, act_layers, norm_func_name) + else: + return self._build_res_mlp( + input_size, + units, + activation, + dense_func, + norm_func_name=norm_func_name, + ) + + def _build_conv(self, ctype, **kwargs): + print('conv_name:', ctype) + + if ctype == 'conv2d': + return self._build_cnn2d(**kwargs) + if ctype == 'coord_conv2d': + return self._build_cnn2d(conv_func=torch_ext.CoordConv2d, **kwargs) + if ctype == 'conv1d': + return self._build_cnn1d(**kwargs) + + def _build_cnn2d(self, input_shape, convs, activation, conv_func=torch.nn.Conv2d, norm_func_name=None): + in_channels = input_shape[0] + layers = [] + for conv in convs: + layers.append(conv_func(in_channels=in_channels, out_channels=conv['filters'], kernel_size=conv['kernel_size'], stride=conv['strides'], padding=conv['padding'])) + conv_func = torch.nn.Conv2d + act = self.activations_factory.create(activation) + layers.append(act) + in_channels = conv['filters'] + if norm_func_name == 'layer_norm': + layers.append(torch_ext.LayerNorm2d(in_channels)) + elif norm_func_name == 'batch_norm': + layers.append(torch.nn.BatchNorm2d(in_channels)) + return nn.Sequential(*layers) + + def _build_cnn1d(self, input_shape, convs, activation, norm_func_name=None): + print('conv1d input shape:', input_shape) + in_channels = input_shape[0] + layers = [] + for conv in convs: + layers.append(torch.nn.Conv1d(in_channels, conv['filters'], conv['kernel_size'], conv['strides'], conv['padding'])) + act = self.activations_factory.create(activation) + layers.append(act) + in_channels = conv['filters'] + if norm_func_name == 'layer_norm': + layers.append(torch.nn.LayerNorm(in_channels)) + elif norm_func_name == 'batch_norm': + layers.append(torch.nn.BatchNorm2d(in_channels)) + return nn.Sequential(*layers) + + +class A2CBuilder(NetworkBuilder): + + def __init__(self, **kwargs): + NetworkBuilder.__init__(self) + + def load(self, params): + self.params = params + + class Network(NetworkBuilder.BaseNetwork): + + def __init__(self, params, **kwargs): + actions_num = kwargs.pop('actions_num') + input_shape = kwargs.pop('input_shape') + self.value_size = kwargs.pop('value_size', 1) + self.num_seqs = num_seqs = kwargs.pop('num_seqs', 1) + NetworkBuilder.BaseNetwork.__init__(self) + self.load(params) + self.actor_cnn = nn.Sequential() + self.critic_cnn = nn.Sequential() + self.actor_mlp = nn.Sequential() + self.critic_mlp = nn.Sequential() + + if self.has_cnn: + input_shape = torch_ext.shape_whc_to_cwh(input_shape) + cnn_args = { + 'ctype': self.cnn['type'], + 'input_shape': input_shape, + 'convs': self.cnn['convs'], + 'activation': self.cnn['activation'], + 'norm_func_name': self.normalization, + } + self.actor_cnn = self._build_conv(**cnn_args) + + if self.separate: + self.critic_cnn = self._build_conv(**cnn_args) + + mlp_input_shape = self._calc_input_size(input_shape, self.actor_cnn) + + in_mlp_shape = mlp_input_shape + if len(self.units) == 0: + out_size = mlp_input_shape + else: + out_size = self.units[-1] + + if self.has_rnn: + if not self.is_rnn_before_mlp: + rnn_in_size = out_size + out_size = self.rnn_units + if self.rnn_concat_input: + rnn_in_size += in_mlp_shape + else: + rnn_in_size = in_mlp_shape + in_mlp_shape = self.rnn_units + + if self.separate: + self.a_rnn = self._build_rnn(self.rnn_name, rnn_in_size, self.rnn_units, self.rnn_layers) + self.c_rnn = self._build_rnn(self.rnn_name, rnn_in_size, self.rnn_units, self.rnn_layers) + if self.rnn_ln: + self.a_layer_norm = torch.nn.LayerNorm(self.rnn_units) + self.c_layer_norm = torch.nn.LayerNorm(self.rnn_units) + else: + self.rnn = self._build_rnn(self.rnn_name, rnn_in_size, self.rnn_units, self.rnn_layers) + if self.rnn_ln: + self.layer_norm = torch.nn.LayerNorm(self.rnn_units) + + mlp_args = {'input_size': in_mlp_shape, 'units': self.units, 'activation': self.activation, 'norm_func_name': self.normalization, 'dense_func': torch.nn.Linear, 'd2rl': self.is_d2rl, 'norm_only_first_layer': self.norm_only_first_layer} + self.actor_mlp = self._build_mlp(**mlp_args) + if self.separate: + self.critic_mlp = self._build_mlp(**mlp_args) + + self.value = torch.nn.Linear(out_size, self.value_size) + self.value_act = self.activations_factory.create(self.value_activation) + + if self.is_discrete: + self.logits = torch.nn.Linear(out_size, actions_num) + ''' + for multidiscrete actions num is a tuple + ''' + if self.is_multi_discrete: + self.logits = torch.nn.ModuleList([torch.nn.Linear(out_size, num) for num in actions_num]) + if self.is_continuous: + self.mu = torch.nn.Linear(out_size, actions_num) + self.mu_act = self.activations_factory.create(self.space_config['mu_activation']) + mu_init = self.init_factory.create(**self.space_config['mu_init']) + self.sigma_act = self.activations_factory.create(self.space_config['sigma_activation']) + sigma_init = self.init_factory.create(**self.space_config['sigma_init']) + + if self.space_config['fixed_sigma']: + self.sigma = nn.Parameter(torch.zeros(actions_num, requires_grad=True, dtype=torch.float32), requires_grad=True) + else: + self.sigma = torch.nn.Linear(out_size, actions_num) + + mlp_init = self.init_factory.create(**self.initializer) + if self.has_cnn: + cnn_init = self.init_factory.create(**self.cnn['initializer']) + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d): + cnn_init(m.weight) + if getattr(m, "bias", None) is not None: + torch.nn.init.zeros_(m.bias) + if isinstance(m, nn.Linear): + mlp_init(m.weight) + if getattr(m, "bias", None) is not None: + torch.nn.init.zeros_(m.bias) + + if self.is_continuous: + mu_init(self.mu.weight) + if self.space_config['fixed_sigma']: + sigma_init(self.sigma) + else: + sigma_init(self.sigma.weight) + + def forward(self, obs_dict): + obs = obs_dict['obs'] + states = obs_dict.get('rnn_states', None) + seq_length = obs_dict.get('seq_length', 1) + if self.has_cnn: + # for obs shape 4 + # input expected shape (B, W, H, C) + # convert to (B, C, W, H) + if len(obs.shape) == 4: + obs = obs.permute((0, 3, 1, 2)) + + if self.separate: + a_out = c_out = obs + a_out = self.actor_cnn(a_out) + a_out = a_out.contiguous().view(a_out.size(0), -1) + + c_out = self.critic_cnn(c_out) + c_out = c_out.contiguous().view(c_out.size(0), -1) + + if self.has_rnn: + if not self.is_rnn_before_mlp: + a_out_in = a_out + c_out_in = c_out + a_out = self.actor_mlp(a_out_in) + c_out = self.critic_mlp(c_out_in) + + if self.rnn_concat_input: + a_out = torch.cat([a_out, a_out_in], dim=1) + c_out = torch.cat([c_out, c_out_in], dim=1) + + batch_size = a_out.size()[0] + num_seqs = batch_size // seq_length + a_out = a_out.reshape(num_seqs, seq_length, -1) + c_out = c_out.reshape(num_seqs, seq_length, -1) + + if self.rnn_name == 'sru': + a_out = a_out.transpose(0, 1) + c_out = c_out.transpose(0, 1) + + if len(states) == 2: + a_states = states[0] + c_states = states[1] + else: + a_states = states[:2] + c_states = states[2:] + a_out, a_states = self.a_rnn(a_out, a_states) + c_out, c_states = self.c_rnn(c_out, c_states) + + if self.rnn_name == 'sru': + a_out = a_out.transpose(0, 1) + c_out = c_out.transpose(0, 1) + else: + if self.rnn_ln: + a_out = self.a_layer_norm(a_out) + c_out = self.c_layer_norm(c_out) + a_out = a_out.contiguous().reshape(a_out.size()[0] * a_out.size()[1], -1) + c_out = c_out.contiguous().reshape(c_out.size()[0] * c_out.size()[1], -1) + + if type(a_states) is not tuple: + a_states = (a_states,) + c_states = (c_states,) + states = a_states + c_states + + if self.is_rnn_before_mlp: + a_out = self.actor_mlp(a_out) + c_out = self.critic_mlp(c_out) + else: + a_out = self.actor_mlp(a_out) + c_out = self.critic_mlp(c_out) + + value = self.value_act(self.value(c_out)) + + if self.is_discrete: + logits = self.logits(a_out) + return logits, value, states + + if self.is_multi_discrete: + logits = [logit(a_out) for logit in self.logits] + return logits, value, states + + if self.is_continuous: + mu = self.mu_act(self.mu(a_out)) + if self.space_config['fixed_sigma']: + sigma = mu * 0.0 + self.sigma_act(self.sigma) + else: + sigma = self.sigma_act(self.sigma(a_out)) + + return mu, sigma, value, states + else: + out = obs + out = self.actor_cnn(out) + out = out.flatten(1) + + if self.has_rnn: + out_in = out + if not self.is_rnn_before_mlp: + out_in = out + out = self.actor_mlp(out) + if self.rnn_concat_input: + out = torch.cat([out, out_in], dim=1) + + batch_size = out.size()[0] + num_seqs = batch_size // seq_length + out = out.reshape(num_seqs, seq_length, -1) + + if len(states) == 1: + states = states[0] + + if self.rnn_name == 'sru': + out = out.transpose(0, 1) + + out, states = self.rnn(out, states) + out = out.contiguous().reshape(out.size()[0] * out.size()[1], -1) + + if self.rnn_name == 'sru': + out = out.transpose(0, 1) + if self.rnn_ln: + out = self.layer_norm(out) + if self.is_rnn_before_mlp: + out = self.actor_mlp(out) + if type(states) is not tuple: + states = (states,) + else: + out = self.actor_mlp(out) + value = self.value_act(self.value(out)) + + if self.central_value: + return value, states + + if self.is_discrete: + logits = self.logits(out) + return logits, value, states + if self.is_multi_discrete: + logits = [logit(out) for logit in self.logits] + return logits, value, states + if self.is_continuous: + mu = self.mu_act(self.mu(out)) + if self.space_config['fixed_sigma']: + sigma = self.sigma_act(self.sigma) + else: + sigma = self.sigma_act(self.sigma(out)) + return mu, mu * 0 + sigma, value, states + + def is_separate_critic(self): + return self.separate + + def is_rnn(self): + return self.has_rnn + + def get_default_rnn_state(self): + if not self.has_rnn: + return None + num_layers = self.rnn_layers + if self.rnn_name == 'identity': + rnn_units = 1 + else: + rnn_units = self.rnn_units + if self.rnn_name == 'lstm': + if self.separate: + return (torch.zeros((num_layers, self.num_seqs, rnn_units)), torch.zeros((num_layers, self.num_seqs, rnn_units)), torch.zeros((num_layers, self.num_seqs, rnn_units)), torch.zeros((num_layers, self.num_seqs, rnn_units))) + else: + return (torch.zeros((num_layers, self.num_seqs, rnn_units)), torch.zeros((num_layers, self.num_seqs, rnn_units))) + else: + if self.separate: + return (torch.zeros((num_layers, self.num_seqs, rnn_units)), torch.zeros((num_layers, self.num_seqs, rnn_units))) + else: + return (torch.zeros((num_layers, self.num_seqs, rnn_units)),) + + def load(self, params): + self.separate = params.get('separate', False) + self.units = params['mlp']['units'] + self.activation = params['mlp']['activation'] + self.initializer = params['mlp']['initializer'] + self.is_d2rl = params['mlp'].get('d2rl', False) + self.norm_only_first_layer = params['mlp'].get('norm_only_first_layer', False) + self.value_activation = params.get('value_activation', 'None') + self.normalization = params.get('normalization', None) + self.has_rnn = 'rnn' in params + self.has_space = 'space' in params + self.central_value = params.get('central_value', False) + self.joint_obs_actions_config = params.get('joint_obs_actions', None) + + if self.has_space: + self.is_multi_discrete = 'multi_discrete' in params['space'] + self.is_discrete = 'discrete' in params['space'] + self.is_continuous = 'continuous' in params['space'] + if self.is_continuous: + self.space_config = params['space']['continuous'] + elif self.is_discrete: + self.space_config = params['space']['discrete'] + elif self.is_multi_discrete: + self.space_config = params['space']['multi_discrete'] + else: + self.is_discrete = False + self.is_continuous = False + self.is_multi_discrete = False + + if self.has_rnn: + self.rnn_units = params['rnn']['units'] + self.rnn_layers = params['rnn']['layers'] + self.rnn_name = params['rnn']['name'] + self.rnn_ln = params['rnn'].get('layer_norm', False) + self.is_rnn_before_mlp = params['rnn'].get('before_mlp', False) + self.rnn_concat_input = params['rnn'].get('concat_input', False) + + if 'cnn' in params: + self.has_cnn = True + self.cnn = params['cnn'] + else: + self.has_cnn = False + + def build(self, name, **kwargs): + net = A2CBuilder.Network(self.params, **kwargs) + return net + + +class Conv2dAuto(nn.Conv2d): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2) # dynamic add padding based on the kernel_size + + +class ConvBlock(nn.Module): + + def __init__(self, in_channels, out_channels, use_bn=False): + super().__init__() + self.use_bn = use_bn + self.conv = Conv2dAuto(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, bias=not use_bn) + if use_bn: + self.bn = nn.BatchNorm2d(out_channels) + + def forward(self, x): + x = self.conv(x) + if self.use_bn: + x = self.bn(x) + return x + + +class ResidualBlock(nn.Module): + + def __init__(self, channels, activation='relu', use_bn=False, use_zero_init=True, use_attention=False): + super().__init__() + self.use_zero_init = use_zero_init + self.use_attention = use_attention + if use_zero_init: + self.alpha = nn.Parameter(torch.zeros(1)) + self.activation = activation + self.conv1 = ConvBlock(channels, channels, use_bn) + self.conv2 = ConvBlock(channels, channels, use_bn) + self.activate1 = nn.ELU() + self.activate2 = nn.ELU() + if use_attention: + self.ca = ChannelAttention(channels) + self.sa = SpatialAttention() + + def forward(self, x): + residual = x + x = self.activate1(x) + x = self.conv1(x) + x = self.activate2(x) + x = self.conv2(x) + if self.use_attention: + x = self.ca(x) * x + x = self.sa(x) * x + if self.use_zero_init: + x = x * self.alpha + residual + else: + x = x + residual + return x + + +class ImpalaSequential(nn.Module): + + def __init__(self, in_channels, out_channels, activation='elu', use_bn=True, use_zero_init=False): + super().__init__() + self.conv = ConvBlock(in_channels, out_channels, use_bn) + self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.res_block1 = ResidualBlock(out_channels, activation=activation, use_bn=use_bn, use_zero_init=use_zero_init) + self.res_block2 = ResidualBlock(out_channels, activation=activation, use_bn=use_bn, use_zero_init=use_zero_init) + + def forward(self, x): + x = self.conv(x) + x = self.max_pool(x) + x = self.res_block1(x) + x = self.res_block2(x) + return x + + +class A2CResnetBuilder(NetworkBuilder): + + def __init__(self, **kwargs): + NetworkBuilder.__init__(self) + + def load(self, params): + self.params = params + + class Network(NetworkBuilder.BaseNetwork): + + def __init__(self, params, **kwargs): + actions_num = kwargs.pop('actions_num') + input_shape = kwargs.pop('input_shape') + input_shape = torch_ext.shape_whc_to_cwh(input_shape) + self.num_seqs = num_seqs = kwargs.pop('num_seqs', 1) + self.value_size = kwargs.pop('value_size', 1) + + NetworkBuilder.BaseNetwork.__init__(self, **kwargs) + self.load(params) + + self.cnn = self._build_impala(input_shape, self.conv_depths) + mlp_input_shape = self._calc_input_size(input_shape, self.cnn) + + in_mlp_shape = mlp_input_shape + + if len(self.units) == 0: + out_size = mlp_input_shape + else: + out_size = self.units[-1] + + if self.has_rnn: + if not self.is_rnn_before_mlp: + rnn_in_size = out_size + out_size = self.rnn_units + else: + rnn_in_size = in_mlp_shape + in_mlp_shape = self.rnn_units + self.rnn = self._build_rnn(self.rnn_name, rnn_in_size, self.rnn_units, self.rnn_layers) + #self.layer_norm = torch.nn.LayerNorm(self.rnn_units) + + mlp_args = {'input_size': in_mlp_shape, 'units': self.units, 'activation': self.activation, 'norm_func_name': self.normalization, 'dense_func': torch.nn.Linear} + + self.mlp = self._build_mlp(**mlp_args) + + self.value = torch.nn.Linear(out_size, self.value_size) + self.value_act = self.activations_factory.create(self.value_activation) + self.flatten_act = self.activations_factory.create(self.activation) + if self.is_discrete: + self.logits = torch.nn.Linear(out_size, actions_num) + if self.is_continuous: + self.mu = torch.nn.Linear(out_size, actions_num) + self.mu_act = self.activations_factory.create(self.space_config['mu_activation']) + mu_init = self.init_factory.create(**self.space_config['mu_init']) + self.sigma_act = self.activations_factory.create(self.space_config['sigma_activation']) + sigma_init = self.init_factory.create(**self.space_config['sigma_init']) + + if self.space_config['fixed_sigma']: + self.sigma = nn.Parameter(torch.zeros(actions_num, requires_grad=True, dtype=torch.float32), requires_grad=True) + else: + self.sigma = torch.nn.Linear(out_size, actions_num) + + mlp_init = self.init_factory.create(**self.initializer) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + #nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('elu')) + for m in self.mlp: + if isinstance(m, nn.Linear): + mlp_init(m.weight) + + if self.is_discrete: + mlp_init(self.logits.weight) + if self.is_continuous: + mu_init(self.mu.weight) + if self.space_config['fixed_sigma']: + sigma_init(self.sigma) + else: + sigma_init(self.sigma.weight) + + mlp_init(self.value.weight) + + def forward(self, obs_dict): + obs = obs_dict['obs'] + obs = obs.permute((0, 3, 1, 2)) + states = obs_dict.get('rnn_states', None) + seq_length = obs_dict.get('seq_length', 1) + out = obs + out = self.cnn(out) + out = out.flatten(1) + out = self.flatten_act(out) + + if self.has_rnn: + if not self.is_rnn_before_mlp: + out = self.mlp(out) + + batch_size = out.size()[0] + num_seqs = batch_size // seq_length + out = out.reshape(num_seqs, seq_length, -1) + if len(states) == 1: + states = states[0] + out, states = self.rnn(out, states) + out = out.contiguous().reshape(out.size()[0] * out.size()[1], -1) + #out = self.layer_norm(out) + if type(states) is not tuple: + states = (states,) + + if self.is_rnn_before_mlp: + for l in self.mlp: + out = l(out) + else: + for l in self.mlp: + out = l(out) + + value = self.value_act(self.value(out)) + + if self.is_discrete: + logits = self.logits(out) + return logits, value, states + + if self.is_continuous: + mu = self.mu_act(self.mu(out)) + if self.space_config['fixed_sigma']: + sigma = self.sigma_act(self.sigma) + else: + sigma = self.sigma_act(self.sigma(out)) + return mu, mu * 0 + sigma, value, states + + def load(self, params): + self.separate = params['separate'] + self.units = params['mlp']['units'] + self.activation = params['mlp']['activation'] + self.initializer = params['mlp']['initializer'] + self.is_discrete = 'discrete' in params['space'] + self.is_continuous = 'continuous' in params['space'] + self.is_multi_discrete = 'multi_discrete' in params['space'] + self.value_activation = params.get('value_activation', 'None') + self.normalization = params.get('normalization', None) + if self.is_continuous: + self.space_config = params['space']['continuous'] + elif self.is_discrete: + self.space_config = params['space']['discrete'] + elif self.is_multi_discrete: + self.space_config = params['space']['multi_discrete'] + self.has_rnn = 'rnn' in params + if self.has_rnn: + self.rnn_units = params['rnn']['units'] + self.rnn_layers = params['rnn']['layers'] + self.rnn_name = params['rnn']['name'] + self.is_rnn_before_mlp = params['rnn'].get('before_mlp', False) + + self.has_cnn = True + self.conv_depths = params['cnn']['conv_depths'] + + def _build_impala(self, input_shape, depths): + in_channels = input_shape[0] + layers = nn.ModuleList() + for d in depths: + layers.append(ImpalaSequential(in_channels, d)) + in_channels = d + return nn.Sequential(*layers) + + def is_separate_critic(self): + return False + + def is_rnn(self): + return self.has_rnn + + def get_default_rnn_state(self): + num_layers = self.rnn_layers + if self.rnn_name == 'lstm': + return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)), torch.zeros((num_layers, self.num_seqs, self.rnn_units))) + else: + return (torch.zeros((num_layers, self.num_seqs, self.rnn_units))) + + def build(self, name, **kwargs): + net = A2CResnetBuilder.Network(self.params, **kwargs) + return net + + +class DiagGaussianActor(NetworkBuilder.BaseNetwork): + """torch.distributions implementation of an diagonal Gaussian policy.""" + + def __init__(self, output_dim, log_std_bounds, **mlp_args): + super().__init__() + + self.log_std_bounds = log_std_bounds + + self.trunk = self._build_mlp(**mlp_args) + last_layer = list(self.trunk.children())[-2].out_features + self.trunk = nn.Sequential(*list(self.trunk.children()), nn.Linear(last_layer, output_dim)) + + def forward(self, obs): + mu, log_std = self.trunk(obs).chunk(2, dim=-1) + + # constrain log_std inside [log_std_min, log_std_max] + #log_std = torch.tanh(log_std) + log_std_min, log_std_max = self.log_std_bounds + log_std = torch.clamp(log_std, log_std_min, log_std_max) + #log_std = log_std_min + 0.5 * (log_std_max - log_std_min) * (log_std + 1) + + std = log_std.exp() + + # TODO: Refactor + + dist = SquashedNormal(mu, std) + # Modify to only return mu and std + return dist + + +class DoubleQCritic(NetworkBuilder.BaseNetwork): + """Critic network, employes double Q-learning.""" + + def __init__(self, output_dim, **mlp_args): + super().__init__() + + self.Q1 = self._build_mlp(**mlp_args) + last_layer = list(self.Q1.children())[-2].out_features + self.Q1 = nn.Sequential(*list(self.Q1.children()), nn.Linear(last_layer, output_dim)) + + self.Q2 = self._build_mlp(**mlp_args) + last_layer = list(self.Q2.children())[-2].out_features + self.Q2 = nn.Sequential(*list(self.Q2.children()), nn.Linear(last_layer, output_dim)) + + def forward(self, obs, action): + assert obs.size(0) == action.size(0) + + obs_action = torch.cat([obs, action], dim=-1) + q1 = self.Q1(obs_action) + q2 = self.Q2(obs_action) + + return q1, q2 + + +class SACBuilder(NetworkBuilder): + + def __init__(self, **kwargs): + NetworkBuilder.__init__(self) + + def load(self, params): + self.params = params + + def build(self, name, **kwargs): + net = SACBuilder.Network(self.params, **kwargs) + return net + + class Network(NetworkBuilder.BaseNetwork): + + def __init__(self, params, **kwargs): + actions_num = kwargs.pop('actions_num') + input_shape = kwargs.pop('input_shape') + obs_dim = kwargs.pop('obs_dim') + action_dim = kwargs.pop('action_dim') + self.num_seqs = num_seqs = kwargs.pop('num_seqs', 1) + NetworkBuilder.BaseNetwork.__init__(self) + self.load(params) + + mlp_input_shape = input_shape + + actor_mlp_args = {'input_size': obs_dim, 'units': self.units, 'activation': self.activation, 'norm_func_name': self.normalization, 'dense_func': torch.nn.Linear, 'd2rl': self.is_d2rl, 'norm_only_first_layer': self.norm_only_first_layer} + + critic_mlp_args = {'input_size': obs_dim + action_dim, 'units': self.units, 'activation': self.activation, 'norm_func_name': self.normalization, 'dense_func': torch.nn.Linear, 'd2rl': self.is_d2rl, 'norm_only_first_layer': self.norm_only_first_layer} + print("Building Actor") + self.actor = self._build_actor(2 * action_dim, self.log_std_bounds, **actor_mlp_args) + + if self.separate: + print("Building Critic") + self.critic = self._build_critic(1, **critic_mlp_args) + print("Building Critic Target") + self.critic_target = self._build_critic(1, **critic_mlp_args) + self.critic_target.load_state_dict(self.critic.state_dict()) + + mlp_init = self.init_factory.create(**self.initializer) + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d): + cnn_init(m.weight) + if getattr(m, "bias", None) is not None: + torch.nn.init.zeros_(m.bias) + if isinstance(m, nn.Linear): + mlp_init(m.weight) + if getattr(m, "bias", None) is not None: + torch.nn.init.zeros_(m.bias) + + def _build_critic(self, output_dim, **mlp_args): + return DoubleQCritic(output_dim, **mlp_args) + + def _build_actor(self, output_dim, log_std_bounds, **mlp_args): + return DiagGaussianActor(output_dim, log_std_bounds, **mlp_args) + + def forward(self, obs_dict): + """TODO""" + obs = obs_dict['obs'] + mu, sigma = self.actor(obs) + return mu, sigma + + def is_separate_critic(self): + return self.separate + + def load(self, params): + self.separate = params.get('separate', True) + self.units = params['mlp']['units'] + self.activation = params['mlp']['activation'] + self.initializer = params['mlp']['initializer'] + self.is_d2rl = params['mlp'].get('d2rl', False) + self.norm_only_first_layer = params['mlp'].get('norm_only_first_layer', False) + self.value_activation = params.get('value_activation', 'None') + self.normalization = params.get('normalization', None) + self.has_space = 'space' in params + self.value_shape = params.get('value_shape', 1) + self.central_value = params.get('central_value', False) + self.joint_obs_actions_config = params.get('joint_obs_actions', None) + self.log_std_bounds = params.get('log_std_bounds', None) + + if self.has_space: + self.is_discrete = 'discrete' in params['space'] + self.is_continuous = 'continuous' in params['space'] + if self.is_continuous: + self.space_config = params['space']['continuous'] + elif self.is_discrete: + self.space_config = params['space']['discrete'] + else: + self.is_discrete = False + self.is_continuous = False + + +''' +class DQNBuilder(NetworkBuilder): + def __init__(self, **kwargs): + NetworkBuilder.__init__(self) + + def load(self, params): + self.units = params['mlp']['units'] + self.activation = params['mlp']['activation'] + self.initializer = params['mlp']['initializer'] + self.regularizer = params['mlp']['regularizer'] + self.is_dueling = params['dueling'] + self.atoms = params['atoms'] + self.is_noisy = params['noisy'] + self.normalization = params.get('normalization', None) + if 'cnn' in params: + self.has_cnn = True + self.cnn = params['cnn'] + else: + self.has_cnn = False + + def build(self, name, **kwargs): + actions_num = kwargs.pop('actions_num') + input = kwargs.pop('inputs') + reuse = kwargs.pop('reuse') + is_train = kwargs.pop('is_train', True) + if self.is_noisy: + dense_layer = self._noisy_dense + else: + dense_layer = torch.nn.Linear + with tf.variable_scope(name, reuse=reuse): + out = input + if self.has_cnn: + cnn_args = { + 'name' :'dqn_cnn', + 'ctype' : self.cnn['type'], + 'input' : input, + 'convs' :self.cnn['convs'], + 'activation' : self.cnn['activation'], + 'initializer' : self.cnn['initializer'], + 'regularizer' : self.cnn['regularizer'], + 'norm_func_name' : self.normalization, + 'is_train' : is_train + } + out = self._build_conv(**cnn_args) + out = tf.contrib.layers.flatten(out) + + mlp_args = { + 'name' :'dqn_mlp', + 'input' : out, + 'activation' : self.activation, + 'initializer' : self.initializer, + 'regularizer' : self.regularizer, + 'norm_func_name' : self.normalization, + 'is_train' : is_train, + 'dense_func' : dense_layer + } + if self.is_dueling: + if len(self.units) > 1: + mlp_args['units'] = self.units[:-1] + out = self._build_mlp(**mlp_args) + hidden_value = dense_layer(inputs=out, units=self.units[-1], kernel_initializer = self.init_factory.create(**self.initializer), activation=self.activations_factory.create(self.activation), kernel_regularizer = self.regularizer_factory.create(**self.regularizer), name='hidden_val') + hidden_advantage = dense_layer(inputs=out, units=self.units[-1], kernel_initializer = self.init_factory.create(**self.initializer), activation=self.activations_factory.create(self.activation), kernel_regularizer = self.regularizer_factory.create(**self.regularizer), name='hidden_adv') + + value = dense_layer(inputs=hidden_value, units=self.atoms, kernel_initializer = self.init_factory.create(**self.initializer), activation=tf.identity, kernel_regularizer = self.regularizer_factory.create(**self.regularizer), name='value') + advantage = dense_layer(inputs=hidden_advantage, units= actions_num * self.atoms, kernel_initializer = self.init_factory.create(**self.initializer), kernel_regularizer = self.regularizer_factory.create(**self.regularizer), activation=tf.identity, name='advantage') + advantage = tf.reshape(advantage, shape = [-1, actions_num, self.atoms]) + value = tf.reshape(value, shape = [-1, 1, self.atoms]) + q_values = value + advantage - tf.reduce_mean(advantage, reduction_indices=1, keepdims=True) + else: + mlp_args['units'] = self.units + out = self._build_mlp('dqn_mlp', out, self.units, self.activation, self.initializer, self.regularizer) + q_values = dense_layer(inputs=out, units=actions_num *self.atoms, kernel_initializer = self.init_factory.create(**self.initializer), kernel_regularizer = self.regularizer_factory.create(**self.regularizer), activation=tf.identity, name='q_vals') + q_values = tf.reshape(q_values, shape = [-1, actions_num, self.atoms]) + + if self.atoms == 1: + return tf.squeeze(q_values) + else: + return q_values + +''' diff --git a/phc/learning/network_loader.py b/phc/learning/network_loader.py new file mode 100644 index 0000000..96ef13a --- /dev/null +++ b/phc/learning/network_loader.py @@ -0,0 +1,176 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from phc.utils import torch_utils + +from easydict import EasyDict as edict +from phc.learning.vq_quantizer import EMAVectorQuantizer, Quantizer +from phc.learning.pnn import PNN + +def load_mcp_mlp(checkpoint, activation = "relu", device = "cpu", mlp_name = "actor_mlp"): + actvation_func = torch_utils.activation_facotry(activation) + key_name = f"a2c_network.{mlp_name}" + + loading_keys = [k for k in checkpoint['model'].keys() if k.startswith(key_name)] + if not mlp_name == "composer": + loading_keys += ["a2c_network.mu.weight", 'a2c_network.mu.bias'] + + loading_keys_linear = [k for k in loading_keys if k.endswith('weight')] + + nn_modules = [] + + for idx, key in enumerate(loading_keys_linear): + if len(checkpoint['model'][key].shape) == 1: # layernorm + layer = torch.nn.LayerNorm(*checkpoint['model'][key].shape[::-1]) + nn_modules.append(layer) + elif len(checkpoint['model'][key].shape) == 2: # nn + layer = nn.Linear(*checkpoint['model'][key].shape[::-1]) + nn_modules.append(layer) + if idx < len(loading_keys_linear) - 1: + nn_modules.append(actvation_func()) + else: + raise NotImplementedError + + mlp = nn.Sequential(*nn_modules) + + if mlp_name == "composer": + # ZL: shouldn't really have this here, but it's a quick fix for now. + mlp.append(actvation_func()) + + state_dict = mlp.state_dict() + + for idx, key_affix in enumerate(state_dict.keys()): + state_dict[key_affix].copy_(checkpoint['model'][loading_keys[idx]]) + + for param in mlp.parameters(): + param.requires_grad = False + + mlp.to(device) + mlp.eval() + + return mlp + +def load_pnn(checkpoint, num_prim, has_lateral, activation = "relu", device = "cpu"): + state_dict_load = checkpoint['model'] + + net_key_name = "a2c_network.pnn.actors.0" + loading_keys = [k for k in checkpoint['model'].keys() if k.startswith(net_key_name) and k.endswith('bias')] + layer_size = [] + for idx, key in enumerate(loading_keys): + layer_size.append(checkpoint['model'][key].shape[::-1][0]) + + mlp_args = {'input_size': state_dict_load['a2c_network.pnn.actors.0.0.weight'].shape[1], 'units':layer_size[:-1], 'activation': activation, 'dense_func': torch.nn.Linear} + pnn = PNN(mlp_args, output_size=checkpoint['model']['a2c_network.mu.bias'].shape[0], numCols=num_prim, has_lateral=has_lateral) + state_dict = pnn.state_dict() + for k in state_dict_load.keys(): + if "pnn" in k: + pnn_dict_key = k.split("pnn.")[1] + state_dict[pnn_dict_key].copy_(state_dict_load[k]) + + pnn.freeze_pnn(num_prim) + pnn.to(device) + return pnn + + +def load_z_encoder(checkpoint, activation = "relu", z_type = "sphere", device = "cpu"): + net_dict = edict() + + actvation_func = torch_utils.activation_facotry(activation) + if z_type == "sphere" or z_type == "uniform" or z_type == "vq_vae" or z_type == "vae": + net_key_name = "a2c_network._task_mlp" if "a2c_network._task_mlp.0.weight" in checkpoint['model'].keys() else "a2c_network.z_mlp" + elif z_type == "hyper": + net_key_name = "a2c_network.z_mlp" + else: + raise NotImplementedError + + loading_keys = [k for k in checkpoint['model'].keys() if k.startswith(net_key_name)] + actor = load_mlp(loading_keys, checkpoint, actvation_func) + + actor.to(device) + actor.eval() + + net_dict.encoder= actor + if "a2c_network.z_logvar.weight" in checkpoint['model'].keys(): + z_logvar = load_linear('a2c_network.z_logvar', checkpoint=checkpoint) + z_mu = load_linear('a2c_network.z_mu', checkpoint=checkpoint) + z_logvar.eval(); z_mu.eval() + net_dict.z_mu = z_mu.to(device) + net_dict.z_logvar = z_logvar.to(device) + + return net_dict + +def load_mlp(loading_keys, checkpoint, actvation_func): + + loading_keys_linear = [k for k in loading_keys if k.endswith('weight')] + nn_modules = [] + for idx, key in enumerate(loading_keys_linear): + if len(checkpoint['model'][key].shape) == 1: # layernorm + layer = torch.nn.LayerNorm(*checkpoint['model'][key].shape[::-1]) + nn_modules.append(layer) + elif len(checkpoint['model'][key].shape) == 2: # nn + layer = nn.Linear(*checkpoint['model'][key].shape[::-1]) + nn_modules.append(layer) + if idx < len(loading_keys_linear) - 1: + nn_modules.append(actvation_func()) + else: + raise NotImplementedError + + net = nn.Sequential(*nn_modules) + + state_dict = net.state_dict() + + for idx, key_affix in enumerate(state_dict.keys()): + state_dict[key_affix].copy_(checkpoint['model'][loading_keys[idx]]) + + for param in net.parameters(): + param.requires_grad = False + + return net + +def load_linear(net_name, checkpoint): + net = nn.Linear(checkpoint['model'][net_name + '.weight'].shape[1], checkpoint['model'][net_name + '.weight'].shape[0]) + state_dict = net.state_dict() + state_dict['weight'].copy_(checkpoint['model'][net_name + '.weight']) + state_dict['bias'].copy_(checkpoint['model'][net_name + '.bias']) + + return net + +def load_z_decoder(checkpoint, activation = "relu", z_type = "sphere", device = "cpu"): + actvation_func = torch_utils.activation_facotry(activation) + key_name = "a2c_network.actor_mlp" + loading_keys = [k for k in checkpoint['model'].keys() if k.startswith(key_name)] + ["a2c_network.mu.weight", 'a2c_network.mu.bias'] + + actor = load_mlp(loading_keys, checkpoint, actvation_func) + + actor.to(device) + actor.eval() + net_dict = edict() + + net_dict.decoder= actor + if z_type == "vq_vae": + quantizer_weights = checkpoint['model']['a2c_network.quantizer.embedding.weight'] + quantizer = Quantizer(quantizer_weights.shape[0], quantizer_weights.shape[1], beta = 0.25) + state_dict = quantizer.state_dict() + state_dict['embedding.weight'].copy_(quantizer_weights) + + quantizer.to(device) + quantizer.eval() + net_dict.quantizer = quantizer + + elif z_type == "vae" and "a2c_network.z_prior.0.weight" in checkpoint['model'].keys(): + prior_loading_keys = [k for k in checkpoint['model'].keys() if k.startswith("a2c_network.z_prior.")] + z_prior = load_mlp(prior_loading_keys, checkpoint, actvation_func) + z_prior.append(actvation_func()) + z_prior_mu = load_linear('a2c_network.z_prior_mu', checkpoint=checkpoint) + + z_prior.eval(); z_prior_mu.eval() + net_dict.z_prior = z_prior.to(device) + net_dict.z_prior_mu = z_prior_mu.to(device) + + if "a2c_network.z_prior_logvar.weight" in checkpoint['model'].keys(): + z_prior_logvar = load_linear('a2c_network.z_prior_logvar', checkpoint=checkpoint) + z_prior_logvar.eval() + net_dict.z_prior_logvar = z_prior_logvar.to(device) + + return net_dict \ No newline at end of file diff --git a/phc/learning/pnn.py b/phc/learning/pnn.py new file mode 100644 index 0000000..0e1808f --- /dev/null +++ b/phc/learning/pnn.py @@ -0,0 +1,131 @@ + + +import torch +import torch.nn as nn +from phc.learning.network_builder import NetworkBuilder +from collections import defaultdict +from rl_games.algos_torch import torch_ext +from tqdm import tqdm + + +class PNN(NetworkBuilder.BaseNetwork): + + def __init__(self, mlp_args, output_size=69, numCols=4, has_lateral=True): + super(PNN, self).__init__() + self.numCols = numCols + units = mlp_args['units'] + dense_func = mlp_args['dense_func'] + self.has_lateral = has_lateral + + self.actors = nn.ModuleList() + for i in range(numCols): + mlp = self._build_sequential_mlp(output_size, **mlp_args) + self.actors.append(mlp) + + if self.has_lateral: + + self.u = nn.ModuleList() + + for i in range(numCols - 1): + self.u.append(nn.ModuleList()) + for j in range(i + 1): + u = nn.Sequential() + in_size = units[0] + for unit in units[1:]: + u.append(dense_func(in_size, unit, bias=False)) + in_size = unit + u.append(dense_func(units[-1], output_size, bias=False)) + # torch.nn.init.zeros_(u[-1].weight) + self.u[i].append(u) + + def freeze_pnn(self, idx): + for param in self.actors[:idx].parameters(): + param.requires_grad = False + if self.has_lateral: + for param in self.u[:idx - 1].parameters(): + param.requires_grad = False + + def load_base_net(self, model_path, actors=1): + checkpoint = torch_ext.load_checkpoint(model_path) + for idx in range(actors): + self.load_actor(checkpoint, idx) + + def load_actor(self, checkpoint, idx=0): + state_dict = self.actors[idx].state_dict() + state_dict['0.weight'].copy_(checkpoint['model']['a2c_network.actor_mlp.0.weight']) + state_dict['0.bias'].copy_(checkpoint['model']['a2c_network.actor_mlp.0.bias']) + state_dict['2.weight'].copy_(checkpoint['model']['a2c_network.actor_mlp.2.weight']) + state_dict['2.bias'].copy_(checkpoint['model']['a2c_network.actor_mlp.2.bias']) + state_dict['4.weight'].copy_(checkpoint['model']['a2c_network.mu.weight']) + state_dict['4.bias'].copy_(checkpoint['model']['a2c_network.mu.bias']) + + def _build_sequential_mlp(self, actions_num, input_size, units, activation, dense_func, norm_only_first_layer=False, norm_func_name=None, need_norm = True): + print('build mlp:', input_size) + in_size = input_size + layers = [] + for unit in units: + layers.append(dense_func(in_size, unit)) + layers.append(self.activations_factory.create(activation)) + + if not need_norm: + continue + if norm_only_first_layer and norm_func_name is not None: + need_norm = False + if norm_func_name == 'layer_norm': + layers.append(torch.nn.LayerNorm(unit)) + elif norm_func_name == 'batch_norm': + layers.append(torch.nn.BatchNorm1d(unit)) + in_size = unit + + + layers.append(nn.Linear(units[-1], actions_num)) + return nn.Sequential(*layers) + + def forward(self, x, idx=-1): + if self.has_lateral: + # idx == -1: forward all, output all + # idx == others, forward till idx. + if idx == 0: + actions = self.actors[0](x) + return actions, [actions] + else: + if idx == -1: + idx = self.numCols - 1 + activation_cache = defaultdict(list) + + for curr_idx in range(0, idx + 1): + curr_actor = self.actors[curr_idx] + assert len(curr_actor) == 5 # Only support three MLPs right now + activation_1 = curr_actor[:2](x) + + acc_acts_1 = [self.u[curr_idx - 1][col_idx][0](activation_cache[0][col_idx]) for col_idx in range(len(activation_cache[0]))] # curr_idx - 1 as we need to go to the previous coloumn's index to activate the weight + activation_2 = curr_actor[3](curr_actor[2](activation_1) + sum(acc_acts_1)) # ReLU, full + + # acc_acts_2 = [self.u[curr_idx - 1][col_idx][1](activation_cache[1][col_idx]) for col_idx in range(len(activation_cache[1]))] + # actions = curr_actor[4](activation_2) + sum(acc_acts_2) + + actions = curr_actor[4](activation_2) # disable action space transfer. + + # acc_acts_1 = [] + # for col_idx in range(len(activation_cache[0])): + # acc_acts_1.append(self.u[curr_idx - 1][col_idx][0](activation_cache[0][col_idx])) + + # activation_2 = curr_actor[3](curr_actor[2](activation_1) + sum(acc_acts_1)) + + # acc_acts_2 = [] + # for col_idx in range(len(activation_cache[1])): + # acc_acts_2.append(self.u[curr_idx - 1][col_idx][1](activation_cache[1][col_idx])) + # actions = curr_actor[4](activation_2) + sum(acc_acts_2) + + activation_cache[0].append(activation_1) + activation_cache[1].append(activation_2) + activation_cache[2].append(actions) + + return actions, activation_cache[2] + else: + if idx != -1: + actions = self.actors[idx](x) + return actions, [actions] + else: + actions = [self.actors[idx](x) for idx in range(self.numCols)] + return actions, actions diff --git a/phc/learning/replay_buffer.py b/phc/learning/replay_buffer.py new file mode 100644 index 0000000..02cc957 --- /dev/null +++ b/phc/learning/replay_buffer.py @@ -0,0 +1,85 @@ +import torch + +class ReplayBuffer(): + def __init__(self, buffer_size, device): + self._head = 0 + self._total_count = 0 + self._buffer_size = buffer_size + self._device = device + self._data_buf = None + self._sample_idx = torch.randperm(buffer_size) + self._sample_head = 0 + + return + + def reset(self): + self._head = 0 + self._total_count = 0 + self._reset_sample_idx() + return + + def get_buffer_size(self): + return self._buffer_size + + def get_total_count(self): + return self._total_count + + def store(self, data_dict): + if (self._data_buf is None): + self._init_data_buf(data_dict) + + n = next(iter(data_dict.values())).shape[0] + buffer_size = self.get_buffer_size() + assert(n <= buffer_size) + + for key, curr_buf in self._data_buf.items(): + curr_n = data_dict[key].shape[0] + assert(n == curr_n) + + store_n = min(curr_n, buffer_size - self._head) + curr_buf[self._head:(self._head + store_n)] = data_dict[key][:store_n] + + remainder = n - store_n + if (remainder > 0): + curr_buf[0:remainder] = data_dict[key][store_n:] + + self._head = (self._head + n) % buffer_size + self._total_count += n + + return + + def sample(self, n): + total_count = self.get_total_count() + buffer_size = self.get_buffer_size() + + idx = torch.arange(self._sample_head, self._sample_head + n) + idx = idx % buffer_size + rand_idx = self._sample_idx[idx] + if (total_count < buffer_size): + rand_idx = rand_idx % self._head + + samples = dict() + for k, v in self._data_buf.items(): + samples[k] = v[rand_idx] + + self._sample_head += n + if (self._sample_head >= buffer_size): + self._reset_sample_idx() + + return samples + + def _reset_sample_idx(self): + buffer_size = self.get_buffer_size() + self._sample_idx[:] = torch.randperm(buffer_size) + self._sample_head = 0 + return + + def _init_data_buf(self, data_dict): + buffer_size = self.get_buffer_size() + self._data_buf = dict() + + for k, v in data_dict.items(): + v_shape = v.shape[1:] + self._data_buf[k] = torch.zeros((buffer_size,) + v_shape, device=self._device) + + return \ No newline at end of file diff --git a/phc/learning/running_norm.py b/phc/learning/running_norm.py new file mode 100644 index 0000000..6a62944 --- /dev/null +++ b/phc/learning/running_norm.py @@ -0,0 +1,44 @@ +import torch +import torch.nn as nn + + +class RunningNorm(nn.Module): + """ + y = (x-mean)/std + using running estimates of mean,std + """ + + def __init__(self, dim, demean=True, destd=True, clip=5.0): + super().__init__() + self.dim = dim + self.demean = demean + self.destd = destd + self.clip = clip + self.register_buffer("n", torch.tensor(0, dtype=torch.long)) + self.register_buffer("mean", torch.zeros(dim)) + self.register_buffer("var", torch.zeros(dim)) + self.register_buffer("std", torch.zeros(dim)) + + def update(self, x): + var_x, mean_x = torch.var_mean(x, dim=0, unbiased=False) + m = x.shape[0] + w = self.n.to(x.dtype) / (m + self.n).to(x.dtype) + self.var[:] = ( + w * self.var + (1 - w) * var_x + w * (1 - w) * (mean_x - self.mean).pow(2) + ) + self.mean[:] = w * self.mean + (1 - w) * mean_x + self.std[:] = torch.sqrt(self.var) + self.n += m + + def forward(self, x): + if self.training: + with torch.no_grad(): + self.update(x) + if self.n > 0: + if self.demean: + x = x - self.mean + if self.destd: + x = x / (self.std + 1e-8) + if self.clip: + x = torch.clamp(x, -self.clip, self.clip) + return x diff --git a/phc/learning/transformer.py b/phc/learning/transformer.py new file mode 100644 index 0000000..18d9b0f --- /dev/null +++ b/phc/learning/transformer.py @@ -0,0 +1,314 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +class PositionalEncoding(nn.Module): + + def __init__(self, d_model, dropout=0.1, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + + self.register_buffer('pe', pe) + + def forward(self, x): + # not used in the final model + x = x + self.pe[:x.shape[0], :] + return self.dropout(x) + + +# only for ablation / not used in the final model +class TimeEncoding(nn.Module): + + def __init__(self, d_model, dropout=0.1, max_len=5000): + super(TimeEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, mask, lengths): + time = mask * 1 / (lengths[..., None] - 1) + time = time[:, None] * torch.arange(time.shape[1], + device=x.device)[None, :] + time = time[:, 0].T + # add the time encoding + x = x + time[..., None] + return self.dropout(x) + + +class Encoder_TRANSFORMER(nn.Module): + + def __init__(self, + modeltype, + njoints, + nfeats, + num_frames, + latent_dim=256, + ff_size=1024, + num_layers=4, + num_heads=8, + dropout=0.1, + ablation=None, + activation="gelu", + **kargs): + super().__init__() + + self.modeltype = modeltype + self.njoints = njoints + self.nfeats = nfeats + self.num_frames = num_frames + + self.latent_dim = latent_dim + + self.ff_size = ff_size + self.num_layers = num_layers + self.num_heads = num_heads + self.dropout = dropout + + self.ablation = ablation + self.activation = activation + + self.input_feats = self.njoints * self.nfeats + + self.mu_layer = nn.Linear(self.latent_dim, self.latent_dim) + self.sigma_layer = nn.Linear(self.latent_dim, self.latent_dim) + + + self.skelEmbedding = nn.Linear(self.input_feats, self.latent_dim) + + self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, + self.dropout) + + # self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) + + seqTransEncoderLayer = nn.TransformerEncoderLayer( + d_model=self.latent_dim, + nhead=self.num_heads, + dim_feedforward=self.ff_size, + dropout=self.dropout, + activation=self.activation) + self.seqTransEncoder = nn.TransformerEncoder( + seqTransEncoderLayer, num_layers=self.num_layers) + + def forward(self, batch): + x, y, mask = batch["x"], batch["y"], batch["mask"] + bs, njoints, nfeats, nframes = x.shape + x = x.permute((3, 0, 1, 2)).reshape(nframes, bs, njoints * nfeats) + + # embedding of the skeleton + x = self.skelEmbedding(x) + + x = self.sequence_pos_encoder(x) + + # transformer layers + final = self.seqTransEncoder(x, src_key_padding_mask=~mask) + # get the average of the output + z = final.mean(axis=0) + + # extract mu and logvar + mu = self.mu_layer(z) + logvar = self.sigma_layer(z) + + + return {"mu": mu, "logvar": logvar} + + +class Decoder_TRANSFORMER(nn.Module): + + def __init__(self, + modeltype, + njoints, + nfeats, + num_frames, + latent_dim=256, + ff_size=1024, + num_layers=4, + num_heads=4, + dropout=0.1, + activation="gelu", + ablation=None, + **kargs): + super().__init__() + + self.modeltype = modeltype + self.njoints = njoints + self.nfeats = nfeats + self.num_frames = num_frames + + self.latent_dim = latent_dim + + self.ff_size = ff_size + self.num_layers = num_layers + self.num_heads = num_heads + self.dropout = dropout + + self.ablation = ablation + + self.activation = activation + + self.input_feats = self.njoints * self.nfeats + + # only for ablation / not used in the final model + if self.ablation == "time_encoding": + self.sequence_pos_encoder = TimeEncoding(self.dropout) + else: + self.sequence_pos_encoder = PositionalEncoding( + self.latent_dim, self.dropout) + + seqTransDecoderLayer = nn.TransformerDecoderLayer( + d_model=self.latent_dim, + nhead=self.num_heads, + dim_feedforward=self.ff_size, + dropout=self.dropout, + activation=activation) + self.seqTransDecoder = nn.TransformerDecoder( + seqTransDecoderLayer, num_layers=self.num_layers) + + self.finallayer = nn.Linear(self.latent_dim, self.input_feats) + + def forward(self, batch): + z, y, mask, lengths = batch["z"], batch["y"], batch["mask"], batch[ + "lengths"] + + latent_dim = z.shape[1] + bs, nframes = mask.shape + njoints, nfeats = self.njoints, self.nfeats + + # only for ablation / not used in the final model + if self.ablation == "zandtime": + yoh = F.one_hot(y, self.num_classes) + z = torch.cat((z, yoh), axis=1) + z = self.ztimelinear(z) + z = z[None] # sequence of size 1 + else: + # only for ablation / not used in the final model + if self.ablation == "concat_bias": + # sequence of size 2 + z = torch.stack((z, self.actionBiases[y]), axis=0) + else: + # shift the latent noise vector to be the action noise + if self.ablation != "average_encoder": + z = z + self.actionBiases[y] + + z = z[None] # sequence of size 1 + + timequeries = torch.zeros(nframes, bs, latent_dim, device=z.device) + + # only for ablation / not used in the final model + if self.ablation == "time_encoding": + timequeries = self.sequence_pos_encoder(timequeries, mask, lengths) + else: + timequeries = self.sequence_pos_encoder(timequeries) + + output = self.seqTransDecoder(tgt=timequeries, + memory=z, + tgt_key_padding_mask=~mask) + + output = self.finallayer(output).reshape(nframes, bs, njoints, nfeats) + + # zero for padded area + output[~mask.T] = 0 + output = output.permute(1, 2, 3, 0) + + batch["output"] = output + return batch + +def PE1d_sincos(seq_length, dim): + """ + :param d_model: dimension of the model + :param length: length of positions + :return: length*d_model position matrix + """ + if dim % 2 != 0: + raise ValueError("Cannot use sin/cos positional encoding with " + "odd dim (got dim={:d})".format(dim)) + pe = torch.zeros(seq_length, dim) + position = torch.arange(0, seq_length).unsqueeze(1) + div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * + -(math.log(10000.0) / dim))) + pe[:, 0::2] = torch.sin(position.float() * div_term) + pe[:, 1::2] = torch.cos(position.float() * div_term) + + return pe.unsqueeze(1) + + +class PositionEmbedding(nn.Module): + """ + Absolute pos embedding (standard), learned. + """ + def __init__(self, seq_length, dim, dropout, grad=False): + super().__init__() + self.embed = nn.Parameter(data=PE1d_sincos(seq_length, dim), requires_grad=grad) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x): + # x.shape: bs, seq_len, feat_dim + l = x.shape[1] + x = x.permute(1, 0, 2) + self.embed[:l].expand(x.permute(1, 0, 2).shape) + x = self.dropout(x.permute(1, 0, 2)) + return x + +class CausalAttention(nn.Module): + def __init__(self, dim, heads): + super().__init__() + self.heads = heads + self.scale = dim ** -0.5 + self.to_q = nn.Linear(dim, dim, bias=False) + self.to_kv = nn.Linear(dim, dim * 2, bias=False) + self.to_out = nn.Linear(dim, dim) + # self.attn_drop = nn.Dropout(0.1) + # self.resid_drop = nn.Dropout(0.1) + + def forward(self, x, mask=None, tgt_mask=None): + b, n, _, h = *x.shape, self.heads + q = self.to_q(x).reshape(b, n, h, -1).transpose(1, 2) + kv = self.to_kv(x).reshape(b, n, 2, h, -1).transpose(2, 3) + + k, v = kv[..., 0, :], kv[...,1, :] + k = k.transpose(1, 2) + v = v.transpose(1, 2) + dots = (q @ k.transpose(-2, -1)) * self.scale + + if mask is not None: + mask = mask[None, None, :, :].float() + dots.masked_fill_(mask==0, float('-inf')) + + # if tgt_mask is not None: + # tgt_mask = tgt_mask[:, None, :, :].float() + # tgt_mask = tgt_mask.transpose(2, 3) * tgt_mask + # dots.masked_fill_(tgt_mask==1, float('-inf')) + + + attn = dots.softmax(dim=-1) + #attn = self.attn_drop(attn) + out = attn @ v + out = out.transpose(1, 2).reshape(b, n, -1) + out = self.to_out(out) + #out = self.resid_drop(out) + return out + +class TransformerBlock(nn.Module): + def __init__(self, dim, heads): + super().__init__() + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.attn = CausalAttention(dim, heads=heads) + self.ff = nn.Sequential( + nn.Linear(dim, dim * 4), + nn.GELU(), + nn.Linear(dim * 4, dim), + ) + + def forward(self, x, tgt_mask): + b, s, h = x.shape + mask = torch.tril(torch.ones(s, s)).bool().to(x.device) + x = x + self.attn(self.norm1(x), mask=mask, tgt_mask=tgt_mask) + x = x + self.ff(self.norm2(x)) + return x \ No newline at end of file diff --git a/phc/learning/transformer_layers.py b/phc/learning/transformer_layers.py new file mode 100644 index 0000000..741e9f2 --- /dev/null +++ b/phc/learning/transformer_layers.py @@ -0,0 +1,281 @@ +# -*- coding: utf-8 -*- +import math +import torch +import torch.nn as nn +from torch import Tensor + +# Took from https://github.com/joeynmt/joeynmt/blob/fb66afcbe1beef9acd59283bcc084c4d4c1e6343/joeynmt/transformer_layers.py + + +# pylint: disable=arguments-differ +class MultiHeadedAttention(nn.Module): + """ + Multi-Head Attention module from "Attention is All You Need" + + Implementation modified from OpenNMT-py. + https://github.com/OpenNMT/OpenNMT-py + """ + + def __init__(self, num_heads: int, size: int, dropout: float = 0.1): + """ + Create a multi-headed attention layer. + :param num_heads: the number of heads + :param size: model size (must be divisible by num_heads) + :param dropout: probability of dropping a unit + """ + super().__init__() + + assert size % num_heads == 0 + + self.head_size = head_size = size // num_heads + self.model_size = size + self.num_heads = num_heads + + self.k_layer = nn.Linear(size, num_heads * head_size) + self.v_layer = nn.Linear(size, num_heads * head_size) + self.q_layer = nn.Linear(size, num_heads * head_size) + + self.output_layer = nn.Linear(size, size) + self.softmax = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + + def forward(self, k: Tensor, v: Tensor, q: Tensor, mask: Tensor = None): + """ + Computes multi-headed attention. + + :param k: keys [B, M, D] with M being the sentence length. + :param v: values [B, M, D] + :param q: query [B, M, D] + :param mask: optional mask [B, 1, M] or [B, M, M] + :return: + """ + batch_size = k.size(0) + num_heads = self.num_heads + + # project the queries (q), keys (k), and values (v) + k = self.k_layer(k) + v = self.v_layer(v) + q = self.q_layer(q) + + # reshape q, k, v for our computation to [batch_size, num_heads, ..] + k = k.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2) + v = v.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2) + q = q.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2) + + # compute scores + q = q / math.sqrt(self.head_size) + + # batch x num_heads x query_len x key_len + scores = torch.matmul(q, k.transpose(2, 3)) + # torch.Size([48, 8, 183, 183]) + + # apply the mask (if we have one) + # we add a dimension for the heads to it below: [B, 1, 1, M] + if mask is not None: + scores = scores.masked_fill(~mask.unsqueeze(1), float('-inf')) + + # apply attention dropout and compute context vectors. + attention = self.softmax(scores) + attention = self.dropout(attention) + # torch.Size([48, 8, 183, 183]) [bs, nheads, time, time] (for decoding) + + # v: torch.Size([48, 8, 183, 32]) (32 is 256/8) + # get context vector (select values with attention) and reshape + # back to [B, M, D] + context = torch.matmul(attention, v) # torch.Size([48, 8, 183, 32]) + context = context.transpose(1, 2).contiguous().view( + batch_size, -1, num_heads * self.head_size) + # torch.Size([48, 183, 256]) put back to 256 (combine the heads) + + output = self.output_layer(context) + # torch.Size([48, 183, 256]): 1 output per time step + + return output + + +# pylint: disable=arguments-differ +class PositionwiseFeedForward(nn.Module): + """ + Position-wise Feed-forward layer + Projects to ff_size and then back down to input_size. + """ + + def __init__(self, input_size, ff_size, dropout=0.1): + """ + Initializes position-wise feed-forward layer. + :param input_size: dimensionality of the input. + :param ff_size: dimensionality of intermediate representation + :param dropout: + """ + super().__init__() + self.layer_norm = nn.LayerNorm(input_size, eps=1e-6) + self.pwff_layer = nn.Sequential( + nn.Linear(input_size, ff_size), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(ff_size, input_size), + nn.Dropout(dropout), + ) + + def forward(self, x): + x_norm = self.layer_norm(x) + return self.pwff_layer(x_norm) + x + + +# pylint: disable=arguments-differ +class PositionalEncoding(nn.Module): + """ + Pre-compute position encodings (PE). + In forward pass, this adds the position-encodings to the + input for as many time steps as necessary. + + Implementation based on OpenNMT-py. + https://github.com/OpenNMT/OpenNMT-py + """ + + def __init__(self, + size: int = 0, + max_len: int = 5000): + """ + Positional Encoding with maximum length max_len + :param size: + :param max_len: + :param dropout: + """ + if size % 2 != 0: + raise ValueError("Cannot use sin/cos positional encoding with " + "odd dim (got dim={:d})".format(size)) + pe = torch.zeros(max_len, size) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp((torch.arange(0, size, 2, dtype=torch.float) * + -(math.log(10000.0) / size))) + pe[:, 0::2] = torch.sin(position.float() * div_term) + pe[:, 1::2] = torch.cos(position.float() * div_term) + pe = pe.unsqueeze(0) # shape: [1, size, max_len] + super().__init__() + self.register_buffer('pe', pe) + self.dim = size + + def forward(self, emb): + """Embed inputs. + Args: + emb (FloatTensor): Sequence of word vectors + ``(seq_len, batch_size, self.dim)`` + """ + # Add position encodings + return emb + self.pe[:, :emb.size(1)] + + +class TransformerEncoderLayer(nn.Module): + """ + One Transformer encoder layer has a Multi-head attention layer plus + a position-wise feed-forward layer. + """ + + def __init__(self, + size: int = 0, + ff_size: int = 0, + num_heads: int = 0, + dropout: float = 0.1): + """ + A single Transformer layer. + :param size: + :param ff_size: + :param num_heads: + :param dropout: + """ + super().__init__() + + self.layer_norm = nn.LayerNorm(size, eps=1e-6) + self.src_src_att = MultiHeadedAttention(num_heads, size, + dropout=dropout) + self.feed_forward = PositionwiseFeedForward(size, ff_size=ff_size, + dropout=dropout) + self.dropout = nn.Dropout(dropout) + self.size = size + + # pylint: disable=arguments-differ + def forward(self, x: Tensor, mask: Tensor) -> Tensor: + """ + Forward pass for a single transformer encoder layer. + First applies layer norm, then self attention, + then dropout with residual connection (adding the input to the result), + and then a position-wise feed-forward layer. + + :param x: layer input + :param mask: input mask + :return: output tensor + """ + x_norm = self.layer_norm(x) + h = self.src_src_att(x_norm, x_norm, x_norm, mask) + h = self.dropout(h) + x + o = self.feed_forward(h) + return o + + +class TransformerDecoderLayer(nn.Module): + """ + Transformer decoder layer. + + Consists of self-attention, source-attention, and feed-forward. + """ + + def __init__(self, + size: int = 0, + ff_size: int = 0, + num_heads: int = 0, + dropout: float = 0.1): + """ + Represents a single Transformer decoder layer. + + It attends to the source representation and the previous decoder states. + + :param size: model dimensionality + :param ff_size: size of the feed-forward intermediate layer + :param num_heads: number of heads + :param dropout: dropout to apply to input + """ + super().__init__() + self.size = size + + self.trg_trg_att = MultiHeadedAttention(num_heads, size, + dropout=dropout) + self.src_trg_att = MultiHeadedAttention(num_heads, size, + dropout=dropout) + + self.feed_forward = PositionwiseFeedForward(size, ff_size=ff_size, + dropout=dropout) + + self.x_layer_norm = nn.LayerNorm(size, eps=1e-6) + self.dec_layer_norm = nn.LayerNorm(size, eps=1e-6) + + self.dropout = nn.Dropout(dropout) + + # pylint: disable=arguments-differ + def forward(self, + x: Tensor = None, + memory: Tensor = None, + src_mask: Tensor = None, + trg_mask: Tensor = None) -> Tensor: + """ + Forward pass of a single Transformer decoder layer. + + :param x: inputs + :param memory: source representations + :param src_mask: source mask + :param trg_mask: target mask (so as to not condition on future steps) + :return: output tensor + """ + # decoder/target self-attention + x_norm = self.x_layer_norm(x) # torch.Size([48, 183, 256]) + h1 = self.trg_trg_att(x_norm, x_norm, x_norm, mask=trg_mask) + h1 = self.dropout(h1) + x + + # source-target attention + h1_norm = self.dec_layer_norm(h1) # torch.Size([48, 183, 256]) (same for memory) + h2 = self.src_trg_att(memory, memory, h1_norm, mask=src_mask) + + # final position-wise feed-forward layer + o = self.feed_forward(self.dropout(h2) + h1) + + return o diff --git a/phc/learning/unrealego/__init__.py b/phc/learning/unrealego/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/phc/learning/unrealego/base_model.py b/phc/learning/unrealego/base_model.py new file mode 100644 index 0000000..16e7698 --- /dev/null +++ b/phc/learning/unrealego/base_model.py @@ -0,0 +1,107 @@ +from operator import contains +import os +import torch +import torch.nn as nn +from collections import OrderedDict +from utils import util + + +class BaseModel(nn.Module): + def name(self): + return 'BaseModel' + + def initialize(self, opt): + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.isTrain = opt.isTrain + self.save_dir = os.path.join(opt.log_dir, opt.experiment_name) + self.loss_names = [] + self.model_names = [] + self.visual_names = [] + self.visual_pose_names = [] + self.image_paths = [] + self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU + + def set_input(self, input): + self.input = input + + # update learning rate + def update_learning_rate(self): + old_lr = self.optimizers[0].param_groups[0]['lr'] + for scheduler in self.schedulers: + scheduler.step() + lr = self.optimizers[0].param_groups[0]['lr'] + print('learning rate %.7f -> %.7f' % (old_lr, lr)) + + + # return training loss + def get_current_errors(self): + errors_ret = OrderedDict() + for name in self.loss_names: + if isinstance(name, str): + errors_ret[name] = getattr(self, 'loss_' + name).item() + return errors_ret + + # return visualization images + def get_current_visuals(self): + visual_ret = OrderedDict() + for name in self.visual_names: + if isinstance(name, str): + value = getattr(self, name) + + if "heatmap" in name: + is_heatmap = True + else: + is_heatmap = False + + visual_ret[name] = util.tensor2im(value.data, is_heatmap=is_heatmap) + + # if isinstance(value, list): + # visual_ret[name] = util.tensor2im(value[-1].data, is_heatmap) + # else: + # visual_ret[name] = util.tensor2im(value.data, is_heatmap) + + return visual_ret + + # save models + def save_networks(self, which_epoch): + for name in self.model_names: + if isinstance(name, str): + save_filename = '%s_net_%s.pth' % (which_epoch, name) + save_path = os.path.join(self.save_dir, save_filename) + net = getattr(self, 'net_' + name) + torch.save(net.cpu().state_dict(), save_path) + if len(self.gpu_ids) > 0 and torch.cuda.is_available(): + net.cuda() + + # load models + def load_networks(self, which_epoch=None, net=None, path_to_trained_weights=None): + if which_epoch is not None: + for name in self.model_names: + print(name) + if isinstance(name, str): + save_filename = '%s_net_%s.pth' % (which_epoch, name) + save_path = os.path.join(self.save_dir, save_filename) + net = getattr(self, 'net_'+name) + state_dict = torch.load(save_path) + net.load_state_dict(state_dict) + # net.load_state_dict(self.fix_model_state_dict(state_dict)) + if not self.isTrain: + net.eval() + else: + state_dict = torch.load(path_to_trained_weights) + if self.opt.distributed: + net.load_state_dict(self.fix_model_state_dict(state_dict)) + else: + net.load_state_dict(state_dict) + print('Loaded pre_trained {}'.format(os.path.basename(path_to_trained_weights))) + + def fix_model_state_dict(self, state_dict): + from collections import OrderedDict + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k + if name.startswith('module.'): + name = name[7:] # remove 'module.' of dataparallel + new_state_dict[name] = v + return new_state_dict \ No newline at end of file diff --git a/phc/learning/unrealego/egoglass_model.py b/phc/learning/unrealego/egoglass_model.py new file mode 100644 index 0000000..092ca8f --- /dev/null +++ b/phc/learning/unrealego/egoglass_model.py @@ -0,0 +1,207 @@ +from enum import auto +import torch +import torch.nn as nn +from torch.autograd import Variable +from torch.cuda.amp import autocast, GradScaler +from torch.nn import MSELoss + +import itertools +from .base_model import BaseModel +from . import network +from utils.loss import LossFuncLimb, LossFuncCosSim, LossFuncMPJPE +from utils.util import batch_compute_similarity_transform_torch + + +class EgoGlassModel(BaseModel): + def name(self): + return 'EgoGlass model' + + def initialize(self, opt): + BaseModel.initialize(self, opt) + + self.opt = opt + self.scaler = GradScaler(enabled=opt.use_amp) + + self.loss_names = [ + 'heatmap_left', 'heatmap_right', + 'heatmap_left_rec', 'heatmap_right_rec', + 'pose', 'cos_sim', + ] + + if self.isTrain: + self.visual_names = [ + 'input_rgb_left', 'input_rgb_right', + 'pred_heatmap_left', 'pred_heatmap_right', + 'gt_heatmap_left', 'gt_heatmap_right', + 'pred_heatmap_left_rec', 'pred_heatmap_right_rec' + ] + else: + self.visual_names = [ + # 'input_rgb_left', 'input_rgb_right', + 'pred_heatmap_left', 'pred_heatmap_right', + 'gt_heatmap_left', 'gt_heatmap_right', + ] + + self.visual_pose_names = [ + "pred_pose", "gt_pose" + ] + + if self.isTrain: + self.model_names = ['HeatMap_left', 'HeatMap_right', 'AutoEncoder'] + else: + self.model_names = ['HeatMap_left', 'HeatMap_right', 'AutoEncoder'] + + self.eval_key = "mpjpe" + self.cm2mm = 10 + + # define the transform network + self.net_HeatMap_left = network.define_HeatMap(opt, model=opt.model) + self.net_HeatMap_right = network.define_HeatMap(opt, model=opt.model) + self.net_AutoEncoder = network.define_AutoEncoder(opt, model=opt.model) + + # define loss functions + self.lossfunc_MSE = MSELoss() + self.lossfunc_limb = LossFuncLimb() + self.lossfunc_cos_sim = LossFuncCosSim() + self.lossfunc_MPJPE = LossFuncMPJPE() + + if self.isTrain: + # initialize optimizers + self.optimizer_HeatMap_left = torch.optim.Adam( + params=self.net_HeatMap_left.parameters(), + lr=opt.lr, + weight_decay=opt.weight_decay + ) + + self.optimizer_HeatMap_right = torch.optim.Adam( + params=self.net_HeatMap_right.parameters(), + lr=opt.lr, + weight_decay=opt.weight_decay + ) + + self.optimizer_AutoEncoder = torch.optim.Adam( + params=self.net_AutoEncoder.parameters(), + lr=opt.lr, + weight_decay=opt.weight_decay + ) + + self.optimizers = [] + self.schedulers = [] + self.optimizers.append(self.optimizer_HeatMap_left) + self.optimizers.append(self.optimizer_HeatMap_right) + self.optimizers.append(self.optimizer_AutoEncoder) + for optimizer in self.optimizers: + self.schedulers.append(network.get_scheduler(optimizer, opt)) + + # if not self.isTrain or opt.continue_train: + # self.load_networks(opt.which_epoch) + + def set_input(self, data): + self.data = data + self.input_rgb_left = data['input_rgb_left'].cuda(self.device) + self.input_rgb_right = data['input_rgb_right'].cuda(self.device) + self.gt_heatmap_left = data['gt_heatmap_left'].cuda(self.device) + self.gt_heatmap_right = data['gt_heatmap_right'].cuda(self.device) + self.gt_pose = data['gt_local_pose'].cuda(self.device) + + def forward(self): + with autocast(enabled=self.opt.use_amp): + self.pred_heatmap_left = self.net_HeatMap_left(self.input_rgb_left) + self.pred_heatmap_right = self.net_HeatMap_right(self.input_rgb_right) + + pred_heatmap_cat = torch.cat([self.pred_heatmap_left, self.pred_heatmap_right], dim=1) + + self.pred_pose, pred_heatmap_rec_cat = self.net_AutoEncoder(pred_heatmap_cat) + + self.pred_heatmap_left_rec, self.pred_heatmap_right_rec = torch.chunk(pred_heatmap_rec_cat, 2, dim=1) + + def backward_HeatMap(self): + with autocast(enabled=self.opt.use_amp): + loss_heatmap_left = self.lossfunc_MSE( + self.pred_heatmap_left, self.gt_heatmap_left + ) + loss_heatmap_right = self.lossfunc_MSE( + self.pred_heatmap_right, self.gt_heatmap_right + ) + + self.loss_heatmap_left = loss_heatmap_left * self.opt.lambda_heatmap + self.loss_heatmap_right = loss_heatmap_right * self.opt.lambda_heatmap + + loss_total = self.loss_heatmap_left + self.loss_heatmap_right + + self.scaler.scale(loss_total).backward(retain_graph=True) + + def backward_AutoEncoder(self): + with autocast(enabled=self.opt.use_amp): + loss_pose = self.lossfunc_MPJPE(self.pred_pose, self.gt_pose) + loss_cos_sim = self.lossfunc_cos_sim(self.pred_pose, self.gt_pose) + loss_heatmap_left_rec = self.lossfunc_MSE( + self.pred_heatmap_left_rec, self.pred_heatmap_left.detach() + ) + loss_heatmap_right_rec = self.lossfunc_MSE( + self.pred_heatmap_right_rec, self.pred_heatmap_right.detach() + ) + + self.loss_pose = loss_pose * self.opt.lambda_mpjpe + self.loss_cos_sim = loss_cos_sim * self.opt.lambda_cos_sim * self.opt.lambda_mpjpe + self.loss_heatmap_left_rec = loss_heatmap_left_rec * self.opt.lambda_heatmap_rec + self.loss_heatmap_right_rec = loss_heatmap_right_rec * self.opt.lambda_heatmap_rec + + loss_total = self.loss_pose + self.loss_cos_sim + \ + self.loss_heatmap_left_rec + self.loss_heatmap_right_rec + + self.scaler.scale(loss_total).backward() + + def optimize_parameters(self): + + # set model trainable + self.net_HeatMap_left.train() + self.net_HeatMap_right.train() + self.net_AutoEncoder.train() + + # set optimizer.zero_grad() + self.optimizer_HeatMap_left.zero_grad() + self.optimizer_HeatMap_right.zero_grad() + self.optimizer_AutoEncoder.zero_grad() + + # forward + self.forward() + + # backward + self.backward_HeatMap() + self.backward_AutoEncoder() + + # optimizer step + self.scaler.step(self.optimizer_HeatMap_left) + self.scaler.step(self.optimizer_HeatMap_right) + self.scaler.step(self.optimizer_AutoEncoder) + + self.scaler.update() + + def evaluate(self, runnning_average_dict): + # set evaluation mode + self.net_HeatMap_left.eval() + self.net_HeatMap_right.eval() + self.net_AutoEncoder.eval() + + # forward pass + self.pred_heatmap_left = self.net_HeatMap_left(self.input_rgb_left) + self.pred_heatmap_right = self.net_HeatMap_right(self.input_rgb_right) + pred_heatmap_cat = torch.cat([self.pred_heatmap_left, self.pred_heatmap_right], dim=1) + self.pred_pose = self.net_AutoEncoder.predict_pose(pred_heatmap_cat) + + S1_hat = batch_compute_similarity_transform_torch(self.pred_pose, self.gt_pose) + + # compute metrics + for id in range(self.pred_pose.size()[0]): # batch size + # calculate mpjpe and p_mpjpe # cm to mm + mpjpe = self.lossfunc_MPJPE(self.pred_pose[id], self.gt_pose[id]) * self.cm2mm + pa_mpjpe = self.lossfunc_MPJPE(S1_hat[id], self.gt_pose[id]) * self.cm2mm + + # update metrics dict + runnning_average_dict.update(dict( + mpjpe=mpjpe, + pa_mpjpe=pa_mpjpe) + ) + + return runnning_average_dict \ No newline at end of file diff --git a/phc/learning/unrealego/models.py b/phc/learning/unrealego/models.py new file mode 100644 index 0000000..c16384d --- /dev/null +++ b/phc/learning/unrealego/models.py @@ -0,0 +1,21 @@ + +def create_model(opt): + print(opt.model) + + if opt.model == 'egoglass': + from .egoglass_model import EgoGlassModel + model = EgoGlassModel() + + elif opt.model == "unrealego_heatmap_shared": + from .unrealego_heatmap_shared_model import UnrealEgoHeatmapSharedModel + model = UnrealEgoHeatmapSharedModel() + + elif opt.model == 'unrealego_autoencoder': + from .unrealego_autoencoder_model import UnrealEgoAutoEncoderModel + model = UnrealEgoAutoEncoderModel() + + else: + raise ValueError('Model [%s] not recognized.' % opt.model) + model.initialize(opt) + print("model [%s] was created." % (model.name())) + return model \ No newline at end of file diff --git a/phc/learning/unrealego/network.py b/phc/learning/unrealego/network.py new file mode 100644 index 0000000..24418d4 --- /dev/null +++ b/phc/learning/unrealego/network.py @@ -0,0 +1,619 @@ +from re import X +from turtle import forward +import torch +import torch.nn as nn +from torch.nn import init +from torch.nn.utils import weight_norm +import functools +from torchvision import models +import torch.nn.functional as F +from torch.optim import lr_scheduler +from collections import OrderedDict +import math +from easydict import EasyDict as edict +from torch.nn import MSELoss +from tqdm import tqdm + +###################################################################################### +# Functions +###################################################################################### +def get_norm_layer(norm_type='batch'): + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, affine=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) + elif norm_type == 'none': + norm_layer = None + else: + raise NotImplementedError('normalization layer [%s] is not found' % norm_type) + return norm_layer + + +def get_nonlinearity_layer(activation_type='PReLU'): + if activation_type == 'ReLU': + nonlinearity_layer = nn.ReLU(True) + elif activation_type == 'SELU': + nonlinearity_layer = nn.SELU(True) + elif activation_type == 'LeakyReLU': + nonlinearity_layer = nn.LeakyReLU(0.2, True) + elif activation_type == 'PReLU': + nonlinearity_layer = nn.PReLU() + else: + raise NotImplementedError('activation layer [%s] is not found' % activation_type) + return nonlinearity_layer + + +def get_scheduler(optimizer, opt): + if opt.lr_policy == 'lambda': + def lambda_rule(epoch): + # lr_l = 1.0 - max(0, epoch+1+1+opt.epoch_count-opt.niter) / float(opt.niter_decay+1) + lr_l = 1.0 - max(0, epoch+opt.epoch_count-opt.niter) / float(opt.niter_decay+1) + return lr_l + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters_step, gamma=0.1) + elif opt.lr_policy == 'exponent': + scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.95) + else: + raise NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) + return scheduler + + +def init_weights(net, init_type='normal', gain=0.02): + def init_func(m): + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm2d') != -1: + init.uniform_(m.weight.data, gain, 1.0) + init.constant_(m.bias.data, 0.0) + + print('initialize network with %s' % init_type) + net.apply(init_func) + + +def print_network_param(net, name): + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + + print('total number of parameters of {}: {:.3f} M'.format(name, num_params / 1e6)) + + +def init_net(net, init_type='normal', gpu_ids=[], init_ImageNet=True): + + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + # net = torch.nn.DataParallel(net, gpu_ids) + net.cuda() + + if init_ImageNet is False: + init_weights(net, init_type) + else: + init_weights(net.after_backbone, init_type) + print(' ... also using ImageNet initialization for the backbone') + + return net + + +def _freeze(*args): + for module in args: + if module: + for p in module.parameters(): + p.requires_grad = False + + +def _unfreeze(*args): + for module in args: + if module: + for p in module.parameters(): + p.requires_grad = True + +def freeze_bn(m): + classname = m.__class__.__name__ + if classname.find('BatchNorm') != -1: + m.eval() + m.weight.requires_grad = False + m.bias.requires_grad = False + +def unfreeze_bn(m): + classname = m.__class__.__name__ + if classname.find('BatchNorm') != -1: + m.train() + m.weight.requires_grad = True + m.bias.requires_grad = True + +def freeze_bn_affine(m): + classname = m.__class__.__name__ + if classname.find('BatchNorm') != -1: + m.weight.requires_grad = False + m.bias.requires_grad = False + + +###################################################################################### +# Define networks +###################################################################################### + +def define_HeatMap(opt, model): + + if model == 'egoglass': + net = HeatMap_EgoGlass(opt) + elif model == "unrealego_heatmap_shared": + net = HeatMap_UnrealEgo_Shared(opt) + elif model == "unrealego_autoencoder": + net = HeatMap_UnrealEgo_Shared(opt) + + print_network_param(net, 'HeatMap_Estimator for {}'.format(model)) + + return init_net(net, opt.init_type, opt.gpu_ids, opt.init_ImageNet) + +def define_AutoEncoder(opt, model): + + if model == 'egoglass': + net = AutoEncoder(opt, input_channel_scale=2) + elif model == "unrealego_autoencoder": + net = AutoEncoder(opt, input_channel_scale=2) + + print_network_param(net, 'AutoEncoder for {}'.format(model)) + + return init_net(net, opt.init_type, opt.gpu_ids, False) + + +###################################################################################### +# Basic Operation +###################################################################################### + + +def make_conv_layer(in_channels, out_channels, kernel_size, stride, padding, with_bn=True): + conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, padding=padding) + # torch.nn.init.xavier_normal_(conv.weight) + # conv = weight_norm(conv) + bn = torch.nn.BatchNorm2d(num_features=out_channels) + relu = torch.nn.LeakyReLU(negative_slope=0.2) + if with_bn: + return torch.nn.Sequential(conv, bn, relu) + else: + return torch.nn.Sequential(conv, relu) + +def make_deconv_layer(in_channels, out_channels, kernel_size, stride, padding, with_bn=True): + conv = torch.nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, padding=padding) + # torch.nn.init.xavier_normal_(conv.weight) + # conv = weight_norm(conv) + bn = torch.nn.BatchNorm2d(num_features=out_channels) + relu = torch.nn.LeakyReLU(negative_slope=0.2) + if with_bn: + return torch.nn.Sequential(conv, bn, relu) + else: + return torch.nn.Sequential(conv, relu) + +def make_fc_layer(in_feature, out_feature, with_relu=True, with_bn=True): + modules = OrderedDict() + fc = torch.nn.Linear(in_feature, out_feature) + # torch.nn.init.xavier_normal_(fc.weight) + # fc = weight_norm(fc) + modules['fc'] = fc + bn = torch.nn.BatchNorm1d(num_features=out_feature) + relu = torch.nn.LeakyReLU(negative_slope=0.2) + + if with_bn is True: + modules['bn'] = bn + else: + print('no bn') + + if with_relu is True: + modules['relu'] = relu + else: + print('no pose relu') + + return torch.nn.Sequential(modules) + +def convrelu(in_channels, out_channels, kernel, padding): + return nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel, padding=padding), + nn.ReLU(inplace=True), + ) + +###################################################################################### +# Network structure +###################################################################################### + + +############################## EgoGlass ############################## + + +class HeatMap_EgoGlass(nn.Module): + def __init__(self, opt, model_name='resnet18'): + super(HeatMap_EgoGlass, self).__init__() + + self.backbone = HeatMap_EgoGlass_Backbone(opt, model_name=model_name) + self.after_backbone = HeatMap_EgoGlass_AfterBackbone(opt) + + def forward(self, input): + + x = self.backbone(input) + output = self.after_backbone(x) + + return output + + +class HeatMap_EgoGlass_Backbone(nn.Module): + def __init__(self, opt, model_name='resnet18'): + super(HeatMap_EgoGlass_Backbone, self).__init__() + + if model_name == 'resnet18': + self.backbone = models.resnet18(pretrained=opt.init_ImageNet) + elif model_name == "resnet34": + self.backbone = models.resnet34(pretrained=opt.init_ImageNet) + elif model_name == "resnet50": + self.backbone = models.resnet50(pretrained=opt.init_ImageNet) + elif model_name == "resnet101": + self.backbone = models.resnet101(pretrained=opt.init_ImageNet) + else: + raise NotImplementedError('model type [%s] is invalid', model_name) + + self.base_layers = list(self.backbone.children()) + self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2) + self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4) + self.layer2 = self.base_layers[5] # size=(N, 128, x.H/8, x.W/8) + self.layer3 = self.base_layers[6] # size=(N, 256, x.H/16, x.W/16) + self.layer4 = self.base_layers[7] # size=(N, 512, x.H/32, x.W/32) + + def forward(self, input): + + layer0 = self.layer0(input) + layer1 = self.layer1(layer0) + layer2 = self.layer2(layer1) + layer3 = self.layer3(layer2) + layer4 = self.layer4(layer3) + + output = [input, layer0, layer1, layer2, layer3, layer4] + + return output + + +class HeatMap_EgoGlass_AfterBackbone(nn.Module): + def __init__(self, opt): + super(HeatMap_EgoGlass_AfterBackbone, self).__init__() + + self.num_heatmap = opt.num_heatmap + + self.layer0_1x1 = convrelu(64, 64, 1, 0) + self.layer1_1x1 = convrelu(64, 64, 1, 0) + self.layer2_1x1 = convrelu(128, 128, 1, 0) + self.layer3_1x1 = convrelu(256, 256, 1, 0) + self.layer4_1x1 = convrelu(512, 512, 1, 0) + + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + + self.conv_up3 = convrelu(256 + 512, 512, 3, 1) + self.conv_up2 = convrelu(128 + 512, 256, 3, 1) + self.conv_up1 = convrelu(64 + 256, 256, 3, 1) + self.conv_up0 = convrelu(64 + 256, 128, 3, 1) + + self.conv_original_size0 = convrelu(3, 64, 3, 1) + self.conv_original_size1 = convrelu(64, 64, 3, 1) + self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1) + + self.conv_heatmap = nn.Conv2d(256, self.num_heatmap, 1) + + + def forward(self, list_input): + + input = list_input[0] + layer0 = list_input[1] + layer1 = list_input[2] + layer2 = list_input[3] + layer3 = list_input[4] + layer4 = list_input[5] + + layer4 = self.layer4_1x1(layer4) + x = self.upsample(layer4) + layer3 = self.layer3_1x1(layer3) + x = torch.cat([x, layer3], dim=1) + x = self.conv_up3(x) + + x = self.upsample(x) + layer2 = self.layer2_1x1(layer2) + x = torch.cat([x, layer2], dim=1) + x = self.conv_up2(x) + + x = self.upsample(x) + layer1 = self.layer1_1x1(layer1) + x = torch.cat([x, layer1], dim=1) + x = self.conv_up1(x) + + output = self.conv_heatmap(x) + + return output + + +############################## UnrealEgo ############################## + +class HeatMap_UnrealEgo_Shared(nn.Module): + def __init__(self, opt, model_name='resnet18'): + super(HeatMap_UnrealEgo_Shared, self).__init__() + + self.backbone = HeatMap_UnrealEgo_Shared_Backbone(opt, model_name=model_name) + self.after_backbone = HeatMap_UnrealEgo_AfterBackbone(opt, model_name=model_name) + + def forward(self, input_img): + output_0 = self.backbone(input_img[:, 0:1]) + output_1 = self.backbone(input_img[:, 1:2]) + output_2 = self.backbone(input_img[:, 2:3]) + output_3 = self.backbone(input_img[:, 3:4]) + cat_features = [torch.cat([output_0[id], output_1[id], output_2[id], output_3[id]], dim=1) for id in range(len(output_0))] + output = self.after_backbone(cat_features) + # import ipdb; ipdb.set_trace() + + # cat_features = [torch.cat([output_0[id], output_3[id]], dim=1) for id in range(len(output_0))] + # output = self.after_backbone(cat_features) + + return output + + def forward_feat(self, input_img): + output_0 = self.backbone(input_img[:, 0:1]) + output_1 = self.backbone(input_img[:, 1:2]) + output_2 = self.backbone(input_img[:, 2:3]) + output_3 = self.backbone(input_img[:, 3:4]) + cat_features = [torch.cat([output_0[id], output_1[id], output_2[id], output_3[id]], dim=1) for id in range(len(output_0))] + output = self.after_backbone(cat_features) + + return output, cat_features + + + def forward_feat_full(self, input_img): + output_0 = self.backbone(input_img[:, 0:1]) + output_1 = self.backbone(input_img[:, 1:2]) + output_2 = self.backbone(input_img[:, 2:3]) + output_3 = self.backbone(input_img[:, 3:4]) + cat_features = [torch.cat([output_0[id], output_1[id], output_2[id], output_3[id]], dim=1) for id in range(len(output_0))] + output, full_feat = self.after_backbone(cat_features, return_feat=True) + + return output, full_feat + + +class HeatMap_UnrealEgo_Shared_Backbone(nn.Module): + def __init__(self, opt, model_name='resnet18'): + super(HeatMap_UnrealEgo_Shared_Backbone, self).__init__() + self.backbone = Encoder_Block(opt, model_name=model_name) + + def forward(self, input): + output = self.backbone(input) + return output + +class Encoder_Block(nn.Module): + def __init__(self, opt, model_name='resnet18'): + super(Encoder_Block, self).__init__() + if model_name == 'resnet18': + self.backbone = models.resnet18(pretrained=opt.init_ImageNet) + self.backbone.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) + elif model_name == "resnet34": + self.backbone = models.resnet34(pretrained=opt.init_ImageNet) + elif model_name == "resnet50": + self.backbone = models.resnet50(pretrained=opt.init_ImageNet) + elif model_name == "resnet101": + self.backbone = models.resnet101(pretrained=opt.init_ImageNet) + else: + raise NotImplementedError('model type [%s] is invalid', model_name) + + self.base_layers = list(self.backbone.children()) + self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2) + self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4) + self.layer2 = self.base_layers[5] # size=(N, 128, x.H/8, x.W/8) + self.layer3 = self.base_layers[6] # size=(N, 256, x.H/16, x.W/16) + self.layer4 = self.base_layers[7] # size=(N, 512, x.H/32, x.W/32) + + def forward(self, input): + + layer0 = self.layer0(input) + layer1 = self.layer1(layer0) + layer2 = self.layer2(layer1) + layer3 = self.layer3(layer2) + layer4 = self.layer4(layer3) + + output = [input, layer0, layer1, layer2, layer3, layer4] + + return output + + +class HeatMap_UnrealEgo_AfterBackbone(nn.Module): + def __init__(self, opt, model_name="resnet18"): + super(HeatMap_UnrealEgo_AfterBackbone, self).__init__() + + if model_name == 'resnet18': + feature_scale = 1 + elif model_name == "resnet34": + feature_scale = 1 + elif model_name == "resnet50": + feature_scale = 4 + elif model_name == "resnet101": + feature_scale = 4 + else: + raise NotImplementedError('model type [%s] is invalid', model_name) + scale = 2 + self.num_heatmap = opt.num_heatmap + + # self.layer0_1x1 = convrelu(128, 128, 1, 0) + self.layer1_1x1 = convrelu(128 * feature_scale * scale, 128 * feature_scale * scale, 1, 0) + self.layer2_1x1 = convrelu(256 * feature_scale * scale, 256 * feature_scale * scale, 1, 0) + self.layer3_1x1 = convrelu(512 * feature_scale * scale, 516 * feature_scale * scale, 1, 0) + self.layer4_1x1 = convrelu(1024 * feature_scale * scale, 1024 * feature_scale * scale, 1, 0) + + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + + self.conv_up3 = convrelu(516 * feature_scale * scale + 1024 * feature_scale * scale, 1024 * feature_scale * scale, 3, 1) + self.conv_up2 = convrelu(256 * feature_scale * scale + 1024 * feature_scale * scale, 512 * feature_scale * scale, 3, 1) + self.conv_up1 = convrelu(128 * feature_scale * scale + 512 * feature_scale * scale, 512 * feature_scale * scale, 3, 1) + + self.conv_heatmap = nn.Conv2d(512 * feature_scale * scale, self.num_heatmap * 2 * scale, 1) + + def forward(self, list_stereo_feature, return_feat = False): + # UNet skip connections + + input = list_stereo_feature[0] + layer0 = list_stereo_feature[1] + layer1 = list_stereo_feature[2] + layer2 = list_stereo_feature[3] + layer3 = list_stereo_feature[4] + layer4 = list_stereo_feature[5] + + layer4 = self.layer4_1x1(layer4) + x = self.upsample(layer4) + layer3 = self.layer3_1x1(layer3) + x = torch.cat([x, layer3], dim=1) + x = self.conv_up3(x) + + x = self.upsample(x) + layer2 = self.layer2_1x1(layer2) + x = torch.cat([x, layer2], dim=1) + x = self.conv_up2(x) + + x = self.upsample(x) + layer1 = self.layer1_1x1(layer1) + x = torch.cat([x, layer1], dim=1) + x_prev = self.conv_up1(x) + + output = self.conv_heatmap(x_prev) + if return_feat: + return output, x_prev + else: + return output + + +############################## AutoEncoder ############################## + + +class AutoEncoder(nn.Module): + + def __init__(self, opt, input_channel_scale=1, fc_dim=16384, num_add_joints = 2): + super(AutoEncoder, self).__init__() + + self.hidden_size = opt.ae_hidden_size + self.with_bn = True + self.with_pose_relu = True + + self.num_heatmap = opt.num_heatmap + self.channels_heatmap = self.num_heatmap * input_channel_scale + self.fc_dim = fc_dim + self.num_add_joints = num_add_joints + + self.conv1 = make_conv_layer(in_channels=self.channels_heatmap, out_channels=64, kernel_size=4, stride=2, padding=1, with_bn=self.with_bn) + self.conv2 = make_conv_layer(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, with_bn=self.with_bn) + self.conv3 = make_conv_layer(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, with_bn=self.with_bn) + + # self.fc1 = make_fc_layer(in_feature=18432, out_feature=2048, with_bn=self.with_bn) + self.fc1 = make_fc_layer(in_feature=self.fc_dim, out_feature=2048, with_bn=self.with_bn) + self.fc2 = make_fc_layer(in_feature=2048, out_feature=512, with_bn=self.with_bn) + self.fc3 = make_fc_layer(in_feature=512, out_feature=self.hidden_size, with_bn=self.with_bn) + + ## pose decoder + self.pose_fc1 = make_fc_layer(self.hidden_size, 32, with_relu=self.with_pose_relu, with_bn=self.with_bn) + self.pose_fc2 = make_fc_layer(32, 32, with_relu=self.with_pose_relu, with_bn=self.with_bn) + self.pose_fc3 = torch.nn.Linear(32, (self.num_heatmap + self.num_add_joints) * 3) + + # heatmap decoder + self.heatmap_fc1 = make_fc_layer(self.hidden_size, 512, with_bn=self.with_bn) + self.heatmap_fc2 = make_fc_layer(512, 2048, with_bn=self.with_bn) + # self.heatmap_fc3 = make_fc_layer(2048, 18432, with_bn=self.with_bn) + self.heatmap_fc3 = make_fc_layer(2048, self.fc_dim, with_bn=self.with_bn) + + self.deconv1 = make_deconv_layer(256, 128, kernel_size=4, stride=2, padding=1, with_bn=self.with_bn) + self.deconv2 = make_deconv_layer(128, 64, kernel_size=4, stride=2, padding=1, with_bn=self.with_bn) + self.deconv3 = make_deconv_layer(64, self.channels_heatmap, kernel_size=4, stride=2, padding=1, with_bn=self.with_bn) + + def predict_pose(self, input): + batch_size = input.size()[0] + + # encode heatmap + x = self.conv1(input) + x = self.conv2(x) + x = self.conv3(x) + batch_size = x.shape[0] + x = x.view(batch_size, -1) + x = self.fc1(x) + x = self.fc2(x) + z = self.fc3(x) + + # decode pose + x_pose = self.pose_fc1(z) + x_pose = self.pose_fc2(x_pose) + output_pose = self.pose_fc3(x_pose) + + return output_pose.view(batch_size, self.num_heatmap, 3) + + + def forward(self, input): + batch_size, C, W, H = input.shape + + # encode heatmap + x = self.conv1(input) + x = self.conv2(x) + x = self.conv3(x) + x = x.view(batch_size, -1) + x = self.fc1(x) + x = self.fc2(x) + z = self.fc3(x) + + # decode pose + x_pose = self.pose_fc1(z) + x_pose = self.pose_fc2(x_pose) + output_pose = self.pose_fc3(x_pose) + + # decode heatmap + x_hm = self.heatmap_fc1(z) + x_hm = self.heatmap_fc2(x_hm) + x_hm = self.heatmap_fc3(x_hm) + x_hm = x_hm.view(batch_size, 256, W//8, H//8) + x_hm = self.deconv1(x_hm) + x_hm = self.deconv2(x_hm) + + output_hm = self.deconv3(x_hm) + + return output_pose.view(batch_size, self.num_heatmap + self.num_add_joints , 3), output_hm + +if __name__ == "__main__": + + opt = edict() + opt.init_ImageNet = True + opt.num_heatmap = 25 + opt.ae_hidden_size = 20 + + net_heatmap = HeatMap_UnrealEgo_Shared(opt=opt, model_name='resnet18') + optimizer = torch.optim.Adam(net_heatmap.parameters(), lr=1e-4) + net_heatmap.train() + ae = AutoEncoder(opt, input_channel_scale = 4, fc_dim= 5120) + + rand_input = torch.rand(3, 4, 128, 160) + lossfunc_MSE = MSELoss() + + for _ in tqdm(range(1)): + heatmaps = net_heatmap(rand_input) + # loss = lossfunc_MSE(heatmaps, torch.zeros_like(heatmaps)) + loss = torch.sqrt((heatmaps - torch.zeros_like(heatmaps)) ** 2).mean() + optimizer.zero_grad() + loss.backward() + optimizer.step() + print(loss.item()) + + + + + pose, output_hm = ae(heatmaps) + import ipdb; ipdb.set_trace() + + heatmaps = torch.chunk(outputs, 4, dim=1) + print(pred_heatmap_left.size()) + print(pred_heatmap_right.size()) diff --git a/phc/learning/unrealego/network_debug.py b/phc/learning/unrealego/network_debug.py new file mode 100644 index 0000000..87a5c3e --- /dev/null +++ b/phc/learning/unrealego/network_debug.py @@ -0,0 +1,571 @@ +from re import X +from turtle import forward +import torch +import torch.nn as nn +from torch.nn import init +from torch.nn.utils import weight_norm +import functools +from torchvision import models +import torch.nn.functional as F +from torch.optim import lr_scheduler +from collections import OrderedDict +import math + + +###################################################################################### +# Functions +###################################################################################### +def get_norm_layer(norm_type='batch'): + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, affine=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) + elif norm_type == 'none': + norm_layer = None + else: + raise NotImplementedError('normalization layer [%s] is not found' % norm_type) + return norm_layer + + +def get_nonlinearity_layer(activation_type='PReLU'): + if activation_type == 'ReLU': + nonlinearity_layer = nn.ReLU(True) + elif activation_type == 'SELU': + nonlinearity_layer = nn.SELU(True) + elif activation_type == 'LeakyReLU': + nonlinearity_layer = nn.LeakyReLU(0.2, True) + elif activation_type == 'PReLU': + nonlinearity_layer = nn.PReLU() + else: + raise NotImplementedError('activation layer [%s] is not found' % activation_type) + return nonlinearity_layer + + +def get_scheduler(optimizer, opt): + if opt.lr_policy == 'lambda': + def lambda_rule(epoch): + # lr_l = 1.0 - max(0, epoch+1+1+opt.epoch_count-opt.niter) / float(opt.niter_decay+1) + lr_l = 1.0 - max(0, epoch+opt.epoch_count-opt.niter) / float(opt.niter_decay+1) + return lr_l + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters_step, gamma=0.1) + elif opt.lr_policy == 'exponent': + scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.95) + else: + raise NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) + return scheduler + + +def init_weights(net, init_type='normal', gain=0.02): + def init_func(m): + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm2d') != -1: + init.uniform_(m.weight.data, gain, 1.0) + init.constant_(m.bias.data, 0.0) + + print('initialize network with %s' % init_type) + net.apply(init_func) + + +def print_network_param(net, name): + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + + print('total number of parameters of {}: {:.3f} M'.format(name, num_params / 1e6)) + + +def init_net(net, init_type='normal', gpu_ids=[], init_ImageNet=True): + + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + # net = torch.nn.DataParallel(net, gpu_ids) + net.cuda() + + if init_ImageNet is False: + init_weights(net, init_type) + else: + init_weights(net.after_backbone, init_type) + print(' ... also using ImageNet initialization for the backbone') + + return net + + +def _freeze(*args): + for module in args: + if module: + for p in module.parameters(): + p.requires_grad = False + + +def _unfreeze(*args): + for module in args: + if module: + for p in module.parameters(): + p.requires_grad = True + +def freeze_bn(m): + classname = m.__class__.__name__ + if classname.find('BatchNorm') != -1: + m.eval() + m.weight.requires_grad = False + m.bias.requires_grad = False + +def unfreeze_bn(m): + classname = m.__class__.__name__ + if classname.find('BatchNorm') != -1: + m.train() + m.weight.requires_grad = True + m.bias.requires_grad = True + +def freeze_bn_affine(m): + classname = m.__class__.__name__ + if classname.find('BatchNorm') != -1: + m.weight.requires_grad = False + m.bias.requires_grad = False + + +###################################################################################### +# Define networks +###################################################################################### + +def define_HeatMap(opt, model): + + if model == 'egoglass': + net = HeatMap_EgoGlass(opt) + elif model == "unrealego_heatmap_shared": + net = HeatMap_UnrealEgo_Shared(opt) + elif model == "unrealego_autoencoder": + net = HeatMap_UnrealEgo_Shared(opt) + + print_network_param(net, 'HeatMap_Estimator for {}'.format(model)) + + return init_net(net, opt.init_type, opt.gpu_ids, opt.init_ImageNet) + +def define_AutoEncoder(opt, model): + + if model == 'egoglass': + net = AutoEncoder(opt, input_channel_scale=2) + elif model == "unrealego_autoencoder": + net = AutoEncoder(opt, input_channel_scale=2) + + print_network_param(net, 'AutoEncoder for {}'.format(model)) + + return init_net(net, opt.init_type, opt.gpu_ids, False) + + +###################################################################################### +# Basic Operation +###################################################################################### + + +def make_conv_layer(in_channels, out_channels, kernel_size, stride, padding, with_bn=True): + conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, padding=padding) + # torch.nn.init.xavier_normal_(conv.weight) + # conv = weight_norm(conv) + bn = torch.nn.BatchNorm2d(num_features=out_channels) + relu = torch.nn.LeakyReLU(negative_slope=0.2) + if with_bn: + return torch.nn.Sequential(conv, bn, relu) + else: + return torch.nn.Sequential(conv, relu) + +def make_deconv_layer(in_channels, out_channels, kernel_size, stride, padding, with_bn=True): + conv = torch.nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, padding=padding) + # torch.nn.init.xavier_normal_(conv.weight) + # conv = weight_norm(conv) + bn = torch.nn.BatchNorm2d(num_features=out_channels) + relu = torch.nn.LeakyReLU(negative_slope=0.2) + if with_bn: + return torch.nn.Sequential(conv, bn, relu) + else: + return torch.nn.Sequential(conv, relu) + +def make_fc_layer(in_feature, out_feature, with_relu=True, with_bn=True): + modules = OrderedDict() + fc = torch.nn.Linear(in_feature, out_feature) + # torch.nn.init.xavier_normal_(fc.weight) + # fc = weight_norm(fc) + modules['fc'] = fc + bn = torch.nn.BatchNorm1d(num_features=out_feature) + relu = torch.nn.LeakyReLU(negative_slope=0.2) + + if with_bn is True: + modules['bn'] = bn + else: + print('no bn') + + if with_relu is True: + modules['relu'] = relu + else: + print('no pose relu') + + return torch.nn.Sequential(modules) + +def convrelu(in_channels, out_channels, kernel, padding): + return nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel, padding=padding), + nn.ReLU(inplace=True), + ) + +###################################################################################### +# Network structure +###################################################################################### + + +############################## EgoGlass ############################## + + +class HeatMap_EgoGlass(nn.Module): + def __init__(self, opt, model_name='resnet18'): + super(HeatMap_EgoGlass, self).__init__() + + self.backbone = HeatMap_EgoGlass_Backbone(opt, model_name=model_name) + self.after_backbone = HeatMap_EgoGlass_AfterBackbone(opt) + + def forward(self, input): + + x = self.backbone(input) + output = self.after_backbone(x) + + return output + + +class HeatMap_EgoGlass_Backbone(nn.Module): + def __init__(self, opt, model_name='resnet18'): + super(HeatMap_EgoGlass_Backbone, self).__init__() + + if model_name == 'resnet18': + self.backbone = models.resnet18(pretrained=opt.init_ImageNet) + elif model_name == "resnet34": + self.backbone = models.resnet34(pretrained=opt.init_ImageNet) + elif model_name == "resnet50": + self.backbone = models.resnet50(pretrained=opt.init_ImageNet) + elif model_name == "resnet101": + self.backbone = models.resnet101(pretrained=opt.init_ImageNet) + else: + raise NotImplementedError('model type [%s] is invalid', model_name) + + self.base_layers = list(self.backbone.children()) + self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2) + self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4) + self.layer2 = self.base_layers[5] # size=(N, 128, x.H/8, x.W/8) + self.layer3 = self.base_layers[6] # size=(N, 256, x.H/16, x.W/16) + self.layer4 = self.base_layers[7] # size=(N, 512, x.H/32, x.W/32) + + def forward(self, input): + + layer0 = self.layer0(input) + layer1 = self.layer1(layer0) + layer2 = self.layer2(layer1) + layer3 = self.layer3(layer2) + layer4 = self.layer4(layer3) + + output = [input, layer0, layer1, layer2, layer3, layer4] + + return output + + +class HeatMap_EgoGlass_AfterBackbone(nn.Module): + def __init__(self, opt): + super(HeatMap_EgoGlass_AfterBackbone, self).__init__() + + self.num_heatmap = opt.num_heatmap + + self.layer0_1x1 = convrelu(64, 64, 1, 0) + self.layer1_1x1 = convrelu(64, 64, 1, 0) + self.layer2_1x1 = convrelu(128, 128, 1, 0) + self.layer3_1x1 = convrelu(256, 256, 1, 0) + self.layer4_1x1 = convrelu(512, 512, 1, 0) + + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + + self.conv_up3 = convrelu(256 + 512, 512, 3, 1) + self.conv_up2 = convrelu(128 + 512, 256, 3, 1) + self.conv_up1 = convrelu(64 + 256, 256, 3, 1) + self.conv_up0 = convrelu(64 + 256, 128, 3, 1) + + self.conv_original_size0 = convrelu(3, 64, 3, 1) + self.conv_original_size1 = convrelu(64, 64, 3, 1) + self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1) + + self.conv_heatmap = nn.Conv2d(256, self.num_heatmap, 1) + + + def forward(self, list_input): + + input = list_input[0] + layer0 = list_input[1] + layer1 = list_input[2] + layer2 = list_input[3] + layer3 = list_input[4] + layer4 = list_input[5] + + layer4 = self.layer4_1x1(layer4) + x = self.upsample(layer4) + layer3 = self.layer3_1x1(layer3) + x = torch.cat([x, layer3], dim=1) + x = self.conv_up3(x) + + x = self.upsample(x) + layer2 = self.layer2_1x1(layer2) + x = torch.cat([x, layer2], dim=1) + x = self.conv_up2(x) + + x = self.upsample(x) + layer1 = self.layer1_1x1(layer1) + x = torch.cat([x, layer1], dim=1) + x = self.conv_up1(x) + + output = self.conv_heatmap(x) + + return output + + +############################## UnrealEgo ############################## + +class HeatMap_UnrealEgo_Shared(nn.Module): + def __init__(self, opt, model_name='resnet18'): + super(HeatMap_UnrealEgo_Shared, self).__init__() + + self.backbone = HeatMap_UnrealEgo_Shared_Backbone(opt, model_name=model_name) + self.after_backbone = HeatMap_UnrealEgo_AfterBackbone(opt, model_name=model_name) + + def forward(self, input_left, input_right): + + x_left, x_right = self.backbone(input_left, input_right) + output = self.after_backbone(x_left, x_right) + + return output + + +class HeatMap_UnrealEgo_Shared_Backbone(nn.Module): + def __init__(self, opt, model_name='resnet18'): + super(HeatMap_UnrealEgo_Shared_Backbone, self).__init__() + + self.backbone = Encoder_Block(opt, model_name=model_name) + + def forward(self, input_left, input_right): + output_left = self.backbone(input_left) + output_right = self.backbone(input_right) + + return output_left, output_right + +class Encoder_Block(nn.Module): + def __init__(self, opt, model_name='resnet18'): + super(Encoder_Block, self).__init__() + + if model_name == 'resnet18': + self.backbone = models.resnet18(pretrained=opt.init_ImageNet) + self.backbone.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) + elif model_name == "resnet34": + self.backbone = models.resnet34(pretrained=opt.init_ImageNet) + elif model_name == "resnet50": + self.backbone = models.resnet50(pretrained=opt.init_ImageNet) + elif model_name == "resnet101": + self.backbone = models.resnet101(pretrained=opt.init_ImageNet) + else: + raise NotImplementedError('model type [%s] is invalid', model_name) + + self.base_layers = list(self.backbone.children()) + self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2) + self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4) + self.layer2 = self.base_layers[5] # size=(N, 128, x.H/8, x.W/8) + self.layer3 = self.base_layers[6] # size=(N, 256, x.H/16, x.W/16) + self.layer4 = self.base_layers[7] # size=(N, 512, x.H/32, x.W/32) + + def forward(self, input): + + layer0 = self.layer0(input) + layer1 = self.layer1(layer0) + layer2 = self.layer2(layer1) + layer3 = self.layer3(layer2) + layer4 = self.layer4(layer3) + + output = [input, layer0, layer1, layer2, layer3, layer4] + + return output + + +class HeatMap_UnrealEgo_AfterBackbone(nn.Module): + def __init__(self, opt, model_name="resnet18"): + super(HeatMap_UnrealEgo_AfterBackbone, self).__init__() + + if model_name == 'resnet18': + feature_scale = 1 + elif model_name == "resnet34": + feature_scale = 1 + elif model_name == "resnet50": + feature_scale = 4 + elif model_name == "resnet101": + feature_scale = 4 + else: + raise NotImplementedError('model type [%s] is invalid', model_name) + + self.num_heatmap = opt.num_heatmap + + # self.layer0_1x1 = convrelu(128, 128, 1, 0) + self.layer1_1x1 = convrelu(128 * feature_scale, 128 * feature_scale, 1, 0) + self.layer2_1x1 = convrelu(256 * feature_scale, 256 * feature_scale, 1, 0) + self.layer3_1x1 = convrelu(512 * feature_scale, 516 * feature_scale, 1, 0) + self.layer4_1x1 = convrelu(1024 * feature_scale, 1024 * feature_scale, 1, 0) + + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + + self.conv_up3 = convrelu(516 * feature_scale + 1024 * feature_scale, 1024 * feature_scale, 3, 1) + self.conv_up2 = convrelu(256 * feature_scale + 1024 * feature_scale, 512 * feature_scale, 3, 1) + self.conv_up1 = convrelu(128 * feature_scale + 512 * feature_scale, 512 * feature_scale, 3, 1) + + self.conv_heatmap = nn.Conv2d(512 * feature_scale, self.num_heatmap * 2, 1) + + def forward(self, list_input_left, list_input_right): + list_stereo_feature = [ + torch.cat([list_input_left[id], list_input_right[id]], dim=1) for id in range(len(list_input_left)) + ] + + input = list_stereo_feature[0] + layer0 = list_stereo_feature[1] + layer1 = list_stereo_feature[2] + layer2 = list_stereo_feature[3] + layer3 = list_stereo_feature[4] + layer4 = list_stereo_feature[5] + + layer4 = self.layer4_1x1(layer4) + x = self.upsample(layer4) + layer3 = self.layer3_1x1(layer3) + x = torch.cat([x, layer3], dim=1) + x = self.conv_up3(x) + + x = self.upsample(x) + layer2 = self.layer2_1x1(layer2) + x = torch.cat([x, layer2], dim=1) + x = self.conv_up2(x) + + x = self.upsample(x) + layer1 = self.layer1_1x1(layer1) + x = torch.cat([x, layer1], dim=1) + x = self.conv_up1(x) + + output = self.conv_heatmap(x) + + return output + + +############################## AutoEncoder ############################## + + +class AutoEncoder(nn.Module): + + def __init__(self, opt, input_channel_scale=1, fc_dim=16384): + super(AutoEncoder, self).__init__() + + self.hidden_size = opt.ae_hidden_size + self.with_bn = True + self.with_pose_relu = True + + self.num_heatmap = opt.num_heatmap + self.channels_heatmap = self.num_heatmap * input_channel_scale + self.fc_dim = fc_dim + + self.conv1 = make_conv_layer(in_channels=self.channels_heatmap, out_channels=64, kernel_size=4, stride=2, padding=1, with_bn=self.with_bn) + self.conv2 = make_conv_layer(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, with_bn=self.with_bn) + self.conv3 = make_conv_layer(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, with_bn=self.with_bn) + + # self.fc1 = make_fc_layer(in_feature=18432, out_feature=2048, with_bn=self.with_bn) + self.fc1 = make_fc_layer(in_feature=self.fc_dim, out_feature=2048, with_bn=self.with_bn) + self.fc2 = make_fc_layer(in_feature=2048, out_feature=512, with_bn=self.with_bn) + self.fc3 = make_fc_layer(in_feature=512, out_feature=self.hidden_size, with_bn=self.with_bn) + + ## pose decoder + self.pose_fc1 = make_fc_layer(self.hidden_size, 32, with_relu=self.with_pose_relu, with_bn=self.with_bn) + self.pose_fc2 = make_fc_layer(32, 32, with_relu=self.with_pose_relu, with_bn=self.with_bn) + self.pose_fc3 = torch.nn.Linear(32, (self.num_heatmap + 1) * 3) + + # heatmap decoder + self.heatmap_fc1 = make_fc_layer(self.hidden_size, 512, with_bn=self.with_bn) + self.heatmap_fc2 = make_fc_layer(512, 2048, with_bn=self.with_bn) + # self.heatmap_fc3 = make_fc_layer(2048, 18432, with_bn=self.with_bn) + self.heatmap_fc3 = make_fc_layer(2048, self.fc_dim, with_bn=self.with_bn) + self.WH = int(math.sqrt(self.fc_dim/256)) + + self.deconv1 = make_deconv_layer(256, 128, kernel_size=4, stride=2, padding=1, with_bn=self.with_bn) + self.deconv2 = make_deconv_layer(128, 64, kernel_size=4, stride=2, padding=1, with_bn=self.with_bn) + self.deconv3 = make_deconv_layer(64, self.channels_heatmap, kernel_size=4, stride=2, padding=1, with_bn=self.with_bn) + + def predict_pose(self, input): + batch_size = input.size()[0] + + # encode heatmap + x = self.conv1(input) + x = self.conv2(x) + x = self.conv3(x) + batch_size = x.shape[0] + x = x.view(batch_size, -1) + x = self.fc1(x) + x = self.fc2(x) + z = self.fc3(x) + + # decode pose + x_pose = self.pose_fc1(z) + x_pose = self.pose_fc2(x_pose) + output_pose = self.pose_fc3(x_pose) + + return output_pose.view(batch_size, self.num_heatmap + 1, 3) + + + def forward(self, input): + batch_size = input.size()[0] + + # encode heatmap + x = self.conv1(input) + x = self.conv2(x) + x = self.conv3(x) + x = x.view(batch_size, -1) + x = self.fc1(x) + x = self.fc2(x) + z = self.fc3(x) + + # decode pose + x_pose = self.pose_fc1(z) + x_pose = self.pose_fc2(x_pose) + output_pose = self.pose_fc3(x_pose) + + # decode heatmap + x_hm = self.heatmap_fc1(z) + x_hm = self.heatmap_fc2(x_hm) + x_hm = self.heatmap_fc3(x_hm) + x_hm = x_hm.view(batch_size, 256, self.WH, self.WH) + x_hm = self.deconv1(x_hm) + x_hm = self.deconv2(x_hm) + output_hm = self.deconv3(x_hm) + + return output_pose.view(batch_size, self.num_heatmap + 1, 3), output_hm + + +if __name__ == "__main__": + + model = HeatMap_UnrealEgo_Shared(opt=None, model_name='resnet50') + + input = torch.rand(3, 3, 256, 256) + outputs = model(input, input) + pred_heatmap_left, pred_heatmap_right = torch.chunk(outputs, 2, dim=1) + + print(pred_heatmap_left.size()) + print(pred_heatmap_right.size()) + diff --git a/phc/learning/unrealego/unrealego_autoencoder_model.py b/phc/learning/unrealego/unrealego_autoencoder_model.py new file mode 100644 index 0000000..1aabb0a --- /dev/null +++ b/phc/learning/unrealego/unrealego_autoencoder_model.py @@ -0,0 +1,173 @@ +from cProfile import run +from enum import auto +import torch +import torch.nn as nn +from torch.autograd import Variable +from torch.cuda.amp import autocast, GradScaler +from torch.nn import MSELoss + +import itertools +from .base_model import BaseModel +from . import network +from utils.loss import LossFuncLimb, LossFuncCosSim, LossFuncMPJPE +from utils.util import batch_compute_similarity_transform_torch + + +class UnrealEgoAutoEncoderModel(BaseModel): + def name(self): + return 'UnrealEgo AutoEncoder model' + + def initialize(self, opt): + BaseModel.initialize(self, opt) + + self.opt = opt + self.scaler = GradScaler(enabled=opt.use_amp) + + self.loss_names = [ + 'heatmap_left_rec', 'heatmap_right_rec', + 'pose', 'cos_sim', + ] + + if self.isTrain: + self.visual_names = [ + 'input_rgb_left', 'input_rgb_right', + 'pred_heatmap_left', 'pred_heatmap_right', + 'gt_heatmap_left', 'gt_heatmap_right', + 'pred_heatmap_left_rec', 'pred_heatmap_right_rec' + ] + else: + self.visual_names = [ + # 'input_rgb_left', 'input_rgb_right', + 'pred_heatmap_left', 'pred_heatmap_right', + 'gt_heatmap_left', 'gt_heatmap_right', + ] + + self.visual_pose_names = [ + "pred_pose", "gt_pose" + ] + + if self.isTrain: + self.model_names = ['HeatMap', 'AutoEncoder'] + else: + self.model_names = ['HeatMap', 'AutoEncoder'] + + self.eval_key = "mpjpe" + self.cm2mm = 10 + + + # define the transform network + self.net_HeatMap = network.define_HeatMap(opt, model=opt.model) + self.net_AutoEncoder = network.define_AutoEncoder(opt, model=opt.model) + + self.load_networks( + net=self.net_HeatMap, + path_to_trained_weights=opt.path_to_trained_heatmap + ) + network._freeze(self.net_HeatMap) + + # define loss functions + self.lossfunc_MSE = MSELoss() + self.lossfunc_limb = LossFuncLimb() + self.lossfunc_cos_sim = LossFuncCosSim() + self.lossfunc_MPJPE = LossFuncMPJPE() + + if self.isTrain: + # initialize optimizers + self.optimizer_AutoEncoder = torch.optim.Adam( + params=self.net_AutoEncoder.parameters(), + lr=opt.lr, + weight_decay=opt.weight_decay + ) + + self.optimizers = [] + self.schedulers = [] + self.optimizers.append(self.optimizer_AutoEncoder) + for optimizer in self.optimizers: + self.schedulers.append(network.get_scheduler(optimizer, opt)) + + def set_input(self, data): + self.data = data + self.input_rgb_left = data['input_rgb_left'].cuda(self.device) + self.input_rgb_right = data['input_rgb_right'].cuda(self.device) + self.gt_heatmap_left = data['gt_heatmap_left'].cuda(self.device) + self.gt_heatmap_right = data['gt_heatmap_right'].cuda(self.device) + self.gt_pose = data['gt_local_pose'].cuda(self.device) + + def forward(self): + with autocast(enabled=self.opt.use_amp): + # estimate stereo heatmaps + with torch.no_grad(): + pred_heatmap_cat = self.net_HeatMap(self.input_rgb_left, self.input_rgb_right) + self.pred_heatmap_left, self.pred_heatmap_right = torch.chunk(pred_heatmap_cat, 2, dim=1) + + # estimate pose and reconstruct stereo heatmaps + self.pred_pose, pred_heatmap_rec_cat = self.net_AutoEncoder(pred_heatmap_cat) + self.pred_heatmap_left_rec, self.pred_heatmap_right_rec = torch.chunk(pred_heatmap_rec_cat, 2, dim=1) + + def backward_AutoEncoder(self): + with autocast(enabled=self.opt.use_amp): + loss_pose = self.lossfunc_MPJPE(self.pred_pose, self.gt_pose) + loss_cos_sim = self.lossfunc_cos_sim(self.pred_pose, self.gt_pose) + loss_heatmap_left_rec = self.lossfunc_MSE( + self.pred_heatmap_left_rec, self.pred_heatmap_left.detach() + ) + loss_heatmap_right_rec = self.lossfunc_MSE( + self.pred_heatmap_right_rec, self.pred_heatmap_right.detach() + ) + + self.loss_pose = loss_pose * self.opt.lambda_mpjpe + self.loss_cos_sim = loss_cos_sim * self.opt.lambda_cos_sim * self.opt.lambda_mpjpe + self.loss_heatmap_left_rec = loss_heatmap_left_rec * self.opt.lambda_heatmap_rec + self.loss_heatmap_right_rec = loss_heatmap_right_rec * self.opt.lambda_heatmap_rec + + loss_total = self.loss_pose + self.loss_cos_sim + \ + self.loss_heatmap_left_rec + self.loss_heatmap_right_rec + + self.scaler.scale(loss_total).backward() + + def optimize_parameters(self): + + # set model trainable + self.net_AutoEncoder.train() + + # set optimizer.zero_grad() + self.optimizer_AutoEncoder.zero_grad() + + # forward + self.forward() + + # backward + self.backward_AutoEncoder() + + # optimizer step + self.scaler.step(self.optimizer_AutoEncoder) + + self.scaler.update() + + def evaluate(self, runnning_average_dict): + # set evaluation mode + self.net_HeatMap.eval() + self.net_AutoEncoder.eval() + + # forward pass + pred_heatmap_cat = self.net_HeatMap(self.input_rgb_left, self.input_rgb_right) + self.pred_heatmap_left, self.pred_heatmap_right = torch.chunk(pred_heatmap_cat, 2, dim=1) + self.pred_pose = self.net_AutoEncoder.predict_pose(pred_heatmap_cat) + + S1_hat = batch_compute_similarity_transform_torch(self.pred_pose, self.gt_pose) + + # compute metrics + for id in range(self.pred_pose.size()[0]): # batch size + # calculate mpjpe and p_mpjpe # cm to mm + mpjpe = self.lossfunc_MPJPE(self.pred_pose[id], self.gt_pose[id]) * self.cm2mm + pa_mpjpe = self.lossfunc_MPJPE(S1_hat[id], self.gt_pose[id]) * self.cm2mm + + # update metrics dict + runnning_average_dict.update(dict( + mpjpe=mpjpe, + pa_mpjpe=pa_mpjpe) + ) + + return runnning_average_dict + + diff --git a/phc/learning/unrealego/unrealego_heatmap_shared_model.py b/phc/learning/unrealego/unrealego_heatmap_shared_model.py new file mode 100644 index 0000000..6e1c87e --- /dev/null +++ b/phc/learning/unrealego/unrealego_heatmap_shared_model.py @@ -0,0 +1,146 @@ +from cProfile import run +from enum import auto +import torch +import torch.nn as nn +from torch.autograd import Variable +from torch.cuda.amp import autocast, GradScaler +from torch.nn import MSELoss + +import itertools +from .base_model import BaseModel +from . import network +from utils.loss import LossFuncLimb, LossFuncCosSim, LossFuncMPJPE +from utils.util import batch_compute_similarity_transform_torch + + +class UnrealEgoHeatmapSharedModel(BaseModel): + def name(self): + return 'UnrealEgo Heatmap Shared model' + + def initialize(self, opt): + BaseModel.initialize(self, opt) + + self.opt = opt + self.scaler = GradScaler(enabled=opt.use_amp) + + self.loss_names = [ + 'heatmap_left', 'heatmap_right', + ] + + self.visual_names = [ + 'input_rgb_left', 'input_rgb_right', + 'pred_heatmap_left', 'pred_heatmap_right', + 'gt_heatmap_left', 'gt_heatmap_right', + ] + + self.visual_pose_names = [ + ] + + if self.isTrain: + self.model_names = ['HeatMap'] + else: + self.model_names = ['HeatMap'] + + self.eval_key = "mse_heatmap" + self.cm2mm = 10 + + + # define the transform network + print(opt.model) + self.net_HeatMap = network.define_HeatMap(opt, model=opt.model) + + if self.isTrain: + # define loss functions + self.lossfunc_MSE = MSELoss() + + # initialize optimizers + self.optimizer_HeatMap = torch.optim.Adam( + params=self.net_HeatMap.parameters(), + lr=opt.lr, + weight_decay=opt.weight_decay + ) + + self.optimizers = [] + self.schedulers = [] + self.optimizers.append(self.optimizer_HeatMap) + for optimizer in self.optimizers: + self.schedulers.append(network.get_scheduler(optimizer, opt)) + + # if not self.isTrain or opt.continue_train: + # self.load_networks(opt.which_epoch) + + def set_input(self, data): + self.data = data + self.input_rgb_left = data['input_rgb_left'].cuda(self.device) + self.input_rgb_right = data['input_rgb_right'].cuda(self.device) + self.gt_heatmap_left = data['gt_heatmap_left'].cuda(self.device) + self.gt_heatmap_right = data['gt_heatmap_right'].cuda(self.device) + + def forward(self): + with autocast(enabled=self.opt.use_amp): + # estimate stereo heatmaps + pred_heatmap_cat = self.net_HeatMap(self.input_rgb_left, self.input_rgb_right) + self.pred_heatmap_left, self.pred_heatmap_right = torch.chunk(pred_heatmap_cat, 2, dim=1) + + def backward_HeatMap(self): + with autocast(enabled=self.opt.use_amp): + loss_heatmap_left = self.lossfunc_MSE( + self.pred_heatmap_left, self.gt_heatmap_left + ) + loss_heatmap_right = self.lossfunc_MSE( + self.pred_heatmap_right, self.gt_heatmap_right + ) + + self.loss_heatmap_left = loss_heatmap_left * self.opt.lambda_heatmap + self.loss_heatmap_right = loss_heatmap_right * self.opt.lambda_heatmap + + loss_total = self.loss_heatmap_left + self.loss_heatmap_right + + self.scaler.scale(loss_total).backward() + + def optimize_parameters(self): + + # set model trainable + self.net_HeatMap.train() + + # set optimizer.zero_grad() + self.optimizer_HeatMap.zero_grad() + + # forward + self.forward() + + # backward + self.backward_HeatMap() + + # optimizer step + self.scaler.step(self.optimizer_HeatMap) + + self.scaler.update() + + def evaluate(self, runnning_average_dict): + # set evaluation mode + self.net_HeatMap.eval() + + # forward pass + pred_heatmap_cat = self.net_HeatMap(self.input_rgb_left, self.input_rgb_right) + self.pred_heatmap_left, self.pred_heatmap_right = torch.chunk(pred_heatmap_cat, 2, dim=1) + + # compute metrics + for id in range(self.pred_heatmap_left.size()[0]): # batch size + # calculate mse loss for heatmap + loss_heatmap_left_id = self.lossfunc_MSE( + self.pred_heatmap_left[id], self.gt_heatmap_left[id] + ) + loss_heatmap_right_id = self.lossfunc_MSE( + self.pred_heatmap_right[id], self.gt_heatmap_right[id] + ) + + mse_heatmap = loss_heatmap_left_id + loss_heatmap_right_id + + # update metrics dict + runnning_average_dict.update(dict( + mse_heatmap=mse_heatmap + ) + ) + + return runnning_average_dict \ No newline at end of file diff --git a/phc/learning/vq_quantizer.py b/phc/learning/vq_quantizer.py new file mode 100644 index 0000000..369f835 --- /dev/null +++ b/phc/learning/vq_quantizer.py @@ -0,0 +1,166 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Quantizer(nn.Module): + def __init__(self, n_e, e_dim, beta): + super(Quantizer, self).__init__() + + self.e_dim = e_dim + self.n_e = n_e + self.beta = beta + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + # self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + # self.embedding.weight.data.uniform_(-1.0 / 2, 1.0 / 2) + self.embedding.weight.data.uniform_(-1.0 / 256, 1.0 / 256) + # self.embedding.weight.data = self.embedding.weight.data/self.embedding.weight.data.norm(dim = -1, keepdim=True) # project to sphere + # self.embedding.weight.data[:] *= 10 + + + def forward(self, z, return_perplexity=False, return_loss = True): + """ + Inputs the output of the encoder network z and maps it to a discrete + one-hot vectort that is the index of the closest embedding vector e_j + z (continuous) -> z_q (discrete) + :param z (B, seq_len, channel): + :return z_q: + """ + assert z.shape[-1] == self.e_dim + z_flattened = z.contiguous().view(-1, self.e_dim) + + # B x V + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - 2 * \ + torch.matmul(z_flattened, self.embedding.weight.t()) + # B x 1 + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + + # compute loss for embedding + if return_loss: + loss = torch.mean((z_q - z.detach())**2) + self.beta * torch.mean((z_q.detach() - z)**2) + # loss = self.beta * torch.mean((z_q.detach() - z)**2) + + # preserve gradients + z_q = z + (z_q - z).detach() + else: + loss = torch.tensor(0.0).to(z.device) + + if return_perplexity: + min_encodings = F.one_hot(min_encoding_indices, self.n_e).type(z.dtype) # measuring utilization + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean*torch.log(e_mean + 1e-10))) + return loss, z_q, min_encoding_indices, perplexity + else: + return loss, z_q, min_encoding_indices + + def map2index(self, z): + """ + Inputs the output of the encoder network z and maps it to a discrete + one-hot vectort that is the index of the closest embedding vector e_j + z (continuous) -> z_q (discrete) + :param z (B, seq_len, channel): + :return z_q: + """ + assert z.shape[-1] == self.e_dim + z_flattened = z.contiguous().view(-1, self.e_dim) + + # B x V + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - 2 * \ + torch.matmul(z_flattened, self.embedding.weight.t()) + # B x 1 + min_encoding_indices = torch.argmin(d, dim=1) + return min_encoding_indices + + def get_codebook_entry(self, indices): + """ + + :param indices(B, seq_len): + :return z_q(B, seq_len, e_dim): + """ + index_flattened = indices.view(-1) + z_q = self.embedding(index_flattened) + z_q = z_q.view(indices.shape + (self.e_dim, )).contiguous() + return z_q + + +class EmbeddingEMA(nn.Module): + def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5): + super(EmbeddingEMA, self).__init__() + self.decay = decay + self.eps = eps + weight = torch.randn(num_tokens, codebook_dim) + + # weight = weight/weight.norm(dim = -1, keepdim=True) # project to sphere + + self.weight = nn.Parameter(weight, requires_grad=False) + # self.weight.data.uniform_(-1.0 / num_tokens, 1.0 / num_tokens) + self.weight.data.uniform_(-1.0, 1.0) + + self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False) # counts for how many times the code is used. + self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False) + self.update = True + + def forward(self, embed_id): + return F.embedding(embed_id, self.weight) + + def cluster_size_ema_update(self, new_cluster_size): + self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay) + + def embed_avg_ema_update(self, new_emb_avg): + self.update_idxes = new_emb_avg.abs().sum(dim = -1) > 0 + self.embed_avg.data[self.update_idxes] = self.embed_avg.data[self.update_idxes].mul_(self.decay).add(new_emb_avg[self.update_idxes], alpha=1 - self.decay) + + def weight_update(self, num_tokens): + n = self.cluster_size.sum() + smoothed_cluster_size = ((self.cluster_size + self.eps) / (n + num_tokens*self.eps) * n) + embed_normalized = self.embed_avg + embed_normalized[self.update_idxes] = self.embed_avg[self.update_idxes] / smoothed_cluster_size.unsqueeze(1)[self.update_idxes] + self.weight.data.copy_(embed_normalized) + + + + +class EMAVectorQuantizer(nn.Module): + def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5): + super(EMAVectorQuantizer, self).__init__() + + self.codebook_dim = embedding_dim + self.num_tokens = n_embed + self.beta = beta + self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps) + + def forward(self, z, return_perplexity=False): + z_flattened = z.view(-1, self.codebook_dim) + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \ + torch.matmul(z_flattened, self.embedding.weight.t()) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + + min_encodings = F.one_hot(min_encoding_indices, self.num_tokens).type(z.dtype) + + if self.training and self.embedding.update: + encoding_sum = min_encodings.sum(0) + embed_sum = min_encodings.transpose(0, 1) @ z_flattened + + self.embedding.cluster_size_ema_update(encoding_sum) + self.embedding.embed_avg_ema_update(embed_sum) + self.embedding.weight_update(self.num_tokens) + + loss = self.beta * F.mse_loss(z_q.detach(), z) + + z_q = z + (z_q - z).detach() + + if return_perplexity: + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) + return loss, z_q, min_encoding_indices, perplexity + else: + return loss, z_q, min_encoding_indices + diff --git a/phc/run.py b/phc/run.py new file mode 100644 index 0000000..c33d65c --- /dev/null +++ b/phc/run.py @@ -0,0 +1,297 @@ +# Copyright (c) 2018-2023, NVIDIA Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import glob +import os +import sys +import pdb +import os.path as osp + +sys.path.append(os.getcwd()) + +from phc.utils.config import set_np_formatting, set_seed, get_args, parse_sim_params, load_cfg +from phc.utils.parse_task import parse_task + +from rl_games.algos_torch import players +from rl_games.algos_torch import torch_ext +from rl_games.common import env_configurations, experiment, vecenv +from rl_games.common.algo_observer import AlgoObserver +from rl_games.torch_runner import Runner + +from phc.utils.flags import flags + +import numpy as np +import copy +import torch +import wandb + +from learning import im_amp +from learning import im_amp_players +from learning import amp_agent +from learning import amp_players +from learning import amp_models +from learning import amp_network_builder +from learning import amp_network_mcp_builder +from learning import amp_network_pnn_builder + + +from env.tasks import humanoid_amp_task + +args = None +cfg = None +cfg_train = None + + +def create_rlgpu_env(**kwargs): + use_horovod = cfg_train['params']['config'].get('multi_gpu', False) + if use_horovod: + import horovod.torch as hvd + + rank = hvd.rank() + print("Horovod rank: ", rank) + + cfg_train['params']['seed'] = cfg_train['params']['seed'] + rank + + args.device = 'cuda' + args.device_id = rank + args.rl_device = 'cuda:' + str(rank) + + cfg['rank'] = rank + cfg['rl_device'] = 'cuda:' + str(rank) + + sim_params = parse_sim_params(args, cfg, cfg_train) + task, env = parse_task(args, cfg, cfg_train, sim_params) + + print(env.num_envs) + print(env.num_actions) + print(env.num_obs) + print(env.num_states) + + frames = kwargs.pop('frames', 1) + if frames > 1: + env = wrappers.FrameStack(env, frames, False) + return env + + +class RLGPUAlgoObserver(AlgoObserver): + + def __init__(self, use_successes=True): + self.use_successes = use_successes + return + + def after_init(self, algo): + self.algo = algo + self.consecutive_successes = torch_ext.AverageMeter(1, self.algo.games_to_track).to(self.algo.ppo_device) + self.writer = self.algo.writer + return + + def process_infos(self, infos, done_indices): + if isinstance(infos, dict): + if (self.use_successes == False) and 'consecutive_successes' in infos: + cons_successes = infos['consecutive_successes'].clone() + self.consecutive_successes.update(cons_successes.to(self.algo.ppo_device)) + if self.use_successes and 'successes' in infos: + successes = infos['successes'].clone() + self.consecutive_successes.update(successes[done_indices].to(self.algo.ppo_device)) + return + + def after_clear_stats(self): + self.mean_scores.clear() + return + + def after_print_stats(self, frame, epoch_num, total_time): + if self.consecutive_successes.current_size > 0: + mean_con_successes = self.consecutive_successes.get_mean() + self.writer.add_scalar('successes/consecutive_successes/mean', mean_con_successes, frame) + self.writer.add_scalar('successes/consecutive_successes/iter', mean_con_successes, epoch_num) + self.writer.add_scalar('successes/consecutive_successes/time', mean_con_successes, total_time) + return + + +class RLGPUEnv(vecenv.IVecEnv): + + def __init__(self, config_name, num_actors, **kwargs): + self.env = env_configurations.configurations[config_name]['env_creator'](**kwargs) + self.use_global_obs = (self.env.num_states > 0) + + self.full_state = {} + self.full_state["obs"] = self.reset() + if self.use_global_obs: + self.full_state["states"] = self.env.get_state() + return + + def step(self, action): + next_obs, reward, is_done, info = self.env.step(action) + + # todo: improve, return only dictinary + self.full_state["obs"] = next_obs + if self.use_global_obs: + self.full_state["states"] = self.env.get_state() + return self.full_state, reward, is_done, info + else: + return self.full_state["obs"], reward, is_done, info + + def reset(self, env_ids=None): + self.full_state["obs"] = self.env.reset(env_ids) + if self.use_global_obs: + self.full_state["states"] = self.env.get_state() + return self.full_state + else: + return self.full_state["obs"] + + def get_number_of_agents(self): + return self.env.get_number_of_agents() + + def get_env_info(self): + info = {} + info['action_space'] = self.env.action_space + info['observation_space'] = self.env.observation_space + info['amp_observation_space'] = self.env.amp_observation_space + + info['enc_amp_observation_space'] = self.env.enc_amp_observation_space + + if isinstance(self.env.task, humanoid_amp_task.HumanoidAMPTask): + info['task_obs_size'] = self.env.task.get_task_obs_size() + else: + info['task_obs_size'] = 0 + + if self.use_global_obs: + info['state_space'] = self.env.state_space + print(info['action_space'], info['observation_space'], info['state_space']) + else: + print(info['action_space'], info['observation_space']) + + return info + + +vecenv.register('RLGPU', lambda config_name, num_actors, **kwargs: RLGPUEnv(config_name, num_actors, **kwargs)) +env_configurations.register('rlgpu', {'env_creator': lambda **kwargs: create_rlgpu_env(**kwargs), 'vecenv_type': 'RLGPU'}) + + +def build_alg_runner(algo_observer): + runner = Runner(algo_observer) + runner.player_factory.register_builder('amp_discrete', lambda **kwargs: amp_players.AMPPlayerDiscrete(**kwargs)) + + runner.algo_factory.register_builder('amp', lambda **kwargs: amp_agent.AMPAgent(**kwargs)) + runner.player_factory.register_builder('amp', lambda **kwargs: amp_players.AMPPlayerContinuous(**kwargs)) + + runner.model_builder.model_factory.register_builder('amp', lambda network, **kwargs: amp_models.ModelAMPContinuous(network)) + runner.model_builder.network_factory.register_builder('amp', lambda **kwargs: amp_network_builder.AMPBuilder()) + runner.model_builder.network_factory.register_builder('amp_mcp', lambda **kwargs: amp_network_mcp_builder.AMPMCPBuilder()) + runner.model_builder.network_factory.register_builder('amp_pnn', lambda **kwargs: amp_network_pnn_builder.AMPPNNBuilder()) + + runner.algo_factory.register_builder('im_amp', lambda **kwargs: im_amp.IMAmpAgent(**kwargs)) + runner.player_factory.register_builder('im_amp', lambda **kwargs: im_amp_players.IMAMPPlayerContinuous(**kwargs)) + + return runner + + +def main(): + global args + global cfg + global cfg_train + + set_np_formatting() + args = get_args() + cfg_env_name = args.cfg_env.split("/")[-1].split(".")[0] + + args.logdir = args.network_path + cfg, cfg_train, logdir = load_cfg(args) + flags.debug, flags.follow, flags.fixed, flags.divide_group, flags.no_collision_check, flags.fixed_path, flags.real_path, flags.small_terrain, flags.show_traj, flags.server_mode, flags.slow, flags.real_traj, flags.im_eval, flags.no_virtual_display, flags.render_o3d = \ + args.debug, args.follow, False, False, False, False, False, args.small_terrain, True, args.server_mode, False, False, args.im_eval, args.no_virtual_display, args.render_o3d + + flags.add_proj = args.add_proj + flags.has_eval = args.has_eval + flags.trigger_input = False + flags.demo = args.demo + + if args.server_mode: + flags.follow = args.follow = True + flags.fixed = args.fixed = True + flags.no_collision_check = True + flags.show_traj = True + cfg['env']['episodeLength'] = 99999999999999 + + if args.test and not flags.small_terrain: + cfg['env']['episodeLength'] = 99999999999999 + + if args.real_traj: + cfg['env']['episodeLength'] = 99999999999999 + flags.real_traj = True + + + project_name = cfg.get("project_name", "pulse") + if (not args.no_log) and (not args.test) and (not args.debug): + wandb.init( + project=project_name, + resume=not args.resume_str is None, + id=args.resume_str, + notes=cfg.get("notes", "no notes"), + ) + wandb.config.update(cfg, allow_val_change=True) + wandb.run.name = cfg_env_name + wandb.run.save() + + cfg_train['params']['seed'] = set_seed(cfg_train['params'].get("seed", -1), cfg_train['params'].get("torch_deterministic", False)) + + if args.horovod: + cfg_train['params']['config']['multi_gpu'] = args.horovod + + if args.horizon_length != -1: + cfg_train['params']['config']['horizon_length'] = args.horizon_length + + if args.minibatch_size != -1: + cfg_train['params']['config']['minibatch_size'] = args.minibatch_size + + if args.motion_file: + cfg['env']['motion_file'] = args.motion_file + flags.test = args.test + + # Create default directories for weights and statistics + cfg_train['params']['config']['network_path'] = args.network_path + args.log_path = osp.join(args.log_path, cfg['name'], cfg_env_name) + cfg_train['params']['config']['log_path'] = args.log_path + cfg_train['params']['config']['train_dir'] = args.log_path + + os.makedirs(args.network_path, exist_ok=True) + os.makedirs(args.log_path, exist_ok=True) + + vargs = vars(args) + + algo_observer = RLGPUAlgoObserver() + + runner = build_alg_runner(algo_observer) + runner.load(cfg_train) + runner.reset() + runner.run(vargs) + + return + + +if __name__ == '__main__': + main() diff --git a/phc/run_hydra.py b/phc/run_hydra.py new file mode 100644 index 0000000..1068a93 --- /dev/null +++ b/phc/run_hydra.py @@ -0,0 +1,345 @@ +# Copyright (c) 2018-2023, NVIDIA Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import glob +import os +import sys +import pdb +import os.path as osp +os.environ["OMP_NUM_THREADS"] = "1" + +sys.path.append(os.getcwd()) + +from phc.utils.config import set_np_formatting, set_seed, SIM_TIMESTEP +from phc.utils.parse_task import parse_task +from isaacgym import gymapi +from isaacgym import gymutil + + +from rl_games.algos_torch import players +from rl_games.algos_torch import torch_ext +from rl_games.common import env_configurations, experiment, vecenv +from rl_games.common.algo_observer import AlgoObserver +from rl_games.torch_runner import Runner + +from phc.utils.flags import flags + +import numpy as np +import copy +import torch +import wandb + +from learning import im_amp +from learning import im_amp_players +from learning import amp_agent +from learning import amp_players +from learning import amp_models +from learning import amp_network_builder +from learning import amp_network_mcp_builder +from learning import amp_network_pnn_builder +from learning import amp_network_z_builder + +from env.tasks import humanoid_amp_task +import hydra +from omegaconf import DictConfig, OmegaConf +from easydict import EasyDict + +args = None +cfg = None +cfg_train = None + + +def parse_sim_params(cfg): + # initialize sim + sim_params = gymapi.SimParams() + sim_params.dt = SIM_TIMESTEP + sim_params.num_client_threads = cfg.sim.slices + + if cfg.sim.use_flex: + if cfg.sim.pipeline in ["gpu"]: + print("WARNING: Using Flex with GPU instead of PHYSX!") + sim_params.use_flex.shape_collision_margin = 0.01 + sim_params.use_flex.num_outer_iterations = 4 + sim_params.use_flex.num_inner_iterations = 10 + else : # use gymapi.SIM_PHYSX + sim_params.physx.solver_type = 1 + sim_params.physx.num_position_iterations = 4 + sim_params.physx.num_velocity_iterations = 1 + sim_params.physx.num_threads = 4 + sim_params.physx.use_gpu = cfg.sim.pipeline in ["gpu"] + sim_params.physx.num_subscenes = cfg.sim.subscenes + if flags.test and not flags.im_eval: + sim_params.physx.max_gpu_contact_pairs = 4 * 1024 * 1024 + else: + sim_params.physx.max_gpu_contact_pairs = 16 * 1024 * 1024 + + sim_params.use_gpu_pipeline = cfg.sim.pipeline in ["gpu"] + sim_params.physx.use_gpu = cfg.sim.pipeline in ["gpu"] + + # if sim options are provided in cfg, parse them and update/override above: + if "sim" in cfg: + gymutil.parse_sim_config(cfg["sim"], sim_params) + + # Override num_threads if passed on the command line + if not cfg.sim.use_flex and cfg.sim.physx.num_threads > 0: + sim_params.physx.num_threads = cfg.sim.physx.num_threads + + return sim_params + +def create_rlgpu_env(**kwargs): + use_horovod = cfg_train['params']['config'].get('multi_gpu', False) + if use_horovod: + import horovod.torch as hvd + + rank = hvd.rank() + print("Horovod rank: ", rank) + + cfg_train['params']['seed'] = cfg_train['params']['seed'] + rank + + args.device = 'cuda' + args.device_id = rank + args.rl_device = 'cuda:' + str(rank) + + cfg['rank'] = rank + cfg['rl_device'] = 'cuda:' + str(rank) + + sim_params = parse_sim_params(cfg) + args = EasyDict({ + "task": cfg.env.task, + "device_id": cfg.device_id, + "rl_device": cfg.rl_device, + "physics_engine": gymapi.SIM_PHYSX if not cfg.sim.use_flex else gymapi.SIM_FLEX, + "headless": cfg.headless, + "device": cfg.device, + }) #### ZL: patch + task, env = parse_task(args, cfg, cfg_train, sim_params) + + print(env.num_envs) + print(env.num_actions) + print(env.num_obs) + print(env.num_states) + + frames = kwargs.pop('frames', 1) + if frames > 1: + env = wrappers.FrameStack(env, frames, False) + return env + + +class RLGPUAlgoObserver(AlgoObserver): + + def __init__(self, use_successes=True): + self.use_successes = use_successes + return + + def after_init(self, algo): + self.algo = algo + self.consecutive_successes = torch_ext.AverageMeter(1, self.algo.games_to_track).to(self.algo.ppo_device) + self.writer = self.algo.writer + return + + def process_infos(self, infos, done_indices): + if isinstance(infos, dict): + if (self.use_successes == False) and 'consecutive_successes' in infos: + cons_successes = infos['consecutive_successes'].clone() + self.consecutive_successes.update(cons_successes.to(self.algo.ppo_device)) + if self.use_successes and 'successes' in infos: + successes = infos['successes'].clone() + self.consecutive_successes.update(successes[done_indices].to(self.algo.ppo_device)) + return + + def after_clear_stats(self): + self.mean_scores.clear() + return + + def after_print_stats(self, frame, epoch_num, total_time): + if self.consecutive_successes.current_size > 0: + mean_con_successes = self.consecutive_successes.get_mean() + self.writer.add_scalar('successes/consecutive_successes/mean', mean_con_successes, frame) + self.writer.add_scalar('successes/consecutive_successes/iter', mean_con_successes, epoch_num) + self.writer.add_scalar('successes/consecutive_successes/time', mean_con_successes, total_time) + return + + +class RLGPUEnv(vecenv.IVecEnv): + + def __init__(self, config_name, num_actors, **kwargs): + self.env = env_configurations.configurations[config_name]['env_creator'](**kwargs) + self.use_global_obs = (self.env.num_states > 0) + + self.full_state = {} + self.full_state["obs"] = self.reset() + if self.use_global_obs: + self.full_state["states"] = self.env.get_state() + return + + def step(self, action): + next_obs, reward, is_done, info = self.env.step(action) + + # todo: improve, return only dictinary + self.full_state["obs"] = next_obs + if self.use_global_obs: + self.full_state["states"] = self.env.get_state() + return self.full_state, reward, is_done, info + else: + return self.full_state["obs"], reward, is_done, info + + def reset(self, env_ids=None): + self.full_state["obs"] = self.env.reset(env_ids) + if self.use_global_obs: + self.full_state["states"] = self.env.get_state() + return self.full_state + else: + return self.full_state["obs"] + + def get_number_of_agents(self): + return self.env.get_number_of_agents() + + def get_env_info(self): + info = {} + info['action_space'] = self.env.action_space + info['observation_space'] = self.env.observation_space + info['amp_observation_space'] = self.env.amp_observation_space + + info['enc_amp_observation_space'] = self.env.enc_amp_observation_space + + if isinstance(self.env.task, humanoid_amp_task.HumanoidAMPTask): + info['task_obs_size'] = self.env.task.get_task_obs_size() + else: + info['task_obs_size'] = 0 + + if self.use_global_obs: + info['state_space'] = self.env.state_space + print(info['action_space'], info['observation_space'], info['state_space']) + else: + print(info['action_space'], info['observation_space']) + + return info + + +vecenv.register('RLGPU', lambda config_name, num_actors, **kwargs: RLGPUEnv(config_name, num_actors, **kwargs)) +env_configurations.register('rlgpu', {'env_creator': lambda **kwargs: create_rlgpu_env(**kwargs), 'vecenv_type': 'RLGPU'}) + + +def build_alg_runner(algo_observer): + runner = Runner(algo_observer) + runner.player_factory.register_builder('amp_discrete', lambda **kwargs: amp_players.AMPPlayerDiscrete(**kwargs)) + + runner.algo_factory.register_builder('amp', lambda **kwargs: amp_agent.AMPAgent(**kwargs)) + runner.player_factory.register_builder('amp', lambda **kwargs: amp_players.AMPPlayerContinuous(**kwargs)) + + runner.model_builder.model_factory.register_builder('amp', lambda network, **kwargs: amp_models.ModelAMPContinuous(network)) + runner.model_builder.network_factory.register_builder('amp', lambda **kwargs: amp_network_builder.AMPBuilder()) + runner.model_builder.network_factory.register_builder('amp_mcp', lambda **kwargs: amp_network_mcp_builder.AMPMCPBuilder()) + runner.model_builder.network_factory.register_builder('amp_pnn', lambda **kwargs: amp_network_pnn_builder.AMPPNNBuilder()) + runner.model_builder.network_factory.register_builder('amp_z', lambda **kwargs: amp_network_z_builder.AMPZBuilder()) + + runner.algo_factory.register_builder('im_amp', lambda **kwargs: im_amp.IMAmpAgent(**kwargs)) + runner.player_factory.register_builder('im_amp', lambda **kwargs: im_amp_players.IMAMPPlayerContinuous(**kwargs)) + + return runner + +@hydra.main( + version_base=None, + config_path="../phc/data/cfg", + config_name="config", +) +def main(cfg_hydra: DictConfig) -> None: + global cfg_train + global cfg + + cfg = EasyDict(OmegaConf.to_container(cfg_hydra, resolve=True)) + + set_np_formatting() + + # cfg, cfg_train, logdir = load_cfg(args) + flags.debug, flags.follow, flags.fixed, flags.divide_group, flags.no_collision_check, flags.fixed_path, flags.real_path, flags.show_traj, flags.server_mode, flags.slow, flags.real_traj, flags.im_eval, flags.no_virtual_display, flags.render_o3d = \ + cfg.debug, cfg.follow, False, False, False, False, False, True, cfg.server_mode, False, False, cfg.im_eval, cfg.no_virtual_display, cfg.render_o3d + + flags.test = cfg.test + flags.add_proj = cfg.add_proj + flags.has_eval = cfg.has_eval + flags.trigger_input = False + + if cfg.server_mode: + flags.follow = cfg.follow = True + flags.fixed = cfg.fixed = True + flags.no_collision_check = True + flags.show_traj = True + cfg['env']['episode_length'] = 99999999999999 + + if cfg.real_traj: + cfg['env']['episode_length'] = 99999999999999 + flags.real_traj = True + + cfg.train = not cfg.test + project_name = cfg.get("project_name", "PULSE") + if (not cfg.no_log) and (not cfg.test) and (not cfg.debug): + wandb.init( + project=project_name, + resume=not cfg.resume_str is None, + id=cfg.resume_str, + notes=cfg.get("notes", "no notes"), + ) + wandb.config.update(cfg, allow_val_change=True) + wandb.run.name = cfg.exp_name + wandb.run.save() + + set_seed(cfg.get("seed", -1), cfg.get("torch_deterministic", False)) + + # Create default directories for weights and statistics + cfg_train = cfg.learning + cfg_train['params']['config']['network_path'] = cfg.output_path + cfg_train['params']['config']['train_dir'] = cfg.output_path + cfg_train["params"]["config"]["num_actors"] = cfg.env.num_envs + + if cfg.epoch > 0: + cfg_train["params"]["load_checkpoint"] = True + cfg_train["params"]["load_path"] = osp.join(cfg.output_path, cfg_train["params"]["config"]['name'] + "_" + str(cfg.epoch).zfill(8) + '.pth') + elif cfg.epoch == -1: + path = osp.join(cfg.output_path, cfg_train["params"]["config"]['name'] + '.pth') + if osp.exists(path): + cfg_train["params"]["load_path"] = path + cfg_train["params"]["load_checkpoint"] = True + else: + print(path) + print("no file to resume!!!!") + + + os.makedirs(cfg.output_path, exist_ok=True) + + algo_observer = RLGPUAlgoObserver() + runner = build_alg_runner(algo_observer) + runner.load(cfg_train) + runner.reset() + runner.run(cfg) + + return + + +if __name__ == '__main__': + main() diff --git a/phc/utils/__init__.py b/phc/utils/__init__.py new file mode 100644 index 0000000..d79b55c --- /dev/null +++ b/phc/utils/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) 2018-2023, NVIDIA Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/phc/utils/benchmarking.py b/phc/utils/benchmarking.py new file mode 100644 index 0000000..85c497f --- /dev/null +++ b/phc/utils/benchmarking.py @@ -0,0 +1,71 @@ +from contextlib import contextmanager +import time +from collections import defaultdict +import re +import sys + +average_times = defaultdict(lambda: (0,0)) + +@contextmanager +def timeit(name): + start = time.time() + yield + end = time.time() + total_time, num_calls = average_times[name] + total_time += end-start + num_calls += 1 + print("TIME:", name, end-start, "| AVG", total_time / num_calls, f"| TOTAL {total_time} {num_calls}") + average_times[name] = (total_time, num_calls) + +def time_decorator(func): + def with_times(*args, **kwargs): + with timeit(func.__name__): + return func(*args, **kwargs) + return with_times + + +def recover_map(lines): + info = {} + pattern = re.compile(".* (.*) .* \| .* (.*\\b) .*\| .* (.*) (.*)") + + for l in lines: + if not l.startswith("TIME"): + continue + + match = pattern.match(l) + + name = match.group(1) + avg = float(match.group(2)) + total_time = float(match.group(3)) + total_calls = float(match.group(4)) + info[name] = (avg, total_time, total_calls) + + return info + +def compare_files(fileA, fileB): + with open(fileA) as fA: + linesA = fA.readlines() + + with open(fileB) as fB: + linesB = fB.readlines() + + mapA = recover_map(linesA) + mapB = recover_map(linesB) + + keysA = set(mapA.keys()) + keysB = set(mapB.keys()) + + inter = keysA.intersection(keysB) + print("Missing A", keysA.difference(inter)) + print("Missing B", keysB.difference(inter)) + + keys_ordered = list(sorted([(mapA[k][1], k) for k in inter], reverse=True)) + + for _, k in keys_ordered: + print(f"{k} {mapA[k]} {mapB[k]}") + + +if __name__ == "__main__": + fA = sys.argv[1] + fB = sys.argv[2] + compare_files(fA, fB) \ No newline at end of file diff --git a/phc/utils/config.py b/phc/utils/config.py new file mode 100644 index 0000000..4cd5181 --- /dev/null +++ b/phc/utils/config.py @@ -0,0 +1,487 @@ +# Copyright (c) 2018-2023, NVIDIA Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import glob +import os +import sys +import pdb +import os.path as osp + +sys.path.append(os.getcwd()) + +import yaml + +from isaacgym import gymapi +from isaacgym import gymutil + +import numpy as np +import random +import torch +from phc.utils.flags import flags + +SIM_TIMESTEP = 1.0 / 60.0 + + +def set_np_formatting(): + np.set_printoptions(edgeitems=30, infstr='inf', linewidth=4000, nanstr='nan', precision=2, suppress=False, threshold=10000, formatter=None) + + +def warn_task_name(): + raise Exception("Unrecognized task!\nTask should be one of: [BallBalance, Cartpole, CartpoleYUp, Ant, Humanoid, Anymal, FrankaCabinet, Quadcopter, ShadowHand, ShadowHandLSTM, ShadowHandFFOpenAI, ShadowHandFFOpenAITest, ShadowHandOpenAI, ShadowHandOpenAITest, Ingenuity]") + + +def set_seed(seed, torch_deterministic=False): + print("torch_deterministic:", torch_deterministic) + print("torch_deterministic:", torch_deterministic) + print("torch_deterministic:", torch_deterministic) + if seed == -1 and torch_deterministic: + seed = 42 + elif seed == -1: + seed = np.random.randint(0, 10000) + print("Setting seed: {}".format(seed)) + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + if torch_deterministic: + # refer to https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.use_deterministic_algorithms(True) + else: + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + return seed + + +def load_cfg(args): + with open(os.path.join(os.getcwd(), args.cfg_train), 'r') as f: + cfg_train = yaml.load(f, Loader=yaml.SafeLoader) + + with open(os.path.join(os.getcwd(), args.cfg_env), 'r') as f: + cfg = yaml.load(f, Loader=yaml.SafeLoader) + + # Override number of environments if passed on the command line + if args.num_envs > 0: + cfg["env"]["numEnvs"] = args.num_envs + + if args.episode_length > 0: + cfg["env"]["episodeLength"] = args.episode_length + + cfg["name"] = args.task + cfg["headless"] = args.headless + + # Set physics domain randomization + if "task" in cfg: + if "randomize" not in cfg["task"]: + cfg["task"]["randomize"] = args.randomize + else: + cfg["task"]["randomize"] = args.randomize or cfg["task"]["randomize"] + else: + cfg["task"] = {"randomize": False} + + logdir = args.logdir + # Set deterministic mode + if args.torch_deterministic: + cfg_train["params"]["torch_deterministic"] = True + + exp_name = cfg_train["params"]["config"]['name'] + + if args.experiment != 'Base': + if args.metadata: + exp_name = "{}_{}_{}_{}".format(args.experiment, args.task_type, args.device, str(args.physics_engine).split("_")[-1]) + + if cfg["task"]["randomize"]: + exp_name += "_DR" + else: + exp_name = args.experiment + + # Override config name + cfg_train["params"]["config"]['name'] = exp_name + + if args.epoch > 0: + cfg_train["params"]["load_checkpoint"] = True + cfg_train["params"]["load_path"] = osp.join(args.network_path, exp_name + "_" + str(args.epoch).zfill(8) + '.pth') + args.checkpoint = cfg_train["params"]["load_path"] + elif args.epoch == -1: + path = osp.join(args.network_path, exp_name + '.pth') + if osp.exists(path): + cfg_train["params"]["load_path"] = path + cfg_train["params"]["load_checkpoint"] = True + args.checkpoint = cfg_train["params"]["load_path"] + else: + print("no file to resume!!!!") + + + # if args.checkpoint != "Base": + # cfg_train["params"]["load_path"] = osp.join(args.network_path, exp_name + "_" + str(args.epoch).zfill(8) + '.pth') + + if args.llc_checkpoint != "": + cfg_train["params"]["config"]["llc_checkpoint"] = args.llc_checkpoint + + # Set maximum number of training iterations (epochs) + if args.max_iterations > 0: + cfg_train["params"]["config"]['max_epochs'] = args.max_iterations + + cfg_train["params"]["config"]["num_actors"] = cfg["env"]["numEnvs"] + + seed = cfg_train["params"].get("seed", -1) + if args.seed is not None: + seed = args.seed + cfg["seed"] = seed + cfg_train["params"]["seed"] = seed + + cfg["args"] = args + + return cfg, cfg_train, logdir + + +def parse_sim_params(args, cfg, cfg_train): + # initialize sim + sim_params = gymapi.SimParams() + sim_params.dt = SIM_TIMESTEP + sim_params.num_client_threads = args.slices + + if args.physics_engine == gymapi.SIM_FLEX: + if args.device != "cpu": + print("WARNING: Using Flex with GPU instead of PHYSX!") + sim_params.flex.shape_collision_margin = 0.01 + sim_params.flex.num_outer_iterations = 4 + sim_params.flex.num_inner_iterations = 10 + elif args.physics_engine == gymapi.SIM_PHYSX: + sim_params.physx.solver_type = 1 + sim_params.physx.num_position_iterations = 4 + sim_params.physx.num_velocity_iterations = 1 + sim_params.physx.num_threads = 4 + sim_params.physx.use_gpu = args.use_gpu + sim_params.physx.num_subscenes = args.subscenes + if flags.test and not flags.im_eval: + sim_params.physx.max_gpu_contact_pairs = 4 * 1024 * 1024 + else: + sim_params.physx.max_gpu_contact_pairs = 16 * 1024 * 1024 + + sim_params.use_gpu_pipeline = args.use_gpu_pipeline + sim_params.physx.use_gpu = args.use_gpu + + # if sim options are provided in cfg, parse them and update/override above: + if "sim" in cfg: + gymutil.parse_sim_config(cfg["sim"], sim_params) + + # Override num_threads if passed on the command line + if args.physics_engine == gymapi.SIM_PHYSX and args.num_threads > 0: + sim_params.physx.num_threads = args.num_threads + + return sim_params + + +def get_args(benchmark=False): + custom_parameters = [ + { + "name": "--test", + "action": "store_true", + "default": False, + "help": "Run trained policy, no training" + }, + { + "name": "--debug", + "action": "store_true", + "default": False, + "help": "Debugging, no training and no logging" + }, + { + "name": "--play", + "action": "store_true", + "default": False, + "help": "Run trained policy, the same as test, can be used only by rl_games RL library" + }, + { + "name": "--epoch", + "type": int, + "default": 0, + "help": "Resume training or start testing from a checkpoint" + }, + { + "name": "--checkpoint", + "type": str, + "default": "Base", + "help": "Path to the saved weights, only for rl_games RL library" + }, + { + "name": "--headless", + "action": "store_true", + "default": False, + "help": "Force display off at all times" + }, + { + "name": "--horovod", + "action": "store_true", + "default": False, + "help": "Use horovod for multi-gpu training, have effect only with rl_games RL library" + }, + { + "name": "--task", + "type": str, + "default": "Humanoid", + "help": "Can be BallBalance, Cartpole, CartpoleYUp, Ant, Humanoid, Anymal, FrankaCabinet, Quadcopter, ShadowHand, Ingenuity" + }, + { + "name": "--task_type", + "type": str, + "default": "Python", + "help": "Choose Python or C++" + }, + { + "name": "--rl_device", + "type": str, + "default": "cuda:0", + "help": "Choose CPU or GPU device for inferencing policy network" + }, + { + "name": "--logdir", + "type": str, + "default": "logs/" + }, + { + "name": "--experiment", + "type": str, + "default": "Base", + "help": "Experiment name. If used with --metadata flag an additional information about physics engine, sim device, pipeline and domain randomization will be added to the name" + }, + { + "name": "--metadata", + "action": "store_true", + "default": False, + "help": "Requires --experiment flag, adds physics engine, sim device, pipeline info and if domain randomization is used to the experiment name provided by user" + }, + { + "name": "--cfg_env", + "type": str, + "default": "Base", + "help": "Environment configuration file (.yaml)" + }, + { + "name": "--cfg_train", + "type": str, + "default": "Base", + "help": "Training configuration file (.yaml)" + }, + { + "name": "--motion_file", + "type": str, + "default": "", + "help": "Specify reference motion file" + }, + { + "name": "--num_envs", + "type": int, + "default": 0, + "help": "Number of environments to create - override config file" + }, + { + "name": "--episode_length", + "type": int, + "default": 0, + "help": "Episode length, by default is read from yaml config" + }, + { + "name": "--seed", + "type": int, + "help": "Random seed" + }, + { + "name": "--max_iterations", + "type": int, + "default": 0, + "help": "Set a maximum number of training iterations" + }, + { + "name": "--horizon_length", + "type": int, + "default": -1, + "help": "Set number of simulation steps per 1 PPO iteration. Supported only by rl_games. If not -1 overrides the config settings." + }, + { + "name": "--minibatch_size", + "type": int, + "default": -1, + "help": "Set batch size for PPO optimization step. Supported only by rl_games. If not -1 overrides the config settings." + }, + { + "name": "--randomize", + "action": "store_true", + "default": False, + "help": "Apply physics domain randomization" + }, + { + "name": "--torch_deterministic", + "action": "store_true", + "default": False, + "help": "Apply additional PyTorch settings for more deterministic behaviour" + }, + { + "name": "--network_path", + "type": str, + "default": "output/", + "help": "Specify network output directory" + }, + { + "name": "--log_path", + "type": str, + "default": "log/", + "help": "Specify log directory" + }, + { + "name": "--llc_checkpoint", + "type": str, + "default": "", + "help": "Path to the saved weights for the low-level controller of an HRL agent." + }, + { + "name": "--no_log", + "action": "store_true", + "default": False, + "help": "No wandb logging" + }, + { + "name": "--resume_str", + "type": str, + "default": None, + "help": "Resuming training from a specific logging instance" + }, + { + "name": "--follow", + "action": "store_true", + "default": False, + "help": "Follow Humanoid" + }, + { + "name": "--real_traj", + "action": "store_true", + "default": False, + "help": "load real_traj" + }, + { + "name": "--show_sensors", + "action": "store_true", + "default": False, + "help": "load real data mesh" + }, + { + "name": "--small_terrain", + "action": "store_true", + "default": False, + "help": "load real data mesh" + }, + { + "name": "--server_mode", + "action": "store_true", + "default": False, + "help": "load real data mesh" + }, + { + "name": "--add_proj", + "action": "store_true", + "default": False, + "help": "adding small projectiiles or not" + }, + { + "name": "--im_eval", + "action": "store_true", + "default": False, + "help": "Eval imitation" + }, + { + "name": "--has_eval", + "action": "store_true", + "default": False, + "help": "Eval during training or not" + }, + { + "name": "--no_virtual_display", + "action": "store_true", + "default": False, + "help": "Disable virtual display" + }, + { + "name": "--render_o3d", + "action": "store_true", + "default": False, + "help": "Disable virtual display" + }, + + { + "name": "--demo", + "action": "store_true", + "default": False, + "help": "No SMPL_robot dependency" + }, + ] + + if benchmark: + custom_parameters += [{ + "name": "--num_proc", + "type": int, + "default": 1, + "help": "Number of child processes to launch" + }, { + "name": "--random_actions", + "action": "store_true", + "help": "Run benchmark with random actions instead of inferencing" + }, { + "name": "--bench_len", + "type": int, + "default": 10, + "help": "Number of timing reports" + }, { + "name": "--bench_file", + "action": "store", + "help": "Filename to store benchmark results" + }] + + # parse arguments + args = gymutil.parse_arguments(description="RL Policy", custom_parameters=custom_parameters) + + # allignment with examples + args.device_id = args.compute_device_id + args.device = args.sim_device_type if args.use_gpu_pipeline else 'cpu' + + if args.test: + args.play = args.test + args.train = False + elif args.play: + args.train = False + else: + args.train = True + + return args diff --git a/phc/utils/data_tree.py b/phc/utils/data_tree.py new file mode 100644 index 0000000..7bd5dc7 --- /dev/null +++ b/phc/utils/data_tree.py @@ -0,0 +1,198 @@ +import numpy as np +import json +import copy +import os +from collections import OrderedDict + +class data_tree(object): + def __init__(self, name): + self._name = name + self._children, self._children_names, self._picked, self._depleted = \ + [], [], [], [] + self._data, self._length = [], [] + self._total_length, self._num_leaf, self._is_leaf = 0, 0, 0 + self._assigned_prob = 0.0 + self._node_weights = [] + + def add_node(self, node_weight, dict_hierachy, mocap_data): + # data_hierachy -> 'behavior' 'direction' 'type' 'style' + # behavior, direction, mocap_type, style = mocap_data[2:] + self._num_leaf += 1 + + if len(dict_hierachy) == 0: + # leaf node + self._data.append(mocap_data[0]) + self._length.append(mocap_data[1]) + self._node_weights.append(node_weight) + self._picked.append(0) + self._depleted.append(0) + self._is_leaf = 1 + else: + children_name = dict_hierachy[0].replace('\n', '') + if children_name not in self._children_names: + self._children_names.append(children_name) + self._children.append(data_tree(children_name)) + self._picked.append(0) + self._depleted.append(0) + + # add the data + index = self._children_names.index(children_name) + self._children[index].add_node(node_weight, dict_hierachy[1:], mocap_data) + + def summarize_length(self): + if self._is_leaf: + self._total_length = np.sum(self._length) + else: + self._total_length = 0 + for i_child in self._children: + self._total_length += i_child.summarize_length() + + return self._total_length + + def to_dict(self, verbose=False): + if self._is_leaf: + self._data_dict = copy.deepcopy(self._data) + else: + self._data_dict = OrderedDict() + for i_child in self._children: + self._data_dict[i_child.name] = i_child.to_dict(verbose) + + if verbose: + if self._is_leaf: + verbose_data_dict = [] + for ii, i_key in enumerate(self._data_dict): + new_key = i_key + ' (picked {} / {})'.format( + str(self._picked[ii]), self._length[ii] + ) + verbose_data_dict.append(new_key) + else: + verbose_data_dict = OrderedDict() + for ii, i_key in enumerate(self._data_dict): + new_key = i_key + ' (picked {} / {})'.format( + str(self._picked[ii]), self._children[ii].total_length + ) + verbose_data_dict[new_key] = self._data_dict[i_key] + + self._data_dict = verbose_data_dict + + return self._data_dict + + @property + def name(self): + return self._name + + @property + def picked(self): + return self._picked + + @property + def total_length(self): + return self._total_length + + def water_floating_algorithm(self): + # find the sub class with the minimum picked + assert not np.all(self._depleted) + for ii in np.where(np.array(self._children_names) == 'mix')[0]: + self._depleted[ii] = np.inf + chosen_child = np.argmin(np.array(self._picked) + + np.array(self._depleted)) + if self._is_leaf: + self._picked[chosen_child] = self._length[chosen_child] + self._depleted[chosen_child] = np.inf + chosen_data = self._data[chosen_child] + data_info = {'name': [self._name], + 'length': self._length[chosen_child], + 'all_depleted': np.all(self._depleted)} + else: + chosen_data, data_info = \ + self._children[chosen_child].water_floating_algorithm() + self._picked[chosen_child] += data_info['length'] + data_info['name'].insert(0, self._name) + if data_info['all_depleted']: + self._depleted[chosen_child] = np.inf + data_info['all_depleted'] = np.all(self._depleted) + + return chosen_data, data_info + + def assign_probability(self, total_prob): + # find the sub class with the minimum picked + leaves, probs = [], [] + weights = [] + if self._is_leaf: + self._assigned_prob = total_prob + leaves.extend(self._data) + per_traj_prob = total_prob / float(len(self._data)) + probs.extend([per_traj_prob] * len(self._data)) + weights.extend(self._node_weights) + else: + per_child_prob = total_prob / float(len(self._children)) + for i_child in self._children: + i_leave, i_prob, i_weights = i_child.assign_probability(per_child_prob) + leaves.extend(i_leave) + probs.extend(i_prob) + weights.extend(i_weights) + + return leaves, probs, weights + + +def parse_dataset(env, args): + """ @brief: get the training set and test set + """ + TRAIN_PERCENTAGE = args.parse_dataset_train + info, motion = env.motion_info, env.motion + lengths = env.get_all_motion_length() + train_size = np.sum(motion.get_all_motion_length()) * TRAIN_PERCENTAGE + + data_structure = data_tree('root') + shuffle_id = list(range(len(info['mocap_data_list']))) + np.random.shuffle(shuffle_id) + info['mocap_data_list'] = [info['mocap_data_list'][ii] for ii in shuffle_id] + for mocap_data, length in zip(info['mocap_data_list'], lengths[shuffle_id]): + node_data = [mocap_data[0]] + [length] + data_structure.add_node(mocap_data[2:], node_data) + + raw_data_dict = data_structure.to_dict() + print(json.dumps(raw_data_dict, indent=4)) + + total_length = 0 + chosen_data = [] + while True: + i_data, i_info = data_structure.water_floating_algorithm() + print('Current length:', total_length, i_data, i_info) + total_length += i_info['length'] + chosen_data.append(i_data) + + if total_length > train_size: + break + data_structure.summarize_length() + data_dict = data_structure.to_dict(verbose=True) + print(json.dumps(data_dict, indent=4)) + + # save the training and test sets + train_data, test_data = [], [] + for i_data in info['mocap_data_list']: + if i_data[0] in chosen_data: + train_data.append(i_data[1:]) + else: + test_data.append(i_data[1:]) + + train_tsv_name = args.mocap_list_file.split('.')[0] + '_' + \ + str(int(args.parse_dataset_train * 100)) + '_train' + '.tsv' + test_tsv_name = train_tsv_name.replace('train', 'test') + info_name = test_tsv_name.replace('test', 'info').replace('.tsv', '.json') + + save_tsv_files(env._base_dir, train_tsv_name, train_data) + save_tsv_files(env._base_dir, test_tsv_name, test_data) + + info_file = open(os.path.join(env._base_dir, 'experiments', 'mocap_files', + info_name), 'w') + json.dump(data_dict, info_file, indent=4) + + +def save_tsv_files(base_dir, name, data_dict): + file_name = os.path.join(base_dir, 'experiments', 'mocap_files', name) + recorder = open(file_name, "w") + for i_data in data_dict: + line = '{}\t{}\t{}\t{}\t{}\n'.format(*i_data) + recorder.write(line) + recorder.close() \ No newline at end of file diff --git a/phc/utils/draw_utils.py b/phc/utils/draw_utils.py new file mode 100644 index 0000000..f135890 --- /dev/null +++ b/phc/utils/draw_utils.py @@ -0,0 +1,77 @@ +import numpy as np +import skimage +from skimage.draw import polygon +from skimage.draw import bezier_curve +from skimage.draw import circle_perimeter +from skimage.draw import disk +from scipy import ndimage +import matplotlib +import matplotlib.pyplot as plt +import matplotlib as mpl + + +def get_color_gradient(percent, color='Blues'): + return mpl.colormaps[color](percent)[:3] + + +def agt_color(aidx): + return matplotlib.colors.to_rgb(plt.rcParams['axes.prop_cycle'].by_key()['color'][aidx % 10]) + + +def draw_disk(img_size=80, max_r=10, iterations=3): + shape = (img_size, img_size) + img = np.zeros(shape, dtype=np.uint8) + x, y = np.random.uniform(max_r, img_size - max_r, size=(2)) + radius = int(np.random.uniform(max_r)) + rr, cc = disk((x, y), radius, shape=shape) + np.clip(rr, 0, img_size - 1, out=rr) + np.clip(cc, 0, img_size - 1, out=cc) + img[rr, cc] = 1 + return img + + +def draw_circle(img_size=80, max_r=10, iterations=3): + img = np.zeros((img_size, img_size), dtype=np.uint8) + r, c = np.random.uniform(max_r, img_size - max_r, size=(2,)).astype(int) + radius = int(np.random.uniform(max_r)) + rr, cc = circle_perimeter(r, c, radius) + np.clip(rr, 0, img_size - 1, out=rr) + np.clip(cc, 0, img_size - 1, out=cc) + img[rr, cc] = 1 + img = ndimage.binary_dilation(img, iterations=1).astype(int) + return img + + +def draw_curve(img_size=80, max_sides=10, iterations=3): + img = np.zeros((img_size, img_size), dtype=np.uint8) + r0, c0, r1, c1, r2, c2 = np.random.uniform(0, img_size, size=(6,)).astype(int) + w = np.random.random() + rr, cc = bezier_curve(r0, c0, r1, c1, r2, c2, w) + np.clip(rr, 0, img_size - 1, out=rr) + np.clip(cc, 0, img_size - 1, out=cc) + img[rr, cc] = 1 + img = ndimage.binary_dilation(img, iterations=iterations).astype(int) + return img + + +def draw_polygon(img_size=80, max_sides=10): + img = np.zeros((img_size, img_size), dtype=np.uint8) + num_coord = int(np.random.uniform(3, max_sides)) + r = np.random.uniform(0, img_size, size=(num_coord,)).astype(int) + c = np.random.uniform(0, img_size, size=(num_coord,)).astype(int) + rr, cc = polygon(r, c) + np.clip(rr, 0, img_size - 1, out=rr) + np.clip(cc, 0, img_size - 1, out=cc) + img[rr, cc] = 1 + return img + + +def draw_ellipse(img_size=80, max_size=10): + img = np.zeros((img_size, img_size), dtype=np.uint8) + r, c, rradius, cradius = np.random.uniform(max_size, img_size - max_size), np.random.uniform(max_size, img_size - max_size),\ + np.random.uniform(1, max_size), np.random.uniform(1, max_size) + rr, cc = skimage.draw.ellipse(r, c, rradius, cradius) + np.clip(rr, 0, img_size - 1, out=rr) + np.clip(cc, 0, img_size - 1, out=cc) + img[rr, cc] = 1 + return img \ No newline at end of file diff --git a/phc/utils/flags.py b/phc/utils/flags.py new file mode 100644 index 0000000..ab21eeb --- /dev/null +++ b/phc/utils/flags.py @@ -0,0 +1,13 @@ +__all__ = ['flags', 'summation'] + +class Flags(object): + def __init__(self, items): + for key, val in items.items(): + setattr(self,key,val) + +flags = Flags({ + 'test': False, + 'debug': False, + "real_traj": False, + "im_eval": False, + }) diff --git a/phc/utils/logger.py b/phc/utils/logger.py new file mode 100644 index 0000000..b5b1041 --- /dev/null +++ b/phc/utils/logger.py @@ -0,0 +1,116 @@ +# ----------------------------------------------------------------------------- +# @brief: +# The logger here will be called all across the project. It is inspired +# by Yuxin Wu (ppwwyyxx@gmail.com) +# +# @author: +# Tingwu Wang, 2017, Feb, 20th +# ----------------------------------------------------------------------------- + +import logging +import sys +import os +import datetime +from termcolor import colored + +__all__ = ['set_file_handler'] # the actual worker is the '_logger' + + +class _MyFormatter(logging.Formatter): + ''' + @brief: + a class to make sure the format could be used + ''' + + def format(self, record): + date = colored('[%(asctime)s @%(filename)s:%(lineno)d]', 'green') + msg = '%(message)s' + + if record.levelno == logging.WARNING: + fmt = date + ' ' + \ + colored('WRN', 'red', attrs=[]) + ' ' + msg + elif record.levelno == logging.ERROR or \ + record.levelno == logging.CRITICAL: + fmt = date + ' ' + \ + colored('ERR', 'red', attrs=['underline']) + ' ' + msg + else: + fmt = date + ' ' + msg + + if hasattr(self, '_style'): + # Python3 compatibilty + self._style._fmt = fmt + self._fmt = fmt + + return super(self.__class__, self).format(record) + + +_logger = logging.getLogger('joint_embedding') +_logger.propagate = False +_logger.setLevel(logging.INFO) + +# set the console output handler +con_handler = logging.StreamHandler(sys.stdout) +con_handler.setFormatter(_MyFormatter(datefmt='%m%d %H:%M:%S')) +_logger.addHandler(con_handler) + + +class GLOBAL_PATH(object): + + def __init__(self, path=None): + if path is None: + path = os.getcwd() + self.path = path + + def _set_path(self, path): + self.path = path + + def _get_path(self): + return self.path + + +PATH = GLOBAL_PATH() + + +def set_file_handler(path=None, prefix='', time_str=''): + # set the file output handler + if time_str == '': + file_name = prefix + \ + datetime.datetime.now().strftime("%A_%d_%B_%Y_%I:%M%p") + '.log' + else: + file_name = prefix + time_str + '.log' + + if path is None: + mod = sys.modules['__main__'] + path = os.path.join(os.path.abspath(mod.__file__), '..', '..', 'log') + else: + path = os.path.join(path, 'log') + path = os.path.abspath(path) + + path = os.path.join(path, file_name) + if not os.path.exists(path): + os.makedirs(path) + + PATH._set_path(path) + path = os.path.join(path, file_name) + from tensorboard_logger import configure + configure(path) + + file_handler = logging.FileHandler( + filename=os.path.join(path, 'logger'), encoding='utf-8', mode='w') + file_handler.setFormatter(_MyFormatter(datefmt='%m%d %H:%M:%S')) + _logger.addHandler(file_handler) + + _logger.info('Log file set to {}'.format(path)) + return path + + +def _get_path(): + return PATH._get_path() + + +_LOGGING_METHOD = ['info', 'warning', 'error', 'critical', + 'warn', 'exception', 'debug'] + +# export logger functions +for func in _LOGGING_METHOD: + locals()[func] = getattr(_logger, func) diff --git a/phc/utils/motion_lib_base.py b/phc/utils/motion_lib_base.py new file mode 100644 index 0000000..f4b9101 --- /dev/null +++ b/phc/utils/motion_lib_base.py @@ -0,0 +1,564 @@ + +import glob +import os +import sys +import pdb +import os.path as osp +sys.path.append(os.getcwd()) + +import numpy as np +import os +import yaml +from tqdm import tqdm + +from phc.utils import torch_utils +import joblib +import torch +from poselib.poselib.skeleton.skeleton3d import SkeletonMotion, SkeletonState +import torch.multiprocessing as mp +import gc +from scipy.spatial.transform import Rotation as sRot +import random +from phc.utils.flags import flags +from enum import Enum +USE_CACHE = False +print("MOVING MOTION DATA TO GPU, USING CACHE:", USE_CACHE) + + +class FixHeightMode(Enum): + no_fix = 0 + full_fix = 1 + ankle_fix = 2 + +if not USE_CACHE: + old_numpy = torch.Tensor.numpy + + class Patch: + + def numpy(self): + if self.is_cuda: + return self.to("cpu").numpy() + else: + return old_numpy(self) + + torch.Tensor.numpy = Patch.numpy + + +def local_rotation_to_dof_vel(local_rot0, local_rot1, dt): + # Assume each joint is 3dof + diff_quat_data = torch_utils.quat_mul(torch_utils.quat_conjugate(local_rot0), local_rot1) + diff_angle, diff_axis = torch_utils.quat_to_angle_axis(diff_quat_data) + dof_vel = diff_axis * diff_angle.unsqueeze(-1) / dt + + return dof_vel[1:, :].flatten() + + +def compute_motion_dof_vels(motion): + num_frames = motion.tensor.shape[0] + dt = 1.0 / motion.fps + dof_vels = [] + + for f in range(num_frames - 1): + local_rot0 = motion.local_rotation[f] + local_rot1 = motion.local_rotation[f + 1] + frame_dof_vel = local_rotation_to_dof_vel(local_rot0, local_rot1, dt) + dof_vels.append(frame_dof_vel) + + dof_vels.append(dof_vels[-1]) + dof_vels = torch.stack(dof_vels, dim=0).view(num_frames, -1, 3) + + return dof_vels + + +class DeviceCache: + + def __init__(self, obj, device): + self.obj = obj + self.device = device + + keys = dir(obj) + num_added = 0 + for k in keys: + try: + out = getattr(obj, k) + except: + # print("Error for key=", k) + continue + + if isinstance(out, torch.Tensor): + if out.is_floating_point(): + out = out.to(self.device, dtype=torch.float32) + else: + out.to(self.device) + setattr(self, k, out) + num_added += 1 + elif isinstance(out, np.ndarray): + out = torch.tensor(out) + if out.is_floating_point(): + out = out.to(self.device, dtype=torch.float32) + else: + out.to(self.device) + setattr(self, k, out) + num_added += 1 + + # print("Total added", num_added) + + def __getattr__(self, string): + out = getattr(self.obj, string) + return out + +class MotionlibMode(Enum): + file = 1 + directory = 2 + +class MotionLibBase(): + + def __init__(self, motion_lib_cfg): + self.m_cfg = motion_lib_cfg + self._device = self.m_cfg.device + + self.mesh_parsers = None + + self.load_data(self.m_cfg.motion_file, min_length = self.m_cfg.min_length, im_eval = self.m_cfg.im_eval) + self.setup_constants(fix_height = self.m_cfg.fix_height, multi_thread = self.m_cfg.multi_thread) + + if flags.real_traj: + self.track_idx = self._motion_data_load[next(iter(self._motion_data_load))].get("track_idx", [19, 24, 29]) + return + + def load_data(self, motion_file, min_length=-1, im_eval = False): + if osp.isfile(motion_file): + self.mode = MotionlibMode.file + self._motion_data_load = joblib.load(motion_file) + else: + self.mode = MotionlibMode.directory + self._motion_data_load = glob.glob(osp.join(motion_file, "*.pkl")) + + data_list = self._motion_data_load + + if self.mode == MotionlibMode.file: + if min_length != -1: + data_list = {k: v for k, v in list(self._motion_data_load.items()) if len(v['pose_quat_global']) >= min_length} + elif im_eval: + data_list = {item[0]: item[1] for item in sorted(self._motion_data_load.items(), key=lambda entry: len(entry[1]['pose_quat_global']), reverse=True)} + # data_list = self._motion_data + else: + data_list = self._motion_data_load + + self._motion_data_list = np.array(list(data_list.values())) + self._motion_data_keys = np.array(list(data_list.keys())) + else: + self._motion_data_list = np.array(self._motion_data_load) + self._motion_data_keys = np.array(self._motion_data_load) + + self._num_unique_motions = len(self._motion_data_list) + if self.mode == MotionlibMode.directory: + self._motion_data_load = joblib.load(self._motion_data_load[0]) # set self._motion_data_load to a sample of the data + + def setup_constants(self, fix_height = FixHeightMode.full_fix, multi_thread = True): + self.fix_height = fix_height + self.multi_thread = multi_thread + + #### Termination history + self._curr_motion_ids = None + self._termination_history = torch.zeros(self._num_unique_motions).to(self._device) + self._success_rate = torch.zeros(self._num_unique_motions).to(self._device) + self._sampling_history = torch.zeros(self._num_unique_motions).to(self._device) + self._sampling_prob = torch.ones(self._num_unique_motions).to(self._device) / self._num_unique_motions # For use in sampling batches + self._sampling_batch_prob = None # For use in sampling within batches + + + @staticmethod + def load_motion_with_skeleton(ids, motion_data_list, skeleton_trees, shape_params, mesh_parsers, config, queue, pid): + raise NotImplementedError + + @staticmethod + def fix_trans_height(pose_aa, trans, curr_gender_betas, mesh_parsers, fix_height_mode): + raise NotImplementedError + + def load_motions(self, skeleton_trees, gender_betas, limb_weights, random_sample=True, start_idx=0, max_len=-1): + # load motion load the same number of motions as there are skeletons (humanoids) + if "gts" in self.__dict__: + del self.gts, self.grs, self.lrs, self.grvs, self.gravs, self.gavs, self.gvs, self.dvs, + del self._motion_lengths, self._motion_fps, self._motion_dt, self._motion_num_frames, self._motion_bodies, self._motion_aa + if flags.real_traj: + del self.q_gts, self.q_grs, self.q_gavs, self.q_gvs + + motions = [] + self._motion_lengths = [] + self._motion_fps = [] + self._motion_dt = [] + self._motion_num_frames = [] + self._motion_bodies = [] + self._motion_aa = [] + + if flags.real_traj: + self.q_gts, self.q_grs, self.q_gavs, self.q_gvs = [], [], [], [] + + torch.cuda.empty_cache() + gc.collect() + + total_len = 0.0 + self.num_joints = len(skeleton_trees[0].node_names) + num_motion_to_load = len(skeleton_trees) + + if random_sample: + sample_idxes = torch.multinomial(self._sampling_prob, num_samples=num_motion_to_load, replacement=True).to(self._device) + else: + sample_idxes = torch.remainder(torch.arange(len(skeleton_trees)) + start_idx, self._num_unique_motions ).to(self._device) + + # import ipdb; ipdb.set_trace() + self._curr_motion_ids = sample_idxes + self.one_hot_motions = torch.nn.functional.one_hot(self._curr_motion_ids, num_classes = self._num_unique_motions).to(self._device) # Testing for obs_v5 + self.curr_motion_keys = self._motion_data_keys[sample_idxes] + self._sampling_batch_prob = self._sampling_prob[self._curr_motion_ids] / self._sampling_prob[self._curr_motion_ids].sum() + + print("\n****************************** Current motion keys ******************************") + print("Sampling motion:", sample_idxes[:30]) + if len(self.curr_motion_keys) < 100: + print(self.curr_motion_keys) + else: + print(self.curr_motion_keys[:30], ".....") + print("*********************************************************************************\n") + + + motion_data_list = self._motion_data_list[sample_idxes.cpu().numpy()] + mp.set_sharing_strategy('file_descriptor') + + manager = mp.Manager() + queue = manager.Queue() + num_jobs = min(mp.cpu_count(), 64) + + if num_jobs <= 8 or not self.multi_thread: + num_jobs = 1 + if flags.debug: + num_jobs = 1 + + res_acc = {} # using dictionary ensures order of the results. + jobs = motion_data_list + chunk = np.ceil(len(jobs) / num_jobs).astype(int) + ids = np.arange(len(jobs)) + + jobs = [(ids[i:i + chunk], jobs[i:i + chunk], skeleton_trees[i:i + chunk], gender_betas[i:i + chunk], self.mesh_parsers, self.m_cfg) for i in range(0, len(jobs), chunk)] + job_args = [jobs[i] for i in range(len(jobs))] + for i in range(1, len(jobs)): + worker_args = (*job_args[i], queue, i) + worker = mp.Process(target=self.load_motion_with_skeleton, args=worker_args) + worker.start() + res_acc.update(self.load_motion_with_skeleton(*jobs[0], None, 0)) + + for i in tqdm(range(len(jobs) - 1)): + res = queue.get() + res_acc.update(res) + + for f in tqdm(range(len(res_acc))): + motion_file_data, curr_motion = res_acc[f] + if USE_CACHE: + curr_motion = DeviceCache(curr_motion, self._device) + + motion_fps = curr_motion.fps + curr_dt = 1.0 / motion_fps + + num_frames = curr_motion.tensor.shape[0] + curr_len = 1.0 / motion_fps * (num_frames - 1) + + + if "beta" in motion_file_data: + self._motion_aa.append(motion_file_data['pose_aa'].reshape(-1, self.num_joints * 3)) + self._motion_bodies.append(curr_motion.gender_beta) + else: + self._motion_aa.append(np.zeros((num_frames, self.num_joints * 3))) + self._motion_bodies.append(torch.zeros(17)) + + self._motion_fps.append(motion_fps) + self._motion_dt.append(curr_dt) + self._motion_num_frames.append(num_frames) + motions.append(curr_motion) + self._motion_lengths.append(curr_len) + + if flags.real_traj: + self.q_gts.append(curr_motion.quest_motion['quest_trans']) + self.q_grs.append(curr_motion.quest_motion['quest_rot']) + self.q_gavs.append(curr_motion.quest_motion['global_angular_vel']) + self.q_gvs.append(curr_motion.quest_motion['linear_vel']) + + del curr_motion + + self._motion_lengths = torch.tensor(self._motion_lengths, device=self._device, dtype=torch.float32) + self._motion_fps = torch.tensor(self._motion_fps, device=self._device, dtype=torch.float32) + self._motion_bodies = torch.stack(self._motion_bodies).to(self._device).type(torch.float32) + self._motion_aa = torch.tensor(np.concatenate(self._motion_aa), device=self._device, dtype=torch.float32) + + self._motion_dt = torch.tensor(self._motion_dt, device=self._device, dtype=torch.float32) + self._motion_num_frames = torch.tensor(self._motion_num_frames, device=self._device) + self._motion_limb_weights = torch.tensor(np.array(limb_weights), device=self._device, dtype=torch.float32) + self._num_motions = len(motions) + + self.gts = torch.cat([m.global_translation for m in motions], dim=0).float().to(self._device) + self.grs = torch.cat([m.global_rotation for m in motions], dim=0).float().to(self._device) + self.lrs = torch.cat([m.local_rotation for m in motions], dim=0).float().to(self._device) + self.grvs = torch.cat([m.global_root_velocity for m in motions], dim=0).float().to(self._device) + self.gravs = torch.cat([m.global_root_angular_velocity for m in motions], dim=0).float().to(self._device) + self.gavs = torch.cat([m.global_angular_velocity for m in motions], dim=0).float().to(self._device) + self.gvs = torch.cat([m.global_velocity for m in motions], dim=0).float().to(self._device) + self.dvs = torch.cat([m.dof_vels for m in motions], dim=0).float().to(self._device) + + if flags.real_traj: + self.q_gts = torch.cat(self.q_gts, dim=0).float().to(self._device) + self.q_grs = torch.cat(self.q_grs, dim=0).float().to(self._device) + self.q_gavs = torch.cat(self.q_gavs, dim=0).float().to(self._device) + self.q_gvs = torch.cat(self.q_gvs, dim=0).float().to(self._device) + + lengths = self._motion_num_frames + lengths_shifted = lengths.roll(1) + lengths_shifted[0] = 0 + self.length_starts = lengths_shifted.cumsum(0) + self.motion_ids = torch.arange(len(motions), dtype=torch.long, device=self._device) + motion = motions[0] + self.num_bodies = motion.num_joints + + num_motions = self.num_motions() + total_len = self.get_total_length() + print(f"Loaded {num_motions:d} motions with a total length of {total_len:.3f}s and {self.gts.shape[0]} frames.") + return motions + + def num_motions(self): + return self._num_motions + + def get_total_length(self): + return sum(self._motion_lengths) + + # def update_sampling_weight(self): + # ## sampling weight based on success rate. + # # sampling_temp = 0.2 + # sampling_temp = 0.1 + # curr_termination_prob = 0.5 + + # curr_succ_rate = 1 - self._termination_history[self._curr_motion_ids] / self._sampling_history[self._curr_motion_ids] + # self._success_rate[self._curr_motion_ids] = curr_succ_rate + # sample_prob = torch.exp(-self._success_rate / sampling_temp) + + # self._sampling_prob = sample_prob / sample_prob.sum() + # self._termination_history[self._curr_motion_ids] = 0 + # self._sampling_history[self._curr_motion_ids] = 0 + + # topk_sampled = self._sampling_prob.topk(50) + # print("Current most sampled", self._motion_data_keys[topk_sampled.indices.cpu().numpy()]) + + def update_hard_sampling_weight(self, failed_keys): + # sampling weight based on evaluation, only trained on "failed" sequences. Auto PMCP. + if len(failed_keys) > 0: + all_keys = self._motion_data_keys.tolist() + indexes = [all_keys.index(k) for k in failed_keys] + self._sampling_prob[:] = 0 + self._sampling_prob[indexes] = 1/len(indexes) + print("############################################################ Auto PMCP ############################################################") + print(f"Training on only {len(failed_keys)} seqs") + print(failed_keys) + else: + all_keys = self._motion_data_keys.tolist() + self._sampling_prob = torch.ones(self._num_unique_motions).to(self._device) / self._num_unique_motions # For use in sampling batches + + def update_soft_sampling_weight(self, failed_keys): + # sampling weight based on evaluation, only "mostly" trained on "failed" sequences. Auto PMCP. + if len(failed_keys) > 0: + all_keys = self._motion_data_keys.tolist() + indexes = [all_keys.index(k) for k in failed_keys] + self._termination_history[indexes] += 1 + self.update_sampling_prob(self._termination_history) + + print("############################################################ Auto PMCP ############################################################") + print(f"Training mostly on {len(self._sampling_prob.nonzero())} seqs ") + print(self._motion_data_keys[self._sampling_prob.nonzero()].flatten()) + print(f"###############################################################################################################################") + else: + all_keys = self._motion_data_keys.tolist() + self._sampling_prob = torch.ones(self._num_unique_motions).to(self._device) / self._num_unique_motions # For use in sampling batches + + def update_sampling_prob(self, termination_history): + if len(termination_history) == len(self._termination_history) and termination_history.sum() > 0: + self._sampling_prob[:] = termination_history/termination_history.sum() + self._termination_history = termination_history + return True + else: + return False + + + # def update_sampling_history(self, env_ids): + # self._sampling_history[self._curr_motion_ids[env_ids]] += 1 + # # print("sampling history: ", self._sampling_history[self._curr_motion_ids]) + + # def update_termination_history(self, termination): + # self._termination_history[self._curr_motion_ids] += termination + # # print("termination history: ", self._termination_history[self._curr_motion_ids]) + + def sample_motions(self, n): + motion_ids = torch.multinomial(self._sampling_batch_prob, num_samples=n, replacement=True).to(self._device) + + return motion_ids + + def sample_time(self, motion_ids, truncate_time=None): + n = len(motion_ids) + phase = torch.rand(motion_ids.shape, device=self._device) + motion_len = self._motion_lengths[motion_ids] + if (truncate_time is not None): + assert (truncate_time >= 0.0) + motion_len -= truncate_time + + motion_time = phase * motion_len + return motion_time.to(self._device) + + def sample_time_interval(self, motion_ids, truncate_time=None): + phase = torch.rand(motion_ids.shape, device=self._device) + motion_len = self._motion_lengths[motion_ids] + if (truncate_time is not None): + assert (truncate_time >= 0.0) + motion_len -= truncate_time + curr_fps = 1 / 30 + motion_time = ((phase * motion_len) / curr_fps).long() * curr_fps + + return motion_time + + def get_motion_length(self, motion_ids=None): + if motion_ids is None: + return self._motion_lengths + else: + return self._motion_lengths[motion_ids] + + def get_motion_num_steps(self, motion_ids=None): + if motion_ids is None: + return (self._motion_num_frames * 30 / self._motion_fps).int() + else: + return (self._motion_num_frames[motion_ids] * 30 / self._motion_fps).int() + + def get_motion_state(self, motion_ids, motion_times, offset=None): + n = len(motion_ids) + num_bodies = self._get_num_bodies() + + motion_len = self._motion_lengths[motion_ids] + num_frames = self._motion_num_frames[motion_ids] + dt = self._motion_dt[motion_ids] + + frame_idx0, frame_idx1, blend = self._calc_frame_blend(motion_times, motion_len, num_frames, dt) + # print("non_interval", frame_idx0, frame_idx1) + f0l = frame_idx0 + self.length_starts[motion_ids] + f1l = frame_idx1 + self.length_starts[motion_ids] + + local_rot0 = self.lrs[f0l] + local_rot1 = self.lrs[f1l] + + body_vel0 = self.gvs[f0l] + body_vel1 = self.gvs[f1l] + + body_ang_vel0 = self.gavs[f0l] + body_ang_vel1 = self.gavs[f1l] + + rg_pos0 = self.gts[f0l, :] + rg_pos1 = self.gts[f1l, :] + + dof_vel0 = self.dvs[f0l] + dof_vel1 = self.dvs[f1l] + + vals = [local_rot0, local_rot1, body_vel0, body_vel1, body_ang_vel0, body_ang_vel1, rg_pos0, rg_pos1, dof_vel0, dof_vel1] + for v in vals: + assert v.dtype != torch.float64 + + blend = blend.unsqueeze(-1) + + blend_exp = blend.unsqueeze(-1) + + if offset is None: + rg_pos = (1.0 - blend_exp) * rg_pos0 + blend_exp * rg_pos1 # ZL: apply offset + else: + rg_pos = (1.0 - blend_exp) * rg_pos0 + blend_exp * rg_pos1 + offset[..., None, :] # ZL: apply offset + + body_vel = (1.0 - blend_exp) * body_vel0 + blend_exp * body_vel1 + body_ang_vel = (1.0 - blend_exp) * body_ang_vel0 + blend_exp * body_ang_vel1 + dof_vel = (1.0 - blend_exp) * dof_vel0 + blend_exp * dof_vel1 + + + local_rot = torch_utils.slerp(local_rot0, local_rot1, torch.unsqueeze(blend, axis=-1)) + dof_pos = self._local_rotation_to_dof_smpl(local_rot) + + rb_rot0 = self.grs[f0l] + rb_rot1 = self.grs[f1l] + rb_rot = torch_utils.slerp(rb_rot0, rb_rot1, blend_exp) + + if flags.real_traj: + q_body_ang_vel0, q_body_ang_vel1 = self.q_gavs[f0l], self.q_gavs[f1l] + q_rb_rot0, q_rb_rot1 = self.q_grs[f0l], self.q_grs[f1l] + q_rg_pos0, q_rg_pos1 = self.q_gts[f0l, :], self.q_gts[f1l, :] + q_body_vel0, q_body_vel1 = self.q_gvs[f0l], self.q_gvs[f1l] + + q_ang_vel = (1.0 - blend_exp) * q_body_ang_vel0 + blend_exp * q_body_ang_vel1 + q_rb_rot = torch_utils.slerp(q_rb_rot0, q_rb_rot1, blend_exp) + q_rg_pos = (1.0 - blend_exp) * q_rg_pos0 + blend_exp * q_rg_pos1 + q_body_vel = (1.0 - blend_exp) * q_body_vel0 + blend_exp * q_body_vel1 + + rg_pos[:, self.track_idx] = q_rg_pos + rb_rot[:, self.track_idx] = q_rb_rot + body_vel[:, self.track_idx] = q_body_vel + body_ang_vel[:, self.track_idx] = q_ang_vel + + return { + "root_pos": rg_pos[..., 0, :].clone(), + "root_rot": rb_rot[..., 0, :].clone(), + "dof_pos": dof_pos.clone(), + "root_vel": body_vel[..., 0, :].clone(), + "root_ang_vel": body_ang_vel[..., 0, :].clone(), + "dof_vel": dof_vel.view(dof_vel.shape[0], -1), + "motion_aa": self._motion_aa[f0l], + "rg_pos": rg_pos, + "rb_rot": rb_rot, + "body_vel": body_vel, + "body_ang_vel": body_ang_vel, + "motion_bodies": self._motion_bodies[motion_ids], + "motion_limb_weights": self._motion_limb_weights[motion_ids], + } + + def get_root_pos_smpl(self, motion_ids, motion_times): + n = len(motion_ids) + num_bodies = self._get_num_bodies() + + motion_len = self._motion_lengths[motion_ids] + num_frames = self._motion_num_frames[motion_ids] + dt = self._motion_dt[motion_ids] + + frame_idx0, frame_idx1, blend = self._calc_frame_blend(motion_times, motion_len, num_frames, dt) + # print("non_interval", frame_idx0, frame_idx1) + f0l = frame_idx0 + self.length_starts[motion_ids] + f1l = frame_idx1 + self.length_starts[motion_ids] + + rg_pos0 = self.gts[f0l, :] + rg_pos1 = self.gts[f1l, :] + + vals = [rg_pos0, rg_pos1] + for v in vals: + assert v.dtype != torch.float64 + + blend = blend.unsqueeze(-1) + + blend_exp = blend.unsqueeze(-1) + + rg_pos = (1.0 - blend_exp) * rg_pos0 + blend_exp * rg_pos1 # ZL: apply offset + return {"root_pos": rg_pos[..., 0, :].clone()} + + def _calc_frame_blend(self, time, len, num_frames, dt): + time = time.clone() + phase = time / len + phase = torch.clip(phase, 0.0, 1.0) # clip time to be within motion length. + time[time < 0] = 0 + + frame_idx0 = (phase * (num_frames - 1)).long() + frame_idx1 = torch.min(frame_idx0 + 1, num_frames - 1) + blend = torch.clip((time - frame_idx0 * dt) / dt, 0.0, 1.0) # clip blend to be within 0 and 1 + + return frame_idx0, frame_idx1, blend + + def _get_num_bodies(self): + return self.num_bodies + + def _local_rotation_to_dof_smpl(self, local_rot): + B, J, _ = local_rot.shape + dof_pos = torch_utils.quat_to_exp_map(local_rot[:, 1:]) + return dof_pos.reshape(B, -1) \ No newline at end of file diff --git a/phc/utils/motion_lib_smpl.py b/phc/utils/motion_lib_smpl.py new file mode 100644 index 0000000..a2e0eb6 --- /dev/null +++ b/phc/utils/motion_lib_smpl.py @@ -0,0 +1,170 @@ + + +import numpy as np +import os +import yaml +from tqdm import tqdm +import os.path as osp + +from phc.utils import torch_utils +import joblib +import torch +from poselib.poselib.skeleton.skeleton3d import SkeletonMotion, SkeletonState +import torch.multiprocessing as mp +import copy +import gc +from smpl_sim.smpllib.smpl_parser import ( + SMPL_Parser, + SMPLH_Parser, + SMPLX_Parser, +) +from scipy.spatial.transform import Rotation as sRot +import random +from phc.utils.flags import flags +from phc.utils.motion_lib_base import MotionLibBase, DeviceCache, compute_motion_dof_vels, FixHeightMode +from smpl_sim.utils.torch_ext import to_torch + +USE_CACHE = False +print("MOVING MOTION DATA TO GPU, USING CACHE:", USE_CACHE) + +if not USE_CACHE: + old_numpy = torch.Tensor.numpy + + class Patch: + + def numpy(self): + if self.is_cuda: + return self.to("cpu").numpy() + else: + return old_numpy(self) + + torch.Tensor.numpy = Patch.numpy + + + + +class MotionLibSMPL(MotionLibBase): + + def __init__(self, motion_lib_cfg): + super().__init__(motion_lib_cfg = motion_lib_cfg) + + data_dir = "data/smpl" + + if osp.exists(data_dir): + if motion_lib_cfg.smpl_type == "smpl": + smpl_parser_n = SMPL_Parser(model_path=data_dir, gender="neutral") + smpl_parser_m = SMPL_Parser(model_path=data_dir, gender="male") + smpl_parser_f = SMPL_Parser(model_path=data_dir, gender="female") + self.mesh_parsers = {0: smpl_parser_n, 1: smpl_parser_m, 2: smpl_parser_f} + else: + self.mesh_parsers = None + + return + + @staticmethod + def fix_trans_height(pose_aa, trans, curr_gender_betas, mesh_parsers, fix_height_mode): + if fix_height_mode == FixHeightMode.no_fix: + return trans, 0 + + with torch.no_grad(): + frame_check = 30 + gender = curr_gender_betas[0] + betas = curr_gender_betas[1:] + mesh_parser = mesh_parsers[gender.item()] + height_tolorance = 0.0 + vertices_curr, joints_curr = mesh_parser.get_joints_verts(pose_aa[:frame_check], betas[None,], trans[:frame_check]) + + offset = joints_curr[:, 0] - trans[:frame_check] # account for SMPL root offset. since the root trans we pass in has been processed, we have to "add it back". + + if fix_height_mode == FixHeightMode.ankle_fix: + assignment_indexes = mesh_parser.lbs_weights.argmax(axis=1) + pick = (((assignment_indexes != mesh_parser.joint_names.index("L_Toe")).int() + (assignment_indexes != mesh_parser.joint_names.index("R_Toe")).int() + + (assignment_indexes != mesh_parser.joint_names.index("R_Hand")).int() + + (assignment_indexes != mesh_parser.joint_names.index("L_Hand")).int()) == 4).nonzero().squeeze() + diff_fix = ((vertices_curr[:, pick] - offset[:, None])[:frame_check, ..., -1].min(dim=-1).values - height_tolorance).min() # Only acount the first 30 frames, which usually is a calibration phase. + elif fix_height_mode == FixHeightMode.full_fix: + + diff_fix = ((vertices_curr - offset[:, None])[:frame_check, ..., -1].min(dim=-1).values - height_tolorance).min() # Only acount the first 30 frames, which usually is a calibration phase. + + + + trans[..., -1] -= diff_fix + return trans, diff_fix + + @staticmethod + def load_motion_with_skeleton(ids, motion_data_list, skeleton_trees, shape_params, mesh_parsers, config, queue, pid): + # ZL: loading motion with the specified skeleton. Perfoming forward kinematics to get the joint positions + max_len = config.max_length + fix_height = config.fix_height + np.random.seed(np.random.randint(5000)* pid) + res = {} + assert (len(ids) == len(motion_data_list)) + for f in range(len(motion_data_list)): + curr_id = ids[f] # id for this datasample + curr_file = motion_data_list[f] + if not isinstance(curr_file, dict) and osp.isfile(curr_file): + key = motion_data_list[f].split("/")[-1].split(".")[0] + curr_file = joblib.load(curr_file)[key] + curr_gender_beta = shape_params[f] + + seq_len = curr_file['root_trans_offset'].shape[0] + if max_len == -1 or seq_len < max_len: + start, end = 0, seq_len + else: + start = random.randint(0, seq_len - max_len) + end = start + max_len + + trans = curr_file['root_trans_offset'].clone()[start:end] + pose_aa = to_torch(curr_file['pose_aa'][start:end]) + pose_quat_global = curr_file['pose_quat_global'][start:end] + + + B, J, N = pose_quat_global.shape + + ##### ZL: randomize the heading ###### + if (not flags.im_eval) and (not flags.test): + # if True: + random_rot = np.zeros(3) + random_rot[2] = np.pi * (2 * np.random.random() - 1.0) + random_heading_rot = sRot.from_euler("xyz", random_rot) + pose_aa[:, :3] = torch.tensor((random_heading_rot * sRot.from_rotvec(pose_aa[:, :3])).as_rotvec()) + pose_quat_global = (random_heading_rot * sRot.from_quat(pose_quat_global.reshape(-1, 4))).as_quat().reshape(B, J, N) + trans = torch.matmul(trans, torch.from_numpy(random_heading_rot.as_matrix().T)) + ##### ZL: randomize the heading ###### + + if not mesh_parsers is None: + trans, trans_fix = MotionLibSMPL.fix_trans_height(pose_aa, trans, curr_gender_beta, mesh_parsers, fix_height_mode = fix_height) + else: + trans_fix = 0 + + pose_quat_global = to_torch(pose_quat_global) + sk_state = SkeletonState.from_rotation_and_root_translation(skeleton_trees[f], pose_quat_global, trans, is_local=False) + + curr_motion = SkeletonMotion.from_skeleton_state(sk_state, curr_file.get("fps", 30)) + curr_dof_vels = compute_motion_dof_vels(curr_motion) + + if flags.real_traj: + quest_sensor_data = to_torch(curr_file['quest_sensor_data']) + quest_trans = quest_sensor_data[..., :3] + quest_rot = quest_sensor_data[..., 3:] + + quest_trans[..., -1] -= trans_fix # Fix trans + + global_angular_vel = SkeletonMotion._compute_angular_velocity(quest_rot, time_delta=1 / curr_file['fps']) + linear_vel = SkeletonMotion._compute_velocity(quest_trans, time_delta=1 / curr_file['fps']) + quest_motion = {"global_angular_vel": global_angular_vel, "linear_vel": linear_vel, "quest_trans": quest_trans, "quest_rot": quest_rot} + curr_motion.quest_motion = quest_motion + + curr_motion.dof_vels = curr_dof_vels + curr_motion.gender_beta = curr_gender_beta + res[curr_id] = (curr_file, curr_motion) + + + + if not queue is None: + queue.put(res) + else: + return res + + + + \ No newline at end of file diff --git a/phc/utils/o3d_utils.py b/phc/utils/o3d_utils.py new file mode 100644 index 0000000..9e66ad0 --- /dev/null +++ b/phc/utils/o3d_utils.py @@ -0,0 +1,40 @@ +from datetime import datetime +import imageio + +def pause_func(action): + global paused + paused = not paused + print(f"Paused: {paused}") + return True + + +def reset_func(action): + global reset + reset = not reset + print(f"Reset: {reset}") + return True + + +def record_func(action): + global recording, writer + if not recording: + fps = 30 + curr_date_time = datetime.now().strftime('%Y-%m-%d-%H:%M:%S') + curr_video_file_name = f"output/renderings/o3d/{curr_date_time}-test.mp4" + writer = imageio.get_writer(curr_video_file_name, fps=fps, macro_block_size=None) + elif not writer is None: + writer.close() + writer = None + + recording = not recording + + print(f"Recording: {recording}") + return True + + +def zoom_func(action): + global control, curr_zoom + curr_zoom = curr_zoom * 0.9 + control.set_zoom(curr_zoom) + print(f"Reset: {reset}") + return True diff --git a/phc/utils/parse_task.py b/phc/utils/parse_task.py new file mode 100644 index 0000000..9fad008 --- /dev/null +++ b/phc/utils/parse_task.py @@ -0,0 +1,65 @@ +# Copyright (c) 2018-2023, NVIDIA Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from phc.env.tasks.humanoid import Humanoid +from phc.env.tasks.humanoid_amp import HumanoidAMP +from phc.env.tasks.humanoid_amp_getup import HumanoidAMPGetup +from phc.env.tasks.humanoid_im import HumanoidIm +from phc.env.tasks.humanoid_im_getup import HumanoidImGetup +from phc.env.tasks.humanoid_im_mcp import HumanoidImMCP +from phc.env.tasks.humanoid_im_mcp_getup import HumanoidImMCPGetup +from phc.env.tasks.vec_task_wrappers import VecTaskPythonWrapper +from phc.env.tasks.humanoid_im_demo import HumanoidImDemo +from phc.env.tasks.humanoid_im_mcp_demo import HumanoidImMCPDemo +from phc.env.tasks.humanoid_im_distill import HumanoidImDistill +from phc.env.tasks.humanoid_im_distill_getup import HumanoidImDistillGetup + +from isaacgym import rlgpu + +import json +import numpy as np + + +def warn_task_name(): + raise Exception("Unrecognized task!\nTask should be one of: [BallBalance, Cartpole, CartpoleYUp, Ant, Humanoid, Anymal, FrankaCabinet, Quadcopter, ShadowHand, ShadowHandLSTM, ShadowHandFFOpenAI, ShadowHandFFOpenAITest, ShadowHandOpenAI, ShadowHandOpenAITest, Ingenuity]") + + +def parse_task(args, cfg, cfg_train, sim_params): + + # create native task and pass custom config + device_id = args.device_id + rl_device = args.rl_device + + cfg["seed"] = cfg_train['params'].get("seed", -1) + cfg_task = cfg["env"] + cfg_task["seed"] = cfg["seed"] + + task = eval(args.task)(cfg=cfg, sim_params=sim_params, physics_engine=args.physics_engine, device_type=args.device, device_id=device_id, headless=args.headless) + env = VecTaskPythonWrapper(task, rl_device, cfg_train['params'].get("clip_observations", np.inf)) + + return task, env diff --git a/phc/utils/plot_script.py b/phc/utils/plot_script.py new file mode 100644 index 0000000..4af4391 --- /dev/null +++ b/phc/utils/plot_script.py @@ -0,0 +1,130 @@ +import math +import numpy as np +import matplotlib +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d import Axes3D +from matplotlib.animation import FuncAnimation, FFMpegFileWriter +from mpl_toolkits.mplot3d.art3d import Poly3DCollection +import mpl_toolkits.mplot3d.axes3d as p3 +# import cv2 +from textwrap import wrap + + +def list_cut_average(ll, intervals): + if intervals == 1: + return ll + + bins = math.ceil(len(ll) * 1.0 / intervals) + ll_new = [] + for i in range(bins): + l_low = intervals * i + l_high = l_low + intervals + l_high = l_high if l_high < len(ll) else len(ll) + ll_new.append(np.mean(ll[l_low:l_high])) + return ll_new + + +def plot_3d_motion(save_path, kinematic_tree, joints, title, figsize=(3, 3), fps=120, radius=3, + vis_mode='default', gt_joints=None): + matplotlib.use('Agg') + + title = '\n'.join(wrap(title, 20)) + + def init(): + ax.set_xlim3d([-radius / 2, radius / 2]) + ax.set_ylim3d([0, radius]) + ax.set_zlim3d([-radius / 3., radius * 2 / 3.]) + # print(title) + fig.suptitle(title, fontsize=10) + ax.grid(b=False) + + def plot_xzPlane(minx, maxx, miny, minz, maxz): + ## Plot a plane XZ + verts = [ + [minx, miny, minz], + [minx, miny, maxz], + [maxx, miny, maxz], + [maxx, miny, minz] + ] + xz_plane = Poly3DCollection([verts]) + xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5)) + ax.add_collection3d(xz_plane) + + # return ax + + # (seq_len, joints_num, 3) + data = joints.copy().reshape(len(joints), -1, 3) + if not gt_joints is None: + data_gt = gt_joints.copy().reshape(len(gt_joints), -1, 3) + + fig = plt.figure(figsize=figsize) + plt.tight_layout() + ax = p3.Axes3D(fig) + init() + MINS = data.min(axis=0).min(axis=0) + MAXS = data.max(axis=0).max(axis=0) + colors_blue = ["#4D84AA", "#5B9965", "#61CEB9", "#34C1E2", "#80B79A"] # GT color + colors_orange = ["#DD5A37", "#D69E00", "#B75A39", "#FF6D00", "#DDB50E"] # Generation color + colors = colors_orange + if vis_mode == 'upper_body': # lower body taken fixed to input motion + colors[0] = colors_blue[0] + colors[1] = colors_blue[1] + elif vis_mode == 'gt': + colors = colors_blue + + frame_number = data.shape[0] + # print(dataset.shape) + + height_offset = MINS[1] + + data[:, :, 1] -= height_offset + trajec = data[:, 0, [0, 2]] + data[..., 0] -= data[:, 0:1, 0] + data[..., 2] -= data[:, 0:1, 2] + + if not gt_joints is None: + data_gt[:, :, 1] -= height_offset + data_gt[..., 0] -= data_gt[:, 0:1, 0] + data_gt[..., 2] -= data_gt[:, 0:1, 2] + + # print(trajec.shape) + + def update(index): + # print(index) + # ax.lines = [] + # ax.collections = [] + ax.lines.clear() + ax.collections.clear() + + ax.view_init(elev=120, azim=-90) + ax.dist = 5 + # ax = + plot_xzPlane(MINS[0] - trajec[index, 0], MAXS[0] - trajec[index, 0], 0, MINS[2] - trajec[index, 1], + MAXS[2] - trajec[index, 1]) + # ax.scatter(dataset[index, :22, 0], dataset[index, :22, 1], dataset[index, :22, 2], color='black', s=3) + + used_colors = colors + for i, (chain, color) in enumerate(zip(kinematic_tree, used_colors)): + linewidth = 2 + ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth, color=color) + ax.scatter(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], color=color, s=50) + + if not gt_joints is None: + ax.plot3D(data_gt[index, chain, 0], data_gt[index, chain, 1], data_gt[index, chain, 2], linewidth=linewidth, color=colors_blue[i]) + ax.scatter(data_gt[index, chain, 0], data_gt[index, chain, 1], data_gt[index, chain, 2], color=colors_blue[i], s=50) + + + plt.axis('off') + ax.set_xticklabels([]) + ax.set_yticklabels([]) + ax.set_zticklabels([]) + + + ani = FuncAnimation(fig, update, frames=frame_number, interval=1000 / fps, repeat=False) + + # writer = FFMpegFileWriter(fps=fps) + ani.save(save_path, fps=fps) + # ani = FuncAnimation(fig, update, frames=frame_number, interval=1000 / fps, repeat=False, init_func=init) + # ani.save(save_path, writer='pillow', fps=1000 / fps) + + plt.close() \ No newline at end of file diff --git a/phc/utils/pytorch3d_transforms.py b/phc/utils/pytorch3d_transforms.py new file mode 100644 index 0000000..7b5b00f --- /dev/null +++ b/phc/utils/pytorch3d_transforms.py @@ -0,0 +1,676 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Union + +import torch +import torch.nn.functional as F + +Device = Union[str, torch.device] +""" +The transformation matrices returned from the functions in this file assume +the points on which the transformation will be applied are column vectors. +i.e. the R matrix is structured as + + R = [ + [Rxx, Rxy, Rxz], + [Ryx, Ryy, Ryz], + [Rzx, Rzy, Rzz], + ] # (3, 3) + +This matrix can be applied to column vectors by post multiplication +by the points e.g. + + points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point + transformed_points = R * points + +To apply the same matrix to points which are row vectors, the R matrix +can be transposed and pre multiplied by the points: + +e.g. + points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point + transformed_points = points * R.transpose(1, 0) +""" + + +def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def quaternion_to_matrix_ijkr(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + i, j, k, r = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Return a tensor where each element has the absolute value taken from the, + corresponding element of a, with sign taken from the corresponding + element of b. This is like the standard copysign floating-point operation, + but is not careful about negative 0 and NaN. + + Args: + a: source tensor. + b: tensor whose signs will be used, of the same shape as a. + + Returns: + Tensor of the same shape as a with the signs of b. + """ + signs_differ = (a < 0) != (b < 0) + return torch.where(signs_differ, -a, a) + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + +def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1) + + q_abs = _sqrt_positive_part(torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + )) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + torch.stack([q_abs[..., 0]**2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + torch.stack([m21 - m12, q_abs[..., 1]**2, m10 + m01, m02 + m20], dim=-1), + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2]**2, m12 + m21], dim=-1), + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3]**2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + + return quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,)) # pyre-ignore[16] + + +@torch.jit.script +def matrix_to_quaternion_ijkr(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1) + + q_abs = _sqrt_positive_part(torch.stack( + [ + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + 1.0 + m00 + m11 + m22, + ], + dim=-1, + )) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_ijkr = torch.stack( + [ + torch.stack([q_abs[..., 0]**2, m10 + m01, m02 + m20, m21 - m12], dim=-1), + torch.stack([m10 + m01, q_abs[..., 1]**2, m21 + m12, m02 - m20], dim=-1), + torch.stack([m02 + m20, m12 + m21, q_abs[..., 2]**2, m10 - m01], dim=-1), + torch.stack([m21 - m12, m02 - m20, m10 - m01, q_abs[..., 3]**2], dim=-1), + ], + dim=-2, + ) + + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_ijkr / (2.0 * q_abs[..., None].max(flr)) + + return quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,)) # pyre-ignore[16] + + +@torch.jit.script +def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1) + + q_abs = _sqrt_positive_part(torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + )) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + torch.stack([q_abs[..., 0]**2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + torch.stack([m21 - m12, q_abs[..., 1]**2, m10 + m01, m02 + m20], dim=-1), + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2]**2, m12 + m21], dim=-1), + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3]**2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + + return quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,)) # pyre-ignore[16] + + +def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + elif axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + elif axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + else: + raise ValueError("letter must be either X, Y or Z.") + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor: + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = [_axis_angle_rotation(c, e) for c, e in zip(convention, torch.unbind(euler_angles, -1))] + # return functools.reduce(torch.matmul, matrices) + return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2]) + + +def _angle_from_tan(axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool) -> torch.Tensor: + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in data as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def _index_from_letter(letter: str) -> int: + if letter == "X": + return 0 + if letter == "Y": + return 1 + if letter == "Z": + return 2 + raise ValueError("letter must be either X, Y or Z.") + + +def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to Euler angles in radians. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3). + """ + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin(matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)) + else: + central_angle = torch.acos(matrix[..., i0, i0]) + + o = ( + _angle_from_tan(convention[0], convention[1], matrix[..., i2], False, tait_bryan), + central_angle, + _angle_from_tan(convention[2], convention[1], matrix[..., i0, :], True, tait_bryan), + ) + return torch.stack(o, -1) + + +def random_quaternions(n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None) -> torch.Tensor: + """ + Generate random quaternions representing rotations, + i.e. versors with nonnegative real part. + + Args: + n: Number of quaternions in a batch to return. + dtype: Type to return. + device: Desired device of returned tensor. Default: + uses the current device for the default tensor type. + + Returns: + Quaternions as tensor of shape (N, 4). + """ + if isinstance(device, str): + device = torch.device(device) + o = torch.randn((n, 4), dtype=dtype, device=device) + s = (o * o).sum(1) + o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] + return o + + +def random_rotations(n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None) -> torch.Tensor: + """ + Generate random rotations as 3x3 rotation matrices. + + Args: + n: Number of rotation matrices in a batch to return. + dtype: Type to return. + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type. + + Returns: + Rotation matrices as tensor of shape (n, 3, 3). + """ + quaternions = random_quaternions(n, dtype=dtype, device=device) + return quaternion_to_matrix(quaternions) + + +def random_rotation(dtype: Optional[torch.dtype] = None, device: Optional[Device] = None) -> torch.Tensor: + """ + Generate a single random 3x3 rotation matrix. + + Args: + dtype: Type to return + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type + + Returns: + Rotation matrix as tensor of shape (3, 3). + """ + return random_rotations(1, dtype, device)[0] + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + + +def quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Multiply two quaternions. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions shape (..., 4). + """ + aw, ax, ay, az = torch.unbind(a, -1) + bw, bx, by, bz = torch.unbind(b, -1) + ow = aw * bw - ax * bx - ay * by - az * bz + ox = aw * bx + ax * bw + ay * bz - az * by + oy = aw * by - ax * bz + ay * bw + az * bx + oz = aw * bz + ax * by - ay * bx + az * bw + return torch.stack((ow, ox, oy, oz), -1) + + +def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Multiply two quaternions representing rotations, returning the quaternion + representing their composition, i.e. the versor with nonnegative real part. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions of shape (..., 4). + """ + ab = quaternion_raw_multiply(a, b) + return standardize_quaternion(ab) + + +def quaternion_invert(quaternion: torch.Tensor) -> torch.Tensor: + """ + Given a quaternion representing rotation, get the quaternion representing + its inverse. + + Args: + quaternion: Quaternions as tensor of shape (..., 4), with real part + first, which must be versors (unit quaternions). + + Returns: + The inverse, a tensor of quaternions of shape (..., 4). + """ + + scaling = torch.tensor([1, -1, -1, -1], device=quaternion.device) + return quaternion * scaling + + +def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Tensor: + """ + Apply the rotation given by a quaternion to a 3D point. + Usual torch rules for broadcasting apply. + + Args: + quaternion: Tensor of quaternions, real part first, of shape (..., 4). + point: Tensor of 3D points of shape (..., 3). + + Returns: + Tensor of rotated points of shape (..., 3). + """ + if point.size(-1) != 3: + raise ValueError(f"Points are not in 3D, {point.shape}.") + real_parts = point.new_zeros(point.shape[:-1] + (1,)) + point_as_quaternion = torch.cat((real_parts, point), -1) + out = quaternion_raw_multiply( + quaternion_raw_multiply(quaternion, point_as_quaternion), + quaternion_invert(quaternion), + ) + return out[..., 1:] + + +def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to axis/angle. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + + +def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as axis/angle to quaternions. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = angles * 0.5 + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = (torch.sin(half_angles[~small_angles]) / angles[~small_angles]) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = (0.5 - (angles[small_angles] * angles[small_angles]) / 48) + quaternions = torch.cat([torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1) + return quaternions + + +def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as quaternions to axis/angle. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) + half_angles = torch.atan2(norms, quaternions[..., :1]) + angles = 2 * half_angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = (torch.sin(half_angles[~small_angles]) / angles[~small_angles]) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = (0.5 - (angles[small_angles] * angles[small_angles]) / 48) + return quaternions[..., 1:] / sin_half_angles_over_angles + + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalization per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + + Returns: + 6D rotation representation, of size (*, 6) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + batch_dim = matrix.size()[:-2] + return matrix[..., :2, :].clone().reshape(batch_dim + (6,)) diff --git a/phc/utils/running_mean_std.py b/phc/utils/running_mean_std.py new file mode 100644 index 0000000..16e5cbd --- /dev/null +++ b/phc/utils/running_mean_std.py @@ -0,0 +1,128 @@ +import torch +import torch.nn as nn +import numpy as np +''' +updates statistic from a full data +''' + + +class RunningMeanStd(nn.Module): + + def __init__(self, + insize, + epsilon=1e-05, + per_channel=False, + norm_only=False): + super(RunningMeanStd, self).__init__() + print('RunningMeanStd: ', insize) + self.insize = insize + self.mean_size = insize[0] + self.epsilon = epsilon + + self.norm_only = norm_only + self.per_channel = per_channel + if per_channel: + if len(self.insize) == 3: + self.axis = [0, 2, 3] + if len(self.insize) == 2: + self.axis = [0, 2] + if len(self.insize) == 1: + self.axis = [0] + in_size = self.insize[0] + else: + self.axis = [0] + in_size = insize + + self.register_buffer("running_mean", + torch.zeros(in_size, dtype=torch.float64)) + self.register_buffer("running_var", + torch.ones(in_size, dtype=torch.float64)) + self.register_buffer("count", torch.ones((), dtype=torch.float64)) + + self.forzen = False + self.forzen_partial = False + + def freeze(self): + self.forzen = True + + def unfreeze(self): + self.forzen = False + + def freeze_partial(self, diff): + self.forzen_partial = True + self.diff = diff + + + def _update_mean_var_count_from_moments(self, mean, var, count, batch_mean, + batch_var, batch_count): + delta = batch_mean - mean + tot_count = count + batch_count + + new_mean = mean + delta * batch_count / tot_count + m_a = var * count + m_b = batch_var * batch_count + M2 = m_a + m_b + delta**2 * count * batch_count / tot_count + new_var = M2 / tot_count + new_count = tot_count + return new_mean, new_var, new_count + + def forward(self, input, unnorm=False): + # change shape + if self.per_channel: + if len(self.insize) == 3: + current_mean = self.running_mean.view( + [1, self.insize[0], 1, 1]).expand_as(input) + current_var = self.running_var.view([1, self.insize[0], 1,1]).expand_as(input) + if len(self.insize) == 2: + current_mean = self.running_mean.view([1, self.insize[0],1]).expand_as(input) + current_var = self.running_var.view([1, self.insize[0],1]).expand_as(input) + if len(self.insize) == 1: + current_mean = self.running_mean.view([1, self.insize[0]]).expand_as(input) + current_var = self.running_var.view([1, self.insize[0]]).expand_as(input) + else: + current_mean = self.running_mean + current_var = self.running_var + # get output + + if unnorm: + y = torch.clamp(input, min=-5.0, max=5.0) + y = torch.sqrt(current_var.float() + + self.epsilon) * y + current_mean.float() + else: + if self.norm_only: + y = input / torch.sqrt(current_var.float() + self.epsilon) + else: + y = (input - current_mean.float()) / torch.sqrt(current_var.float() + self.epsilon) + y = torch.clamp(y, min=-5.0, max=5.0) + + # update After normalization, so that the values used for training and testing are the same. + if self.training and not self.forzen: + mean = input.mean(self.axis) # along channel axis + var = input.var(self.axis) + new_mean, new_var, new_count = self._update_mean_var_count_from_moments(self.running_mean, self.running_var, self.count, mean, var, input.size()[0]) + if self.forzen_partial: + # Only update the last bit (futures) + self.running_mean[-self.diff:], self.running_var[-self.diff:], self.count = new_mean[-self.diff:], new_var[-self.diff:], new_count + else: + self.running_mean, self.running_var, self.count = new_mean, new_var, new_count + + return y + + +class RunningMeanStdObs(nn.Module): + + def __init__(self, + insize, + epsilon=1e-05, + per_channel=False, + norm_only=False): + assert (insize is dict) + super(RunningMeanStdObs, self).__init__() + self.running_mean_std = nn.ModuleDict({ + k: RunningMeanStd(v, epsilon, per_channel, norm_only) + for k, v in insize.items() + }) + + def forward(self, input, unnorm=False): + res = {k: self.running_mean_std(v, unnorm) for k, v in input.items()} + return res \ No newline at end of file diff --git a/phc/utils/torch_utils.py b/phc/utils/torch_utils.py new file mode 100644 index 0000000..3523660 --- /dev/null +++ b/phc/utils/torch_utils.py @@ -0,0 +1,260 @@ +# Copyright (c) 2018-2023, NVIDIA Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import numpy as np + +from isaacgym.torch_utils import * +import torch +from torch import nn +import phc.utils.pytorch3d_transforms as ptr +import torch.nn.functional as F + + +def project_to_norm(x, norm=5, z_type = "sphere"): + if z_type == "sphere": + x = x / (torch.norm(x, dim=-1, keepdim=True) / norm + 1e-8) + elif z_type == "uniform": + x = torch.clamp(x, -norm, norm) + return x + +@torch.jit.script +def my_quat_rotate(q, v): + shape = q.shape + q_w = q[:, -1] + q_vec = q[:, :3] + a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1) + b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0 + c = q_vec * \ + torch.bmm(q_vec.view(shape[0], 1, 3), v.view( + shape[0], 3, 1)).squeeze(-1) * 2.0 + return a + b + c + +@torch.jit.script +def quat_to_angle_axis(q): + # type: (Tensor) -> Tuple[Tensor, Tensor] + # computes axis-angle representation from quaternion q + # q must be normalized + min_theta = 1e-5 + qx, qy, qz, qw = 0, 1, 2, 3 + + sin_theta = torch.sqrt(1 - q[..., qw] * q[..., qw]) + angle = 2 * torch.acos(q[..., qw]) + angle = normalize_angle(angle) + sin_theta_expand = sin_theta.unsqueeze(-1) + axis = q[..., qx:qw] / sin_theta_expand + + mask = torch.abs(sin_theta) > min_theta + default_axis = torch.zeros_like(axis) + default_axis[..., -1] = 1 + + angle = torch.where(mask, angle, torch.zeros_like(angle)) + mask_expand = mask.unsqueeze(-1) + axis = torch.where(mask_expand, axis, default_axis) + return angle, axis + + +@torch.jit.script +def angle_axis_to_exp_map(angle, axis): + # type: (Tensor, Tensor) -> Tensor + # compute exponential map from axis-angle + angle_expand = angle.unsqueeze(-1) + exp_map = angle_expand * axis + return exp_map + + +@torch.jit.script +def quat_to_exp_map(q): + # type: (Tensor) -> Tensor + # compute exponential map from quaternion + # q must be normalized + angle, axis = quat_to_angle_axis(q) + exp_map = angle_axis_to_exp_map(angle, axis) + return exp_map + + +@torch.jit.script +def quat_to_tan_norm(q): + # type: (Tensor) -> Tensor + # represents a rotation using the tangent and normal vectors + ref_tan = torch.zeros_like(q[..., 0:3]) + ref_tan[..., 0] = 1 + tan = my_quat_rotate(q, ref_tan) + + ref_norm = torch.zeros_like(q[..., 0:3]) + ref_norm[..., -1] = 1 + norm = my_quat_rotate(q, ref_norm) + + norm_tan = torch.cat([tan, norm], dim=len(tan.shape) - 1) + return norm_tan + + +@torch.jit.script +def tan_norm_to_mat(tan_norm): + B = tan_norm.shape[0] + tan = tan_norm.view(-1, 2, 3)[:, 0] + norm = tan_norm.view(-1, 2, 3)[:, 1] + tan_n = F.normalize(tan, dim=-1) + + norm_n = norm - (tan_n * norm).sum(-1, keepdim=True) * tan_n + norm_n = F.normalize(norm_n, dim=-1) + + cross = torch.cross(norm_n, tan_n) + + rot_mat = torch.stack([tan_n, cross, norm_n], dim=-1).reshape(B, -1, 3, 3) + return rot_mat + + +@torch.jit.script +def tan_norm_to_quat(tan_norm): + B = tan_norm.shape[0] + rot_mat = tan_norm_to_mat(tan_norm) + quat_new = ptr.matrix_to_quaternion_ijkr(rot_mat).view(B, -1, 4) + return quat_new + + +@torch.jit.script +def euler_xyz_to_exp_map(roll, pitch, yaw): + # type: (Tensor, Tensor, Tensor) -> Tensor + q = quat_from_euler_xyz(roll, pitch, yaw) + exp_map = quat_to_exp_map(q) + return exp_map + + +@torch.jit.script +def exp_map_to_angle_axis(exp_map): + min_theta = 1e-5 + + angle = torch.norm(exp_map, dim=-1) + angle_exp = torch.unsqueeze(angle, dim=-1) + axis = exp_map / angle_exp + angle = normalize_angle(angle) + + default_axis = torch.zeros_like(exp_map) + default_axis[..., -1] = 1 + + mask = torch.abs(angle) > min_theta + angle = torch.where(mask, angle, torch.zeros_like(angle)) + mask_expand = mask.unsqueeze(-1) + axis = torch.where(mask_expand, axis, default_axis) + + return angle, axis + + +@torch.jit.script +def exp_map_to_quat(exp_map): + angle, axis = exp_map_to_angle_axis(exp_map) + q = quat_from_angle_axis(angle, axis) + return q + + +@torch.jit.script +def slerp(q0, q1, t): + # type: (Tensor, Tensor, Tensor) -> Tensor + cos_half_theta = torch.sum(q0 * q1, dim=-1) + + neg_mask = cos_half_theta < 0 + q1 = q1.clone() + q1[neg_mask] = -q1[neg_mask] + cos_half_theta = torch.abs(cos_half_theta) + cos_half_theta = torch.unsqueeze(cos_half_theta, dim=-1) + + half_theta = torch.acos(cos_half_theta) + sin_half_theta = torch.sqrt(1.0 - cos_half_theta * cos_half_theta) + + ratioA = torch.sin((1 - t) * half_theta) / sin_half_theta + ratioB = torch.sin(t * half_theta) / sin_half_theta + + new_q = ratioA * q0 + ratioB * q1 + + new_q = torch.where(torch.abs(sin_half_theta) < 0.001, 0.5 * q0 + 0.5 * q1, new_q) + new_q = torch.where(torch.abs(cos_half_theta) >= 1, q0, new_q) + + return new_q + + +@torch.jit.script +def calc_heading(q): + # type: (Tensor) -> Tensor + # calculate heading direction from quaternion + # the heading is the direction on the xy plane + # q must be normalized + # this is the x axis heading + ref_dir = torch.zeros_like(q[..., 0:3]) + ref_dir[..., 0] = 1 + rot_dir = my_quat_rotate(q, ref_dir) + + heading = torch.atan2(rot_dir[..., 1], rot_dir[..., 0]) + return heading + + +@torch.jit.script +def calc_heading_quat(q): + # type: (Tensor) -> Tensor + # calculate heading rotation from quaternion + # the heading is the direction on the xy plane + # q must be normalized + heading = calc_heading(q) + axis = torch.zeros_like(q[..., 0:3]) + axis[..., 2] = 1 + + heading_q = quat_from_angle_axis(heading, axis) + return heading_q + + +@torch.jit.script +def calc_heading_quat_inv(q): + # type: (Tensor) -> Tensor + # calculate heading rotation from quaternion + # the heading is the direction on the xy plane + # q must be normalized + heading = calc_heading(q) + axis = torch.zeros_like(q[..., 0:3]) + axis[..., 2] = 1 + + heading_q = quat_from_angle_axis(-heading, axis) + return heading_q + +def activation_facotry(act_name): + if act_name == 'relu': + return nn.ReLU + elif act_name == 'tanh': + return nn.Tanh + elif act_name == 'sigmoid': + return nn.Sigmoid + elif act_name == "elu": + return nn.ELU + elif act_name == "selu": + return nn.SELU + elif act_name == "silu": + return nn.SiLU + elif act_name == "gelu": + return nn.GELU + elif act_name == "softplus": + nn.Softplus + elif act_name == "None": + return nn.Identity \ No newline at end of file diff --git a/phc/utils/traj_generator.py b/phc/utils/traj_generator.py new file mode 100644 index 0000000..4d358ba --- /dev/null +++ b/phc/utils/traj_generator.py @@ -0,0 +1,193 @@ +# Copyright (c) 2018-2023, NVIDIA Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import numpy as np +import torch +import joblib +import random +from phc.utils.flags import flags + + +class TrajGenerator(): + + def __init__(self, num_envs, episode_dur, num_verts, device, dtheta_max, speed_min, speed_max, accel_max, sharp_turn_prob): + + self._device = device + self._dt = episode_dur / (num_verts - 1) + self._dtheta_max = dtheta_max + self._speed_min = speed_min + self._speed_max = speed_max + self._accel_max = accel_max + self._sharp_turn_prob = sharp_turn_prob + + self._verts_flat = torch.zeros((num_envs * num_verts, 3), dtype=torch.float32, device=self._device) + self._verts = self._verts_flat.view((num_envs, num_verts, 3)) + + env_ids = torch.arange(self.get_num_envs(), dtype=np.int) + + # self.traj_data = joblib.load("data/traj/traj_data.pkl") + self.heading = torch.zeros(num_envs, 1) + return + + def reset(self, env_ids, init_pos): + n = len(env_ids) + if (n > 0): + num_verts = self.get_num_verts() + dtheta = 2 * torch.rand([n, num_verts - 1], device=self._device) - 1.0 # Sample the angles at each waypoint + dtheta *= self._dtheta_max * self._dt + + dtheta_sharp = np.pi * (2 * torch.rand([n, num_verts - 1], device=self._device) - 1.0) # Sharp Angles Angle + sharp_probs = self._sharp_turn_prob * torch.ones_like(dtheta) + sharp_mask = torch.bernoulli(sharp_probs) == 1.0 + dtheta[sharp_mask] = dtheta_sharp[sharp_mask] + + dtheta[:, 0] = np.pi * (2 * torch.rand([n], device=self._device) - 1.0) # Heading + + dspeed = 2 * torch.rand([n, num_verts - 1], device=self._device) - 1.0 + dspeed *= self._accel_max * self._dt + dspeed[:, 0] = (self._speed_max - self._speed_min) * torch.rand([n], device=self._device) + self._speed_min # Speed + + speed = torch.zeros_like(dspeed) + speed[:, 0] = dspeed[:, 0] + for i in range(1, dspeed.shape[-1]): + speed[:, i] = torch.clip(speed[:, i - 1] + dspeed[:, i], self._speed_min, self._speed_max) + + ################################################ + if flags.fixed_path: + dtheta[:, :] = 0 # ZL: Hacking to make everything 0 + dtheta[0, 0] = 0 # ZL: Hacking to create collision + if len(dtheta) > 1: + dtheta[1, 0] = -np.pi # ZL: Hacking to create collision + speed[:] = (self._speed_min + self._speed_max) / 2 + ################################################ + + if flags.slow: + speed[:] = speed / 4 + + dtheta = torch.cumsum(dtheta, dim=-1) + + seg_len = speed * self._dt + + dpos = torch.stack([torch.cos(dtheta), -torch.sin(dtheta), torch.zeros_like(dtheta)], dim=-1) + dpos *= seg_len.unsqueeze(-1) + dpos[..., 0, 0:2] += init_pos[..., 0:2] + vert_pos = torch.cumsum(dpos, dim=-2) + + self._verts[env_ids, 0, 0:2] = init_pos[..., 0:2] + self._verts[env_ids, 1:] = vert_pos + + ####### ZL: Loading random real-world trajectories ####### + if flags.real_path: + rids = random.sample(self.traj_data.keys(), n) + traj = torch.stack([torch.from_numpy(self.traj_data[id]['coord_dense'])[:num_verts] for id in rids], dim=0).to(self._device).float() + + traj[..., 0:2] = traj[..., 0:2] - (traj[..., 0, 0:2] - init_pos[..., 0:2])[:, None] + self._verts[env_ids] = traj + + return + + def input_new_trajs(self, env_ids): + import json + import requests + from scipy.interpolate import interp1d + x = requests.get(f'http://{SERVER}:{PORT}/path?num_envs={len(env_ids)}') + + data_lists = [value for idx, value in x.json().items()] + coord = np.array(data_lists) + x = np.linspace(0, coord.shape[1] - 1, num=coord.shape[1]) + fx = interp1d(x, coord[..., 0], kind='linear') + fy = interp1d(x, coord[..., 1], kind='linear') + x4 = np.linspace(0, coord.shape[1] - 1, num=coord.shape[1] * 10) + coord_dense = np.stack([fx(x4), fy(x4), np.zeros([len(env_ids), x4.shape[0]])], axis=-1) + coord_dense = np.concatenate([coord_dense, coord_dense[..., -1:, :]], axis=-2) + self._verts[env_ids] = torch.from_numpy(coord_dense).float().to(env_ids.device) + return self._verts[env_ids] + + def get_num_verts(self): + return self._verts.shape[1] + + def get_num_segs(self): + return self.get_num_verts() - 1 + + def get_num_envs(self): + return self._verts.shape[0] + + def get_traj_duration(self): + num_verts = self.get_num_verts() + dur = num_verts * self._dt + return dur + + def get_traj_verts(self, traj_id): + return self._verts[traj_id] + + def calc_pos(self, traj_ids, times): + traj_dur = self.get_traj_duration() + num_verts = self.get_num_verts() + num_segs = self.get_num_segs() + + traj_phase = torch.clip(times / traj_dur, 0.0, 1.0) + seg_idx = traj_phase * num_segs + seg_id0 = torch.floor(seg_idx).long() + seg_id1 = torch.ceil(seg_idx).long() + lerp = seg_idx - seg_id0 + + pos0 = self._verts_flat[traj_ids * num_verts + seg_id0] + pos1 = self._verts_flat[traj_ids * num_verts + seg_id1] + + lerp = lerp.unsqueeze(-1) + pos = (1.0 - lerp) * pos0 + lerp * pos1 + + return pos + + def mock_calc_pos(self, env_ids, traj_ids, times, query_value_gradient): + traj_dur = self.get_traj_duration() + num_verts = self.get_num_verts() + num_segs = self.get_num_segs() + + traj_phase = torch.clip(times / traj_dur, 0.0, 1.0) + seg_idx = traj_phase * num_segs + seg_id0 = torch.floor(seg_idx).long() + seg_id1 = torch.ceil(seg_idx).long() + lerp = seg_idx - seg_id0 + + pos0 = self._verts_flat[traj_ids * num_verts + seg_id0] + pos1 = self._verts_flat[traj_ids * num_verts + seg_id1] + + lerp = lerp.unsqueeze(-1) + pos = (1.0 - lerp) * pos0 + lerp * pos1 + + new_obs, func = query_value_gradient(env_ids, pos) + if not new_obs is None: + # ZL: computes grad + with torch.enable_grad(): + new_obs.requires_grad_(True) + new_val = func(new_obs) + + disc_grad = torch.autograd.grad(new_val, new_obs, grad_outputs=torch.ones_like(new_val), create_graph=False, retain_graph=True, only_inputs=True) + + return pos diff --git a/poselib/.gitignore b/poselib/.gitignore new file mode 100644 index 0000000..a127971 --- /dev/null +++ b/poselib/.gitignore @@ -0,0 +1,107 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# vscode +.vscode/ diff --git a/poselib/README.md b/poselib/README.md new file mode 100644 index 0000000..d8fc8cb --- /dev/null +++ b/poselib/README.md @@ -0,0 +1,20 @@ +# poselib + +`poselib` is a library for loading, manipulating, and retargeting skeleton poses and motions. It is separated into three modules: `poselib.poselib.core` for basic data loading and tensor operations, `poselib.poselib.skeleton` for higher-level skeleton operations, and `poselib.poselib.visualization` for displaying skeleton poses. + +## poselib.poselib.core +- `poselib.poselib.core.rotation3d`: A set of Torch JIT functions for dealing with quaternions, transforms, and rotation/transformation matrices. + - `quat_*` manipulate and create quaternions in [x, y, z, w] format (where w is the real component). + - `transform_*` handle 7D transforms in [quat, pos] format. + - `rot_matrix_*` handle 3x3 rotation matrices. + - `euclidean_*` handle 4x4 Euclidean transformation matrices. +- `poselib.poselib.core.tensor_utils`: Provides loading and saving functions for PyTorch tensors. + +## poselib.poselib.skeleton +- `poselib.poselib.skeleton.skeleton3d`: Utilities for loading and manipulating skeleton poses, and retargeting poses to different skeletons. + - `SkeletonTree` is a class that stores a skeleton as a tree structure. + - `SkeletonState` describes the static state of a skeleton, and provides both global and local joint angles. + - `SkeletonMotion` describes a time-series of skeleton states and provides utilities for computing joint velocities. + +## poselib.poselib.visualization +- `poselib.poselib.visualization.common`: Functions used for visualizing skeletons interactively in `matplotlib`. diff --git a/poselib/__init__.py b/poselib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/poselib/poselib/__init__.py b/poselib/poselib/__init__.py new file mode 100644 index 0000000..fd13ae7 --- /dev/null +++ b/poselib/poselib/__init__.py @@ -0,0 +1,3 @@ +__version__ = "0.0.1" + +from .core import * diff --git a/poselib/poselib/core/__init__.py b/poselib/poselib/core/__init__.py new file mode 100644 index 0000000..e3c0f9d --- /dev/null +++ b/poselib/poselib/core/__init__.py @@ -0,0 +1,3 @@ +from .tensor_utils import * +from .rotation3d import * +from .backend import Serializable, logger diff --git a/poselib/poselib/core/backend/__init__.py b/poselib/poselib/core/backend/__init__.py new file mode 100644 index 0000000..49705b2 --- /dev/null +++ b/poselib/poselib/core/backend/__init__.py @@ -0,0 +1,3 @@ +from .abstract import Serializable + +from .logger import logger diff --git a/poselib/poselib/core/backend/abstract.py b/poselib/poselib/core/backend/abstract.py new file mode 100644 index 0000000..caef630 --- /dev/null +++ b/poselib/poselib/core/backend/abstract.py @@ -0,0 +1,128 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +from abc import ABCMeta, abstractmethod, abstractclassmethod +from collections import OrderedDict +import json + +import numpy as np +import os + +TENSOR_CLASS = {} + + +def register(name): + global TENSOR_CLASS + + def core(tensor_cls): + TENSOR_CLASS[name] = tensor_cls + return tensor_cls + + return core + + +def _get_cls(name): + global TENSOR_CLASS + return TENSOR_CLASS[name] + + +class NumpyEncoder(json.JSONEncoder): + """ Special json encoder for numpy types """ + + def default(self, obj): + if isinstance( + obj, + ( + np.int_, + np.intc, + np.intp, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + ), + ): + return int(obj) + elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)): + return float(obj) + elif isinstance(obj, (np.ndarray,)): + return dict(__ndarray__=obj.tolist(), dtype=str(obj.dtype), shape=obj.shape) + return json.JSONEncoder.default(self, obj) + + +def json_numpy_obj_hook(dct): + if isinstance(dct, dict) and "__ndarray__" in dct: + data = np.asarray(dct["__ndarray__"], dtype=dct["dtype"]) + return data.reshape(dct["shape"]) + return dct + + +class Serializable: + """ Implementation to read/write to file. + All class the is inherited from this class needs to implement to_dict() and + from_dict() + """ + + @abstractclassmethod + def from_dict(cls, dict_repr, *args, **kwargs): + """ Read the object from an ordered dictionary + + :param dict_repr: the ordered dictionary that is used to construct the object + :type dict_repr: OrderedDict + :param args, kwargs: the arguments that need to be passed into from_dict() + :type args, kwargs: additional arguments + """ + pass + + @abstractmethod + def to_dict(self): + """ Construct an ordered dictionary from the object + + :rtype: OrderedDict + """ + pass + + @classmethod + def from_file(cls, path, *args, **kwargs): + """ Read the object from a file (either .npy or .json) + + :param path: path of the file + :type path: string + :param args, kwargs: the arguments that need to be passed into from_dict() + :type args, kwargs: additional arguments + """ + if path.endswith(".json"): + with open(path, "r") as f: + d = json.load(f, object_hook=json_numpy_obj_hook) + elif path.endswith(".npy"): + d = np.load(path, allow_pickle=True).item() + else: + assert False, "failed to load {} from {}".format(cls.__name__, path) + assert d["__name__"] == cls.__name__, "the file belongs to {}, not {}".format( + d["__name__"], cls.__name__ + ) + return cls.from_dict(d, *args, **kwargs) + + def to_file(self, path: str) -> None: + """ Write the object to a file (either .npy or .json) + + :param path: path of the file + :type path: string + """ + if os.path.dirname(path) != "" and not os.path.exists(os.path.dirname(path)): + os.makedirs(os.path.dirname(path)) + d = self.to_dict() + d["__name__"] = self.__class__.__name__ + if path.endswith(".json"): + with open(path, "w") as f: + json.dump(d, f, cls=NumpyEncoder, indent=4) + elif path.endswith(".npy"): + np.save(path, d) diff --git a/poselib/poselib/core/backend/logger.py b/poselib/poselib/core/backend/logger.py new file mode 100644 index 0000000..369cae9 --- /dev/null +++ b/poselib/poselib/core/backend/logger.py @@ -0,0 +1,20 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import logging + +logger = logging.getLogger("poselib") +logger.setLevel(logging.INFO) + +if not len(logger.handlers): + formatter = logging.Formatter( + fmt="%(asctime)-15s - %(levelname)s - %(module)s - %(message)s" + ) + handler = logging.StreamHandler() + handler.setFormatter(formatter) + logger.addHandler(handler) + logger.info("logger initialized") diff --git a/poselib/poselib/core/rotation3d.py b/poselib/poselib/core/rotation3d.py new file mode 100644 index 0000000..afb7e3a --- /dev/null +++ b/poselib/poselib/core/rotation3d.py @@ -0,0 +1,473 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +from typing import List, Optional + +import math +import torch + + +@torch.jit.script +def quat_mul(a, b): + """ + quaternion multiplication + """ + x1, y1, z1, w1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3] + x2, y2, z2, w2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3] + + w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 + x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 + y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2 + z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2 + + return torch.stack([x, y, z, w], dim=-1) + + +@torch.jit.script +def quat_pos(x): + """ + make all the real part of the quaternion positive + """ + q = x + z = (q[..., 3:] < 0).float() + q = (1 - 2 * z) * q + return q + + +@torch.jit.script +def quat_abs(x): + """ + quaternion norm (unit quaternion represents a 3D rotation, which has norm of 1) + """ + x = x.norm(p=2, dim=-1) + return x + + +@torch.jit.script +def quat_unit(x): + """ + normalized quaternion with norm of 1 + """ + norm = quat_abs(x).unsqueeze(-1) + return x / (norm.clamp(min=1e-9)) + + +@torch.jit.script +def quat_conjugate(x): + """ + quaternion with its imaginary part negated + """ + return torch.cat([-x[..., :3], x[..., 3:]], dim=-1) + + +@torch.jit.script +def quat_real(x): + """ + real component of the quaternion + """ + return x[..., 3] + + +@torch.jit.script +def quat_imaginary(x): + """ + imaginary components of the quaternion + """ + return x[..., :3] + + +@torch.jit.script +def quat_norm_check(x): + """ + verify that a quaternion has norm 1 + """ + assert bool((abs(x.norm(p=2, dim=-1) - 1) < 1e-3).all()), "the quaternion is has non-1 norm: {}".format(abs(x.norm(p=2, dim=-1) - 1)) + assert bool((x[..., 3] >= 0).all()), "the quaternion has negative real part" + + +@torch.jit.script +def quat_normalize(q): + """ + Construct 3D rotation from quaternion (the quaternion needs not to be normalized). + """ + q = quat_unit(quat_pos(q)) # normalized to positive and unit quaternion + return q + + +@torch.jit.script +def quat_from_xyz(xyz): + """ + Construct 3D rotation from the imaginary component + """ + w = (1.0 - xyz.norm()).unsqueeze(-1) + assert bool((w >= 0).all()), "xyz has its norm greater than 1" + return torch.cat([xyz, w], dim=-1) + + +@torch.jit.script +def quat_identity(shape: List[int]): + """ + Construct 3D identity rotation given shape + """ + w = torch.ones(shape + [1]) + xyz = torch.zeros(shape + [3]) + q = torch.cat([xyz, w], dim=-1) + return quat_normalize(q) + + +@torch.jit.script +def quat_from_angle_axis(angle, axis, degree: bool = False): + """ Create a 3D rotation from angle and axis of rotation. The rotation is counter-clockwise + along the axis. + + The rotation can be interpreted as a_R_b where frame "b" is the new frame that + gets rotated counter-clockwise along the axis from frame "a" + + :param angle: angle of rotation + :type angle: Tensor + :param axis: axis of rotation + :type axis: Tensor + :param degree: put True here if the angle is given by degree + :type degree: bool, optional, default=False + """ + if degree: + angle = angle / 180.0 * math.pi + theta = (angle / 2).unsqueeze(-1) + axis = axis / (axis.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-9)) + xyz = axis * theta.sin() + w = theta.cos() + return quat_normalize(torch.cat([xyz, w], dim=-1)) + + +@torch.jit.script +def quat_from_rotation_matrix(m): + """ + Construct a 3D rotation from a valid 3x3 rotation matrices. + Reference can be found here: + http://www.cg.info.hiroshima-cu.ac.jp/~miyazaki/knowledge/teche52.html + + :param m: 3x3 orthogonal rotation matrices. + :type m: Tensor + + :rtype: Tensor + """ + m = m.unsqueeze(0) + diag0 = m[..., 0, 0] + diag1 = m[..., 1, 1] + diag2 = m[..., 2, 2] + + # Math stuff. + w = (((diag0 + diag1 + diag2 + 1.0) / 4.0).clamp(0.0, None))**0.5 + x = (((diag0 - diag1 - diag2 + 1.0) / 4.0).clamp(0.0, None))**0.5 + y = (((-diag0 + diag1 - diag2 + 1.0) / 4.0).clamp(0.0, None))**0.5 + z = (((-diag0 - diag1 + diag2 + 1.0) / 4.0).clamp(0.0, None))**0.5 + + # Only modify quaternions where w > x, y, z. + c0 = (w >= x) & (w >= y) & (w >= z) + x[c0] *= (m[..., 2, 1][c0] - m[..., 1, 2][c0]).sign() + y[c0] *= (m[..., 0, 2][c0] - m[..., 2, 0][c0]).sign() + z[c0] *= (m[..., 1, 0][c0] - m[..., 0, 1][c0]).sign() + + # Only modify quaternions where x > w, y, z + c1 = (x >= w) & (x >= y) & (x >= z) + w[c1] *= (m[..., 2, 1][c1] - m[..., 1, 2][c1]).sign() + y[c1] *= (m[..., 1, 0][c1] + m[..., 0, 1][c1]).sign() + z[c1] *= (m[..., 0, 2][c1] + m[..., 2, 0][c1]).sign() + + # Only modify quaternions where y > w, x, z. + c2 = (y >= w) & (y >= x) & (y >= z) + w[c2] *= (m[..., 0, 2][c2] - m[..., 2, 0][c2]).sign() + x[c2] *= (m[..., 1, 0][c2] + m[..., 0, 1][c2]).sign() + z[c2] *= (m[..., 2, 1][c2] + m[..., 1, 2][c2]).sign() + + # Only modify quaternions where z > w, x, y. + c3 = (z >= w) & (z >= x) & (z >= y) + w[c3] *= (m[..., 1, 0][c3] - m[..., 0, 1][c3]).sign() + x[c3] *= (m[..., 2, 0][c3] + m[..., 0, 2][c3]).sign() + y[c3] *= (m[..., 2, 1][c3] + m[..., 1, 2][c3]).sign() + + return quat_normalize(torch.stack([x, y, z, w], dim=-1)).squeeze(0) + + +@torch.jit.script +def quat_mul_norm(x, y): + """ + Combine two set of 3D rotations together using \**\* operator. The shape needs to be + broadcastable + """ + return quat_normalize(quat_mul(x, y)) + + +@torch.jit.script +def quat_rotate(rot, vec): + """ + Rotate a 3D vector with the 3D rotation + """ + other_q = torch.cat([vec, torch.zeros_like(vec[..., :1])], dim=-1) + return quat_imaginary(quat_mul(quat_mul(rot, other_q), quat_conjugate(rot))) + + +@torch.jit.script +def quat_inverse(x): + """ + The inverse of the rotation + """ + return quat_conjugate(x) + + +@torch.jit.script +def quat_identity_like(x): + """ + Construct identity 3D rotation with the same shape + """ + return quat_identity(x.shape[:-1]) + + +@torch.jit.script +def quat_angle_axis(x): + """ + The (angle, axis) representation of the rotation. The axis is normalized to unit length. + The angle is guaranteed to be between [0, pi]. + """ + s = 2 * (x[..., 3]**2) - 1 + angle = s.clamp(-1, 1).arccos() # just to be safe + axis = x[..., :3] + axis /= axis.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-9) + return angle, axis + + +@torch.jit.script +def quat_yaw_rotation(x, z_up: bool = True): + """ + Yaw rotation (rotation along z-axis) + """ + q = x + if z_up: + q = torch.cat([torch.zeros_like(q[..., 0:2]), q[..., 2:3], q[..., 3:]], dim=-1) + else: + q = torch.cat( + [ + torch.zeros_like(q[..., 0:1]), + q[..., 1:2], + torch.zeros_like(q[..., 2:3]), + q[..., 3:4], + ], + dim=-1, + ) + return quat_normalize(q) + + +@torch.jit.script +def transform_from_rotation_translation(r: Optional[torch.Tensor] = None, t: Optional[torch.Tensor] = None): + """ + Construct a transform from a quaternion and 3D translation. Only one of them can be None. + """ + assert r is not None or t is not None, "rotation and translation can't be all None" + if r is None: + assert t is not None + r = quat_identity(list(t.shape)) + if t is None: + t = torch.zeros(list(r.shape) + [3]) + return torch.cat([r, t], dim=-1) + + +@torch.jit.script +def transform_identity(shape: List[int]): + """ + Identity transformation with given shape + """ + r = quat_identity(shape) + t = torch.zeros(shape + [3]) + return transform_from_rotation_translation(r, t) + + +@torch.jit.script +def transform_rotation(x): + """Get rotation from transform""" + return x[..., :4] + + +@torch.jit.script +def transform_translation(x): + """Get translation from transform""" + return x[..., 4:] + + +@torch.jit.script +def transform_inverse(x): + """ + Inverse transformation + """ + inv_so3 = quat_inverse(transform_rotation(x)) + return transform_from_rotation_translation(r=inv_so3, t=quat_rotate(inv_so3, -transform_translation(x))) + + +@torch.jit.script +def transform_identity_like(x): + """ + identity transformation with the same shape + """ + return transform_identity(x.shape) + + +@torch.jit.script +def transform_mul(x, y): + """ + Combine two transformation together + """ + z = transform_from_rotation_translation( + r=quat_mul_norm(transform_rotation(x), transform_rotation(y)), + t=quat_rotate(transform_rotation(x), transform_translation(y)) + transform_translation(x), + ) + return z + + +@torch.jit.script +def transform_apply(rot, vec): + """ + Transform a 3D vector + """ + assert isinstance(vec, torch.Tensor) + return quat_rotate(transform_rotation(rot), vec) + transform_translation(rot) + + +@torch.jit.script +def rot_matrix_det(x): + """ + Return the determinant of the 3x3 matrix. The shape of the tensor will be as same as the + shape of the matrix + """ + a, b, c = x[..., 0, 0], x[..., 0, 1], x[..., 0, 2] + d, e, f = x[..., 1, 0], x[..., 1, 1], x[..., 1, 2] + g, h, i = x[..., 2, 0], x[..., 2, 1], x[..., 2, 2] + t1 = a * (e * i - f * h) + t2 = b * (d * i - f * g) + t3 = c * (d * h - e * g) + return t1 - t2 + t3 + + +@torch.jit.script +def rot_matrix_integrity_check(x): + """ + Verify that a rotation matrix has a determinant of one and is orthogonal + """ + det = rot_matrix_det(x) + assert bool((abs(det - 1) < 1e-3).all()), "the matrix has non-one determinant" + rtr = x @ x.permute(torch.arange(x.dim() - 2), -1, -2) + rtr_gt = rtr.zeros_like() + rtr_gt[..., 0, 0] = 1 + rtr_gt[..., 1, 1] = 1 + rtr_gt[..., 2, 2] = 1 + assert bool(((rtr - rtr_gt) < 1e-3).all()), "the matrix is not orthogonal" + + +# @torch.jit.script +# def rot_matrix_from_quaternion(q): +# """ +# Construct rotation matrix from quaternion +# x, y, z, w convension +# """ +# print("!!!!!!! This function does well-formed rotation matrices!!!") +# # Shortcuts for individual elements (using wikipedia's convention) +# qi, qj, qk, qr = q[..., 0], q[..., 1], q[..., 2], q[..., 3] + +# # Set individual elements +# R00 = 1.0 - 2.0 * (qj**2 + qk**2) +# R01 = 2 * (qi * qj - qk * qr) +# R02 = 2 * (qi * qk + qj * qr) +# R10 = 2 * (qi * qj + qk * qr) +# R11 = 1.0 - 2.0 * (qi**2 + qk**2) +# R12 = 2 * (qj * qk - qi * qr) +# R20 = 2 * (qi * qk - qj * qr) +# R21 = 2 * (qj * qk + qi * qr) +# R22 = 1.0 - 2.0 * (qi**2 + qj**2) + +# R0 = torch.stack([R00, R01, R02], dim=-1) +# R1 = torch.stack([R10, R11, R12], dim=-1) +# R2 = torch.stack([R10, R21, R22], dim=-1) + +# R = torch.stack([R0, R1, R2], dim=-2) + +# return R + + +@torch.jit.script +def rot_matrix_from_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + i, j, k, r = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +@torch.jit.script +def euclidean_to_rotation_matrix(x): + """ + Get the rotation matrix on the top-left corner of a Euclidean transformation matrix + """ + return x[..., :3, :3] + + +@torch.jit.script +def euclidean_integrity_check(x): + euclidean_to_rotation_matrix(x) # check 3d-rotation matrix + assert bool((x[..., 3, :3] == 0).all()), "the last row is illegal" + assert bool((x[..., 3, 3] == 1).all()), "the last row is illegal" + + +@torch.jit.script +def euclidean_translation(x): + """ + Get the translation vector located at the last column of the matrix + """ + return x[..., :3, 3] + + +@torch.jit.script +def euclidean_inverse(x): + """ + Compute the matrix that represents the inverse rotation + """ + s = x.zeros_like() + irot = quat_inverse(quat_from_rotation_matrix(x)) + s[..., :3, :3] = irot + s[..., :3, 4] = quat_rotate(irot, -euclidean_translation(x)) + return s + + +@torch.jit.script +def euclidean_to_transform(transformation_matrix): + """ + Construct a transform from a Euclidean transformation matrix + """ + return transform_from_rotation_translation( + r=quat_from_rotation_matrix(m=euclidean_to_rotation_matrix(transformation_matrix)), + t=euclidean_translation(transformation_matrix), + ) diff --git a/poselib/poselib/core/tensor_utils.py b/poselib/poselib/core/tensor_utils.py new file mode 100644 index 0000000..2646556 --- /dev/null +++ b/poselib/poselib/core/tensor_utils.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +from collections import OrderedDict +from .backend import Serializable +import torch + + +class TensorUtils(Serializable): + @classmethod + def from_dict(cls, dict_repr, *args, **kwargs): + """ Read the object from an ordered dictionary + + :param dict_repr: the ordered dictionary that is used to construct the object + :type dict_repr: OrderedDict + :param kwargs: the arguments that need to be passed into from_dict() + :type kwargs: additional arguments + """ + return torch.from_numpy(dict_repr["arr"].astype(dict_repr["context"]["dtype"])) + + def to_dict(self): + """ Construct an ordered dictionary from the object + + :rtype: OrderedDict + """ + return NotImplemented + +def tensor_to_dict(x): + """ Construct an ordered dictionary from the object + + :rtype: OrderedDict + """ + x_np = x.numpy() + return { + "arr": x_np, + "context": { + "dtype": x_np.dtype.name + } + } diff --git a/poselib/poselib/core/tests/__init__.py b/poselib/poselib/core/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/poselib/poselib/core/tests/test_rotation.py b/poselib/poselib/core/tests/test_rotation.py new file mode 100644 index 0000000..c5b6802 --- /dev/null +++ b/poselib/poselib/core/tests/test_rotation.py @@ -0,0 +1,56 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +from ..rotation3d import * +import numpy as np +import torch + +q = torch.from_numpy(np.array([[0, 1, 2, 3], [-2, 3, -1, 5]], dtype=np.float32)) +print("q", q) +r = quat_normalize(q) +x = torch.from_numpy(np.array([[1, 0, 0], [0, -1, 0]], dtype=np.float32)) +print(r) +print(quat_rotate(r, x)) + +angle = torch.from_numpy(np.array(np.random.rand() * 10.0, dtype=np.float32)) +axis = torch.from_numpy(np.array([1, np.random.rand() * 10.0, np.random.rand() * 10.0], dtype=np.float32),) + +print(repr(angle)) +print(repr(axis)) + +rot = quat_from_angle_axis(angle, axis) +x = torch.from_numpy(np.random.rand(5, 6, 3)) +y = quat_rotate(quat_inverse(rot), quat_rotate(rot, x)) +print(x.numpy()) +print(y.numpy()) +assert np.allclose(x.numpy(), y.numpy()) + +m = torch.from_numpy(np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]], dtype=np.float32)) +r = quat_from_rotation_matrix(m) +t = torch.from_numpy(np.array([0, 1, 0], dtype=np.float32)) +se3 = transform_from_rotation_translation(r=r, t=t) +print(se3) +print(transform_apply(se3, t)) + +rot = quat_from_angle_axis( + torch.from_numpy(np.array([45, -54], dtype=np.float32)), + torch.from_numpy(np.array([[1, 0, 0], [0, 1, 0]], dtype=np.float32)), + degree=True, +) +trans = torch.from_numpy(np.array([[1, 1, 0], [1, 1, 0]], dtype=np.float32)) +transform = transform_from_rotation_translation(r=rot, t=trans) + +t = transform_mul(transform, transform_inverse(transform)) +gt = np.zeros((2, 7)) +gt[:, 0] = 1.0 +print(t.numpy()) +print(gt) +# assert np.allclose(t.numpy(), gt) + +transform2 = torch.from_numpy(np.array([[1, 0, 0, 1], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=np.float32),) +transform2 = euclidean_to_transform(transform2) +print(transform2) diff --git a/poselib/poselib/skeleton/__init__.py b/poselib/poselib/skeleton/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/poselib/poselib/skeleton/backend/__init__.py b/poselib/poselib/skeleton/backend/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/poselib/poselib/skeleton/backend/fbx/__init__.py b/poselib/poselib/skeleton/backend/fbx/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/poselib/poselib/skeleton/backend/fbx/fbx_py27_backend.py b/poselib/poselib/skeleton/backend/fbx/fbx_py27_backend.py new file mode 100644 index 0000000..bee7f9f --- /dev/null +++ b/poselib/poselib/skeleton/backend/fbx/fbx_py27_backend.py @@ -0,0 +1,308 @@ +""" +Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + +NVIDIA CORPORATION and its licensors retain all intellectual property and proprietary +rights in and to this software, related documentation and any modifications thereto. Any +use, reproduction, disclosure or distribution of this software and related documentation +without an express license agreement from NVIDIA CORPORATION is strictly prohibited. +""" + +""" +This script reads an fbx file and saves the joint names, parents, and transforms to a +numpy array. + +NOTE: It must be run from python 2.7 with the fbx SDK installed. To use this script, +please use the read_fbx file +""" + +import sys + +import numpy as np + +try: + import fbx + import FbxCommon +except ImportError as e: + print("Error: FBX Import Failed. Message: {}".format(e)) + if sys.version_info[0] >= 3: + print( + "WARNING: you are using python 3 when this script should only be run from " + "python 2" + ) + else: + print( + "You are using python 2 but importing fbx failed. You must install it from " + "http://help.autodesk.com/view/FBX/2018/ENU/?guid=FBX_Developer_Help_" + "scripting_with_python_fbx_html" + ) + print("Exiting") + exit() + + +def fbx_to_npy(file_name_in, file_name_out, root_joint_name, fps): + """ + This function reads in an fbx file, and saves the relevant info to a numpy array + + Fbx files have a series of animation curves, each of which has animations at different + times. This script assumes that for mocap data, there is only one animation curve that + contains all the joints. Otherwise it is unclear how to read in the data. + + If this condition isn't met, then the method throws an error + + :param file_name_in: str, file path in. Should be .fbx file + :param file_name_out: str, file path out. Should be .npz file + :return: nothing, it just writes a file. + """ + + # Create the fbx scene object and load the .fbx file + fbx_sdk_manager, fbx_scene = FbxCommon.InitializeSdkObjects() + FbxCommon.LoadScene(fbx_sdk_manager, fbx_scene, file_name_in) + + """ + To read in the animation, we must find the root node of the skeleton. + + Unfortunately fbx files can have "scene parents" and other parts of the tree that are + not joints + + As a crude fix, this reader just takes and finds the first thing which has an + animation curve attached + """ + + search_root = (root_joint_name is None or root_joint_name == "") + + # Get the root node of the skeleton, which is the child of the scene's root node + possible_root_nodes = [fbx_scene.GetRootNode()] + found_root_node = False + max_key_count = 0 + root_joint = None + while len(possible_root_nodes) > 0: + joint = possible_root_nodes.pop(0) + if not search_root: + if joint.GetName() == root_joint_name: + root_joint = joint + try: + curve, anim_layer = _get_animation_curve(joint, fbx_scene) + except RuntimeError: + curve = None + if curve is not None: + key_count = curve.KeyGetCount() + if key_count > max_key_count: + found_root_node = True + max_key_count = key_count + root_curve = curve + if search_root and not root_joint: + root_joint = joint + + if not search_root and curve is not None and root_joint is not None: + break + + for child_index in range(joint.GetChildCount()): + possible_root_nodes.append(joint.GetChild(child_index)) + + if not found_root_node: + raise RuntimeError("No root joint found!! Exiting") + + joint_list, joint_names, parents = _get_skeleton(root_joint) + + """ + Read in the transformation matrices of the animation, taking the scaling into account + """ + + anim_range, frame_count, frame_rate = _get_frame_count(fbx_scene) + + local_transforms = [] + #for frame in range(frame_count): + time_sec = anim_range.GetStart().GetSecondDouble() + time_range_sec = anim_range.GetStop().GetSecondDouble() - time_sec + fbx_fps = frame_count / time_range_sec + if fps != 120: + fbx_fps = fps + print("FPS: ", fbx_fps) + while time_sec < anim_range.GetStop().GetSecondDouble(): + fbx_time = fbx.FbxTime() + fbx_time.SetSecondDouble(time_sec) + fbx_time = fbx_time.GetFramedTime() + transforms_current_frame = [] + + # Fbx has a unique time object which you need + #fbx_time = root_curve.KeyGetTime(frame) + for joint in joint_list: + arr = np.array(_recursive_to_list(joint.EvaluateLocalTransform(fbx_time))) + scales = np.array(_recursive_to_list(joint.EvaluateLocalScaling(fbx_time))) + + lcl_trans = joint.LclTranslation.Get() + lcl_rot = joint.LclRotation.Get() + lcl_matrix = fbx.FbxAMatrix() + # lcl_matrix.SetR(fbx.FbxVector4(lcl_rot[0], lcl_rot[1], lcl_rot[2], 1.0)) + # lcl_matrix.SetT(fbx.FbxVector4(lcl_trans[0], lcl_trans[1], lcl_trans[2], 1.0)) + # lcl_matrix = np.array(_recursive_to_list(lcl_matrix)) + curve = joint.LclTranslation.GetCurve(anim_layer, "X") + transX = curve.Evaluate(fbx_time)[0] if curve else lcl_trans[0] + curve = joint.LclTranslation.GetCurve(anim_layer, "Y") + transY = curve.Evaluate(fbx_time)[0] if curve else lcl_trans[1] + curve = joint.LclTranslation.GetCurve(anim_layer, "Z") + transZ = curve.Evaluate(fbx_time)[0] if curve else lcl_trans[2] + + curve = joint.LclRotation.GetCurve(anim_layer, "X") + rotX = curve.Evaluate(fbx_time)[0] if curve else lcl_rot[0] + curve = joint.LclRotation.GetCurve(anim_layer, "Y") + rotY = curve.Evaluate(fbx_time)[0] if curve else lcl_rot[1] + curve = joint.LclRotation.GetCurve(anim_layer, "Z") + rotZ = curve.Evaluate(fbx_time)[0] if curve else lcl_rot[2] + + lcl_matrix.SetR(fbx.FbxVector4(rotX, rotY, rotZ, 1.0)) + lcl_matrix.SetT(fbx.FbxVector4(transX, transY, transZ, 1.0)) + lcl_matrix = np.array(_recursive_to_list(lcl_matrix)) + # if not np.allclose(scales[0:3], scales[0]): + # raise ValueError( + # "Different X, Y and Z scaling. Unsure how this should be handled. " + # "To solve this, look at this link and try to upgrade the script " + # "http://help.autodesk.com/view/FBX/2017/ENU/?guid=__files_GUID_10CDD" + # "63C_79C1_4F2D_BB28_AD2BE65A02ED_htm" + # ) + # Adjust the array for scaling + arr /= scales[0] + arr[3, 3] = 1.0 + lcl_matrix[3, 3] = 1.0 + transforms_current_frame.append(lcl_matrix) + local_transforms.append(transforms_current_frame) + + time_sec += (1.0/fbx_fps) + + local_transforms = np.array(local_transforms) + print("Frame Count: ", len(local_transforms)) + + # Write to numpy array + np.savez_compressed( + file_name_out, names=joint_names, parents=parents, transforms=local_transforms, fps=fbx_fps + ) + +def _get_frame_count(fbx_scene): + # Get the animation stacks and layers, in order to pull off animation curves later + num_anim_stacks = fbx_scene.GetSrcObjectCount( + FbxCommon.FbxCriteria.ObjectType(FbxCommon.FbxAnimStack.ClassId) + ) + # if num_anim_stacks != 1: + # raise RuntimeError( + # "More than one animation stack was found. " + # "This script must be modified to handle this case. Exiting" + # ) + if num_anim_stacks > 1: + index = 1 + else: + index = 0 + anim_stack = fbx_scene.GetSrcObject( + FbxCommon.FbxCriteria.ObjectType(FbxCommon.FbxAnimStack.ClassId), index + ) + + anim_range = anim_stack.GetLocalTimeSpan() + duration = anim_range.GetDuration() + fps = duration.GetFrameRate(duration.GetGlobalTimeMode()) + frame_count = duration.GetFrameCount(True) + + return anim_range, frame_count, fps + +def _get_animation_curve(joint, fbx_scene): + # Get the animation stacks and layers, in order to pull off animation curves later + num_anim_stacks = fbx_scene.GetSrcObjectCount( + FbxCommon.FbxCriteria.ObjectType(FbxCommon.FbxAnimStack.ClassId) + ) + # if num_anim_stacks != 1: + # raise RuntimeError( + # "More than one animation stack was found. " + # "This script must be modified to handle this case. Exiting" + # ) + if num_anim_stacks > 1: + index = 1 + else: + index = 0 + anim_stack = fbx_scene.GetSrcObject( + FbxCommon.FbxCriteria.ObjectType(FbxCommon.FbxAnimStack.ClassId), index + ) + + num_anim_layers = anim_stack.GetSrcObjectCount( + FbxCommon.FbxCriteria.ObjectType(FbxCommon.FbxAnimLayer.ClassId) + ) + if num_anim_layers != 1: + raise RuntimeError( + "More than one animation layer was found. " + "This script must be modified to handle this case. Exiting" + ) + animation_layer = anim_stack.GetSrcObject( + FbxCommon.FbxCriteria.ObjectType(FbxCommon.FbxAnimLayer.ClassId), 0 + ) + + def _check_longest_curve(curve, max_curve_key_count): + longest_curve = None + if curve and curve.KeyGetCount() > max_curve_key_count[0]: + max_curve_key_count[0] = curve.KeyGetCount() + return True + + return False + + max_curve_key_count = [0] + longest_curve = None + for c in ["X", "Y", "Z"]: + curve = joint.LclTranslation.GetCurve( + animation_layer, c + ) # sample curve for translation + if _check_longest_curve(curve, max_curve_key_count): + longest_curve = curve + + curve = joint.LclRotation.GetCurve( + animation_layer, "X" + ) + if _check_longest_curve(curve, max_curve_key_count): + longest_curve = curve + + return longest_curve, animation_layer + + +def _get_skeleton(root_joint): + + # Do a depth first search of the skeleton to extract all the joints + joint_list = [root_joint] + joint_names = [root_joint.GetName()] + parents = [-1] # -1 means no parent + + def append_children(joint, pos): + """ + Depth first search function + :param joint: joint item in the fbx + :param pos: position of current element (for parenting) + :return: Nothing + """ + for child_index in range(joint.GetChildCount()): + child = joint.GetChild(child_index) + joint_list.append(child) + joint_names.append(child.GetName()) + parents.append(pos) + append_children(child, len(parents) - 1) + + append_children(root_joint, 0) + return joint_list, joint_names, parents + + +def _recursive_to_list(array): + """ + Takes some iterable that might contain iterables and converts it to a list of lists + [of lists... etc] + + Mainly used for converting the strange fbx wrappers for c++ arrays into python lists + :param array: array to be converted + :return: array converted to lists + """ + try: + return float(array) + except TypeError: + return [_recursive_to_list(a) for a in array] + + +if __name__ == "__main__": + + # Read in the input and output files, then read the fbx + file_name_in, file_name_out = sys.argv[1:3] + root_joint_name = sys.argv[3] + fps = int(sys.argv[4]) + + fbx_to_npy(file_name_in, file_name_out, root_joint_name, fps) diff --git a/poselib/poselib/skeleton/backend/fbx/fbx_read_wrapper.py b/poselib/poselib/skeleton/backend/fbx/fbx_read_wrapper.py new file mode 100644 index 0000000..1dbd5d4 --- /dev/null +++ b/poselib/poselib/skeleton/backend/fbx/fbx_read_wrapper.py @@ -0,0 +1,75 @@ +""" +Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + +NVIDIA CORPORATION and its licensors retain all intellectual property and proprietary +rights in and to this software, related documentation and any modifications thereto. Any +use, reproduction, disclosure or distribution of this software and related documentation +without an express license agreement from NVIDIA CORPORATION is strictly prohibited. +""" + +""" +Script that reads in fbx files from python 2 + +This requires a configs file, which contains the command necessary to switch conda +environments to run the fbx reading script from python 2 +""" + +from ....core import logger + +import inspect +import os + +import numpy as np + +# Get the current folder to import the config file +current_folder = os.path.realpath( + os.path.abspath(os.path.split(inspect.getfile(inspect.currentframe()))[0]) +) + + +def fbx_to_array(fbx_file_path, fbx_configs, root_joint, fps): + """ + Reads an fbx file to an array. + + Currently reading of the frame time is not supported. 120 fps is hard coded TODO + + :param fbx_file_path: str, file path to fbx + :return: tuple with joint_names, parents, transforms, frame time + """ + + # Ensure the file path is valid + fbx_file_path = os.path.abspath(fbx_file_path) + assert os.path.exists(fbx_file_path) + + # Switch directories to the utils folder to ensure the reading works + previous_cwd = os.getcwd() + os.chdir(current_folder) + + # Call the python 2.7 script + temp_file_path = os.path.join(current_folder, fbx_configs["tmp_path"]) + python_path = fbx_configs["fbx_py27_path"] + logger.info("executing python script to read fbx data using Autodesk FBX SDK...") + command = '{} fbx_py27_backend.py "{}" "{}" "{}" "{}"'.format( + python_path, fbx_file_path, temp_file_path, root_joint, fps + ) + logger.debug("executing command: {}".format(command)) + os.system(command) + logger.info( + "executing python script to read fbx data using Autodesk FBX SDK... done" + ) + + with open(temp_file_path, "rb") as f: + data = np.load(f) + output = ( + data["names"], + data["parents"], + data["transforms"], + data["fps"], + ) + + # Remove the temporary file + os.remove(temp_file_path) + + # Return the os to its previous cwd, otherwise reading multiple files might fail + os.chdir(previous_cwd) + return output diff --git a/poselib/poselib/skeleton/skeleton3d.py b/poselib/poselib/skeleton/skeleton3d.py new file mode 100644 index 0000000..5c56ee3 --- /dev/null +++ b/poselib/poselib/skeleton/skeleton3d.py @@ -0,0 +1,1264 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import os +import xml.etree.ElementTree as ET +from collections import OrderedDict +from typing import List, Optional, Type, Dict + +import numpy as np +import torch + +from ..core import * +from .backend.fbx.fbx_read_wrapper import fbx_to_array +import scipy.ndimage.filters as filters + + +class SkeletonTree(Serializable): + """ + A skeleton tree gives a complete description of a rigid skeleton. It describes a tree structure + over a list of nodes with their names indicated by strings. Each edge in the tree has a local + translation associated with it which describes the distance between the two nodes that it + connects. + + Basic Usage: + >>> t = SkeletonTree.from_mjcf(SkeletonTree.__example_mjcf_path__) + >>> t + SkeletonTree( + node_names=['torso', 'front_left_leg', 'aux_1', 'front_left_foot', 'front_right_leg', 'aux_2', 'front_right_foot', 'left_back_leg', 'aux_3', 'left_back_foot', 'right_back_leg', 'aux_4', 'right_back_foot'], + parent_indices=tensor([-1, 0, 1, 2, 0, 4, 5, 0, 7, 8, 0, 10, 11]), + local_translation=tensor([[ 0.0000, 0.0000, 0.7500], + [ 0.0000, 0.0000, 0.0000], + [ 0.2000, 0.2000, 0.0000], + [ 0.2000, 0.2000, 0.0000], + [ 0.0000, 0.0000, 0.0000], + [-0.2000, 0.2000, 0.0000], + [-0.2000, 0.2000, 0.0000], + [ 0.0000, 0.0000, 0.0000], + [-0.2000, -0.2000, 0.0000], + [-0.2000, -0.2000, 0.0000], + [ 0.0000, 0.0000, 0.0000], + [ 0.2000, -0.2000, 0.0000], + [ 0.2000, -0.2000, 0.0000]]) + ) + >>> t.node_names + ['torso', 'front_left_leg', 'aux_1', 'front_left_foot', 'front_right_leg', 'aux_2', 'front_right_foot', 'left_back_leg', 'aux_3', 'left_back_foot', 'right_back_leg', 'aux_4', 'right_back_foot'] + >>> t.parent_indices + tensor([-1, 0, 1, 2, 0, 4, 5, 0, 7, 8, 0, 10, 11]) + >>> t.local_translation + tensor([[ 0.0000, 0.0000, 0.7500], + [ 0.0000, 0.0000, 0.0000], + [ 0.2000, 0.2000, 0.0000], + [ 0.2000, 0.2000, 0.0000], + [ 0.0000, 0.0000, 0.0000], + [-0.2000, 0.2000, 0.0000], + [-0.2000, 0.2000, 0.0000], + [ 0.0000, 0.0000, 0.0000], + [-0.2000, -0.2000, 0.0000], + [-0.2000, -0.2000, 0.0000], + [ 0.0000, 0.0000, 0.0000], + [ 0.2000, -0.2000, 0.0000], + [ 0.2000, -0.2000, 0.0000]]) + >>> t.parent_of('front_left_leg') + 'torso' + >>> t.index('front_right_foot') + 6 + >>> t[2] + 'aux_1' + """ + + __example_mjcf_path__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "tests/ant.xml") + + def __init__(self, node_names, parent_indices, local_translation): + """ + :param node_names: a list of names for each tree node + :type node_names: List[str] + :param parent_indices: an int32-typed tensor that represents the edge to its parent.\ + -1 represents the root node + :type parent_indices: Tensor + :param local_translation: a 3d vector that gives local translation information + :type local_translation: Tensor + """ + ln, lp, ll = len(node_names), len(parent_indices), len(local_translation) + assert len(set((ln, lp, ll))) == 1 + self._node_names = node_names + self._parent_indices = parent_indices.long() + self._local_translation = local_translation + self._node_indices = {self.node_names[i]: i for i in range(len(self))} + + def __len__(self): + """ number of nodes in the skeleton tree """ + return len(self.node_names) + + def __iter__(self): + """ iterator that iterate through the name of each node """ + yield from self.node_names + + def __getitem__(self, item): + """ get the name of the node given the index """ + return self.node_names[item] + + def __repr__(self): + return ("SkeletonTree(\n node_names={},\n parent_indices={}," + "\n local_translation={}\n)".format( + self._indent(repr(self.node_names)), + self._indent(repr(self.parent_indices)), + self._indent(repr(self.local_translation)), + )) + + def _indent(self, s): + return "\n ".join(s.split("\n")) + + @property + def node_names(self): + return self._node_names + + @property + def parent_indices(self): + return self._parent_indices + + @property + def local_translation(self): + return self._local_translation + + @property + def num_joints(self): + """ number of nodes in the skeleton tree """ + return len(self) + + @classmethod + def from_dict(cls, dict_repr, *args, **kwargs): + return cls( + list(map(str, dict_repr["node_names"])), + TensorUtils.from_dict(dict_repr["parent_indices"], *args, **kwargs), + TensorUtils.from_dict(dict_repr["local_translation"], *args, **kwargs), + ) + + def to_dict(self): + return OrderedDict([ + ("node_names", self.node_names), + ("parent_indices", tensor_to_dict(self.parent_indices)), + ("local_translation", tensor_to_dict(self.local_translation)), + ]) + + @classmethod + def from_mjcf(cls, path: str) -> "SkeletonTree": + """ + Parses a mujoco xml scene description file and returns a Skeleton Tree. + We use the model attribute at the root as the name of the tree. + + :param path: + :type path: string + :return: The skeleton tree constructed from the mjcf file + :rtype: SkeletonTree + """ + tree = ET.parse(path) + xml_doc_root = tree.getroot() + xml_world_body = xml_doc_root.find("worldbody") + if xml_world_body is None: + raise ValueError("MJCF parsed incorrectly please verify it.") + # assume this is the root + xml_body_root = xml_world_body.find("body") + if xml_body_root is None: + raise ValueError("MJCF parsed incorrectly please verify it.") + + node_names = [] + parent_indices = [] + local_translation = [] + + # recursively adding all nodes into the skel_tree + def _add_xml_node(xml_node, parent_index, node_index): + node_name = xml_node.attrib.get("name") + # parse the local translation into float list + pos = np.fromstring(xml_node.attrib.get("pos"), dtype=float, sep=" ") + node_names.append(node_name) + parent_indices.append(parent_index) + local_translation.append(pos) + curr_index = node_index + node_index += 1 + for next_node in xml_node.findall("body"): + node_index = _add_xml_node(next_node, curr_index, node_index) + return node_index + + _add_xml_node(xml_body_root, -1, 0) + + return cls( + node_names, + torch.from_numpy(np.array(parent_indices, dtype=np.int32)), + torch.from_numpy(np.array(local_translation, dtype=np.float32)), + ) + + def parent_of(self, node_name): + """ get the name of the parent of the given node + + :param node_name: the name of the node + :type node_name: string + :rtype: string + """ + return self[int(self.parent_indices[self.index(node_name)].item())] + + def index(self, node_name): + """ get the index of the node + + :param node_name: the name of the node + :type node_name: string + :rtype: int + """ + return self._node_indices[node_name] + + def drop_nodes_by_names(self, node_names: List[str], pairwise_translation=None) -> "SkeletonTree": + new_length = len(self) - len(node_names) + new_node_names = [] + new_local_translation = torch.zeros(new_length, 3, dtype=self.local_translation.dtype) + new_parent_indices = torch.zeros(new_length, dtype=self.parent_indices.dtype) + parent_indices = self.parent_indices.numpy() + new_node_indices: dict = {} + new_node_index = 0 + for node_index in range(len(self)): + if self[node_index] in node_names: + continue + tb_node_index = parent_indices[node_index] + if tb_node_index != -1: + local_translation = self.local_translation[node_index, :] + while tb_node_index != -1 and self[tb_node_index] in node_names: + local_translation += self.local_translation[tb_node_index, :] + tb_node_index = parent_indices[tb_node_index] + assert tb_node_index != -1, "the root node cannot be dropped" + + if pairwise_translation is not None: + local_translation = pairwise_translation[tb_node_index, node_index, :] + else: + local_translation = self.local_translation[node_index, :] + + new_node_names.append(self[node_index]) + new_local_translation[new_node_index, :] = local_translation + if tb_node_index == -1: + new_parent_indices[new_node_index] = -1 + else: + new_parent_indices[new_node_index] = new_node_indices[self[tb_node_index]] + new_node_indices[self[node_index]] = new_node_index + new_node_index += 1 + + return SkeletonTree(new_node_names, new_parent_indices, new_local_translation) + + def keep_nodes_by_names(self, node_names: List[str], pairwise_translation=None) -> "SkeletonTree": + nodes_to_drop = list(filter(lambda x: x not in node_names, self)) + return self.drop_nodes_by_names(nodes_to_drop, pairwise_translation) + + +class SkeletonState(Serializable): + """ + A skeleton state contains all the information needed to describe a static state of a skeleton. + It requires a skeleton tree, local/global rotation at each joint and the root translation. + + Example: + >>> t = SkeletonTree.from_mjcf(SkeletonTree.__example_mjcf_path__) + >>> zero_pose = SkeletonState.zero_pose(t) + >>> plot_skeleton_state(zero_pose) # can be imported from `.visualization.common` + [plot of the ant at zero pose + >>> local_rotation = zero_pose.local_rotation.clone() + >>> local_rotation[2] = torch.tensor([0, 0, 1, 0]) + >>> new_pose = SkeletonState.from_rotation_and_root_translation( + ... skeleton_tree=t, + ... r=local_rotation, + ... t=zero_pose.root_translation, + ... is_local=True + ... ) + >>> new_pose.local_rotation + tensor([[0., 0., 0., 1.], + [0., 0., 0., 1.], + [0., 1., 0., 0.], + [0., 0., 0., 1.], + [0., 0., 0., 1.], + [0., 0., 0., 1.], + [0., 0., 0., 1.], + [0., 0., 0., 1.], + [0., 0., 0., 1.], + [0., 0., 0., 1.], + [0., 0., 0., 1.], + [0., 0., 0., 1.], + [0., 0., 0., 1.]]) + >>> plot_skeleton_state(new_pose) # you should be able to see one of ant's leg is bent + [plot of the ant with the new pose + >>> new_pose.global_rotation # the local rotation is propagated to the global rotation at joint #3 + tensor([[0., 0., 0., 1.], + [0., 0., 0., 1.], + [0., 1., 0., 0.], + [0., 1., 0., 0.], + [0., 0., 0., 1.], + [0., 0., 0., 1.], + [0., 0., 0., 1.], + [0., 0., 0., 1.], + [0., 0., 0., 1.], + [0., 0., 0., 1.], + [0., 0., 0., 1.], + [0., 0., 0., 1.], + [0., 0., 0., 1.]]) + + Global/Local Representation (cont. from the previous example) + >>> new_pose.is_local + True + >>> new_pose.tensor # this will return the local rotation followed by the root translation + tensor([0., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., + 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., + 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., + 0.]) + >>> new_pose.tensor.shape # 4 * 13 (joint rotation) + 3 (root translatio + torch.Size([55]) + >>> new_pose.global_repr().is_local + False + >>> new_pose.global_repr().tensor # this will return the global rotation followed by the root translation instead + tensor([0., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., + 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., + 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., + 0.]) + >>> new_pose.global_repr().tensor.shape # 4 * 13 (joint rotation) + 3 (root translation + torch.Size([55]) + """ + + def __init__(self, tensor_backend, skeleton_tree, is_local): + self._skeleton_tree = skeleton_tree + self._is_local = is_local + self.tensor = tensor_backend.clone() + + def __len__(self): + return self.tensor.shape[0] + + @property + def rotation(self): + if not hasattr(self, "_rotation"): + self._rotation = self.tensor[..., :self.num_joints * 4].reshape(*(self.tensor.shape[:-1] + (self.num_joints, 4))) + return self._rotation + + @property + def _local_rotation(self): + if self._is_local: + return self.rotation + else: + return None + + @property + def _global_rotation(self): + if not self._is_local: + return self.rotation + else: + return None + + @property + def is_local(self): + """ is the rotation represented in local frame? + + :rtype: bool + """ + return self._is_local + + @property + def invariant_property(self): + return {"skeleton_tree": self.skeleton_tree, "is_local": self.is_local} + + @property + def num_joints(self): + """ number of joints in the skeleton tree + + :rtype: int + """ + return self.skeleton_tree.num_joints + + @property + def skeleton_tree(self): + """ skeleton tree + + :rtype: SkeletonTree + """ + return self._skeleton_tree + + @property + def root_translation(self): + """ root translation + + :rtype: Tensor + """ + if not hasattr(self, "_root_translation"): + self._root_translation = self.tensor[..., self.num_joints * 4:self.num_joints * 4 + 3] + return self._root_translation + + @property + def global_transformation(self): + """ global transformation of each joint (transform from joint frame to global frame) """ + # Forward Kinematics + if not hasattr(self, "_global_transformation"): + local_transformation = self.local_transformation + global_transformation = [] + parent_indices = self.skeleton_tree.parent_indices.numpy() + # global_transformation = local_transformation.identity_like() + for node_index in range(len(self.skeleton_tree)): + parent_index = parent_indices[node_index] + if parent_index == -1: + global_transformation.append(local_transformation[..., node_index, :]) + else: + global_transformation.append(transform_mul( + global_transformation[parent_index], + local_transformation[..., node_index, :], + )) + self._global_transformation = torch.stack(global_transformation, axis=-2) + return self._global_transformation + + @property + def global_rotation(self): + """ global rotation of each joint (rotation matrix to rotate from joint's F.O.R to global + F.O.R) """ + if self._global_rotation is None: + if not hasattr(self, "_comp_global_rotation"): + self._comp_global_rotation = transform_rotation(self.global_transformation) + return self._comp_global_rotation + else: + return self._global_rotation + + @property + def global_translation(self): + """ global translation of each joint """ + if not hasattr(self, "_global_translation"): + self._global_translation = transform_translation(self.global_transformation) + return self._global_translation + + @property + def global_translation_xy(self): + """ global translation in xy """ + trans_xy_data = self.global_translation.zeros_like() + trans_xy_data[..., 0:2] = self.global_translation[..., 0:2] + return trans_xy_data + + @property + def global_translation_xz(self): + """ global translation in xz """ + trans_xz_data = self.global_translation.zeros_like() + trans_xz_data[..., 0:1] = self.global_translation[..., 0:1] + trans_xz_data[..., 2:3] = self.global_translation[..., 2:3] + return trans_xz_data + + @property + def local_rotation(self): + """ the rotation from child frame to parent frame given in the order of child nodes appeared + in `.skeleton_tree.node_names` """ + if self._local_rotation is None: + if not hasattr(self, "_comp_local_rotation"): + local_rotation = quat_identity_like(self.global_rotation) + for node_index in range(len(self.skeleton_tree)): + parent_index = self.skeleton_tree.parent_indices[node_index] + if parent_index == -1: + local_rotation[..., node_index, :] = self.global_rotation[..., node_index, :] + else: + local_rotation[..., node_index, :] = quat_mul_norm( + quat_inverse(self.global_rotation[..., parent_index, :]), + self.global_rotation[..., node_index, :], + ) + self._comp_local_rotation = local_rotation + return self._comp_local_rotation + else: + return self._local_rotation + + @property + def local_transformation(self): + """ local translation + local rotation. It describes the transformation from child frame to + parent frame given in the order of child nodes appeared in `.skeleton_tree.node_names` """ + if not hasattr(self, "_local_transformation"): + self._local_transformation = transform_from_rotation_translation(r=self.local_rotation, t=self.local_translation) + return self._local_transformation + + @property + def local_translation(self): + """ local translation of the skeleton state. It is identical to the local translation in + `.skeleton_tree.local_translation` except the root translation. The root translation is + identical to `.root_translation` """ + if not hasattr(self, "_local_translation"): + broadcast_shape = (tuple(self.tensor.shape[:-1]) + (len(self.skeleton_tree),) + tuple(self.skeleton_tree.local_translation.shape[-1:])) + local_translation = self.skeleton_tree.local_translation.broadcast_to(*broadcast_shape).clone() + local_translation[..., 0, :] = self.root_translation + self._local_translation = local_translation + return self._local_translation + + # Root Properties + @property + def root_translation_xy(self): + """ root translation on xy """ + if not hasattr(self, "_root_translation_xy"): + self._root_translation_xy = self.global_translation_xy[..., 0, :] + return self._root_translation_xy + + @property + def global_root_rotation(self): + """ root rotation """ + if not hasattr(self, "_global_root_rotation"): + self._global_root_rotation = self.global_rotation[..., 0, :] + return self._global_root_rotation + + @property + def global_root_yaw_rotation(self): + """ root yaw rotation """ + if not hasattr(self, "_global_root_yaw_rotation"): + self._global_root_yaw_rotation = self.global_root_rotation.yaw_rotation() + return self._global_root_yaw_rotation + + # Properties relative to root + @property + def local_translation_to_root(self): + """ The 3D translation from joint frame to the root frame. """ + if not hasattr(self, "_local_translation_to_root"): + self._local_translation_to_root = (self.global_translation - self.root_translation.unsqueeze(-1)) + return self._local_translation_to_root + + @property + def local_rotation_to_root(self): + """ The 3D rotation from joint frame to the root frame. It is equivalent to + The root_R_world * world_R_node """ + return (quat_inverse(self.global_root_rotation).unsqueeze(-1) * self.global_rotation) + + def compute_forward_vector( + self, + left_shoulder_index, + right_shoulder_index, + left_hip_index, + right_hip_index, + gaussian_filter_width=20, + ): + """ Computes forward vector based on cross product of the up vector with + average of the right->left shoulder and hip vectors """ + global_positions = self.global_translation + # Perpendicular to the forward direction. + # Uses the shoulders and hips to find this. + side_direction = (global_positions[:, left_shoulder_index].numpy() - global_positions[:, right_shoulder_index].numpy() + global_positions[:, left_hip_index].numpy() - global_positions[:, right_hip_index].numpy()) + side_direction = (side_direction / np.sqrt((side_direction**2).sum(axis=-1))[..., np.newaxis]) + + # Forward direction obtained by crossing with the up direction. + forward_direction = np.cross(side_direction, np.array([[0, 1, 0]])) + + # Smooth the forward direction with a Gaussian. + # Axis 0 is the time/frame axis. + forward_direction = filters.gaussian_filter1d(forward_direction, gaussian_filter_width, axis=0, mode="nearest") + forward_direction = (forward_direction / np.sqrt((forward_direction**2).sum(axis=-1))[..., np.newaxis]) + + return torch.from_numpy(forward_direction) + + @staticmethod + def _to_state_vector(rot, rt): + state_shape = rot.shape[:-2] + vr = rot.reshape(*(state_shape + (-1,))) + vt = rt.broadcast_to(*state_shape + rt.shape[-1:]).reshape(*(state_shape + (-1,))) + v = torch.cat([vr, vt], axis=-1) + return v + + @classmethod + def from_dict(cls: Type["SkeletonState"], dict_repr: OrderedDict, *args, **kwargs) -> "SkeletonState": + rot = TensorUtils.from_dict(dict_repr["rotation"], *args, **kwargs) + rt = TensorUtils.from_dict(dict_repr["root_translation"], *args, **kwargs) + return cls( + SkeletonState._to_state_vector(rot, rt), + SkeletonTree.from_dict(dict_repr["skeleton_tree"], *args, **kwargs), + dict_repr["is_local"], + ) + + def to_dict(self) -> OrderedDict: + return OrderedDict([ + ("rotation", tensor_to_dict(self.rotation)), + ("root_translation", tensor_to_dict(self.root_translation)), + ("skeleton_tree", self.skeleton_tree.to_dict()), + ("is_local", self.is_local), + ]) + + @classmethod + def from_rotation_and_root_translation(cls, skeleton_tree, r, t, is_local=True): + """ + Construct a skeleton state from rotation and root translation + + :param skeleton_tree: the skeleton tree + :type skeleton_tree: SkeletonTree + :param r: rotation (either global or local) + :type r: Tensor + :param t: root translation + :type t: Tensor + :param is_local: to indicate that whether the rotation is local or global + :type is_local: bool, optional, default=True + """ + assert (r.dim() > 0), "the rotation needs to have at least 1 dimension (dim = {})".format(r.dim) + state_vec = SkeletonState._to_state_vector(r, t) + + return cls( + state_vec, + skeleton_tree=skeleton_tree, + is_local=is_local, + ) + + @classmethod + def zero_pose(cls, skeleton_tree): + """ + Construct a zero-pose skeleton state from the skeleton tree by assuming that all the local + rotation is 0 and root translation is also 0. + + :param skeleton_tree: the skeleton tree as the rigid body + :type skeleton_tree: SkeletonTree + """ + return cls.from_rotation_and_root_translation( + skeleton_tree=skeleton_tree, + r=quat_identity([skeleton_tree.num_joints]), + t=torch.zeros(3, dtype=skeleton_tree.local_translation.dtype), + is_local=True, + ) + + def local_repr(self): + """ + Convert the skeleton state into local representation. This will only affects the values of + .tensor. If the skeleton state already has `is_local=True`. This method will do nothing. + + :rtype: SkeletonState + """ + if self.is_local: + return self + return SkeletonState.from_rotation_and_root_translation( + self.skeleton_tree, + r=self.local_rotation, + t=self.root_translation, + is_local=True, + ) + + def global_repr(self): + """ + Convert the skeleton state into global representation. This will only affects the values of + .tensor. If the skeleton state already has `is_local=False`. This method will do nothing. + + :rtype: SkeletonState + """ + if not self.is_local: + return self + return SkeletonState.from_rotation_and_root_translation( + self.skeleton_tree, + r=self.global_rotation, + t=self.root_translation, + is_local=False, + ) + + def _get_pairwise_average_translation(self): + global_transform_inv = transform_inverse(self.global_transformation) + p1 = global_transform_inv.unsqueeze(-2) + p2 = self.global_transformation.unsqueeze(-3) + + pairwise_translation = (transform_translation(transform_mul(p1, p2)).reshape(-1, len(self.skeleton_tree), len(self.skeleton_tree), 3).mean(axis=0)) + return pairwise_translation + + def _transfer_to(self, new_skeleton_tree: SkeletonTree): + old_indices = list(map(self.skeleton_tree.index, new_skeleton_tree)) + return SkeletonState.from_rotation_and_root_translation( + new_skeleton_tree, + r=self.global_rotation[..., old_indices, :], + t=self.root_translation, + is_local=False, + ) + + def drop_nodes_by_names(self, node_names: List[str], estimate_local_translation_from_states: bool = True) -> "SkeletonState": + """ + Drop a list of nodes from the skeleton and re-compute the local rotation to match the + original joint position as much as possible. + + :param node_names: a list node names that specifies the nodes need to be dropped + :type node_names: List of strings + :param estimate_local_translation_from_states: the boolean indicator that specifies whether\ + or not to re-estimate the local translation from the states (avg.) + :type estimate_local_translation_from_states: boolean + :rtype: SkeletonState + """ + if estimate_local_translation_from_states: + pairwise_translation = self._get_pairwise_average_translation() + else: + pairwise_translation = None + new_skeleton_tree = self.skeleton_tree.drop_nodes_by_names(node_names, pairwise_translation) + return self._transfer_to(new_skeleton_tree) + + def keep_nodes_by_names(self, node_names: List[str], estimate_local_translation_from_states: bool = True) -> "SkeletonState": + """ + Keep a list of nodes and drop all other nodes from the skeleton and re-compute the local + rotation to match the original joint position as much as possible. + + :param node_names: a list node names that specifies the nodes need to be dropped + :type node_names: List of strings + :param estimate_local_translation_from_states: the boolean indicator that specifies whether\ + or not to re-estimate the local translation from the states (avg.) + :type estimate_local_translation_from_states: boolean + :rtype: SkeletonState + """ + return self.drop_nodes_by_names( + list(filter(lambda x: (x not in node_names), self)), + estimate_local_translation_from_states, + ) + + def _remapped_to(self, joint_mapping: Dict[str, str], target_skeleton_tree: SkeletonTree): + joint_mapping_inv = {target: source for source, target in joint_mapping.items()} + reduced_target_skeleton_tree = target_skeleton_tree.keep_nodes_by_names(list(joint_mapping_inv)) + n_joints = ( + len(joint_mapping), + len(self.skeleton_tree), + len(reduced_target_skeleton_tree), + ) + assert (len(set(n_joints)) == 1), "the joint mapping is not consistent with the skeleton trees" + source_indices = list(map( + lambda x: self.skeleton_tree.index(joint_mapping_inv[x]), + reduced_target_skeleton_tree, + )) + target_local_rotation = self.local_rotation[..., source_indices, :] + return SkeletonState.from_rotation_and_root_translation( + skeleton_tree=reduced_target_skeleton_tree, + r=target_local_rotation, + t=self.root_translation, + is_local=True, + ) + + def retarget_to( + self, + joint_mapping: Dict[str, str], + source_tpose_local_rotation, + source_tpose_root_translation: np.ndarray, + target_skeleton_tree: SkeletonTree, + target_tpose_local_rotation, + target_tpose_root_translation: np.ndarray, + rotation_to_target_skeleton, + scale_to_target_skeleton: float, + z_up: bool = True, + ) -> "SkeletonState": + """ + Retarget the skeleton state to a target skeleton tree. This is a naive retarget + implementation with rough approximations. The function follows the procedures below. + + Steps: + 1. Drop the joints from the source (self) that do not belong to the joint mapping\ + with an implementation that is similar to "keep_nodes_by_names()" - take a\ + look at the function doc for more details (same for source_tpose) + + 2. Rotate the source state and the source tpose by "rotation_to_target_skeleton"\ + to align the source with the target orientation + + 3. Extract the root translation and normalize it to match the scale of the target\ + skeleton + + 4. Extract the global rotation from source state relative to source tpose and\ + re-apply the relative rotation to the target tpose to construct the global\ + rotation after retargetting + + 5. Combine the computed global rotation and the root translation from 3 and 4 to\ + complete the retargeting. + + 6. Make feet on the ground (global translation z) + + :param joint_mapping: a dictionary of that maps the joint node from the source skeleton to \ + the target skeleton + :type joint_mapping: Dict[str, str] + + :param source_tpose_local_rotation: the local rotation of the source skeleton + :type source_tpose_local_rotation: Tensor + + :param source_tpose_root_translation: the root translation of the source tpose + :type source_tpose_root_translation: np.ndarray + + :param target_skeleton_tree: the target skeleton tree + :type target_skeleton_tree: SkeletonTree + + :param target_tpose_local_rotation: the local rotation of the target skeleton + :type target_tpose_local_rotation: Tensor + + :param target_tpose_root_translation: the root translation of the target tpose + :type target_tpose_root_translation: Tensor + + :param rotation_to_target_skeleton: the rotation that needs to be applied to the source\ + skeleton to align with the target skeleton. Essentially the rotation is t_R_s, where t is\ + the frame of reference of the target skeleton and s is the frame of reference of the source\ + skeleton + :type rotation_to_target_skeleton: Tensor + :param scale_to_target_skeleton: the factor that needs to be multiplied from source\ + skeleton to target skeleton (unit in distance). For example, to go from `cm` to `m`, the \ + factor needs to be 0.01. + :type scale_to_target_skeleton: float + :rtype: SkeletonState + """ + + # STEP 0: Preprocess + source_tpose = SkeletonState.from_rotation_and_root_translation( + skeleton_tree=self.skeleton_tree, + r=source_tpose_local_rotation, + t=source_tpose_root_translation, + is_local=True, + ) + target_tpose = SkeletonState.from_rotation_and_root_translation( + skeleton_tree=target_skeleton_tree, + r=target_tpose_local_rotation, + t=target_tpose_root_translation, + is_local=True, + ) + + # STEP 1: Drop the irrelevant joints + pairwise_translation = self._get_pairwise_average_translation() + node_names = list(joint_mapping) + new_skeleton_tree = self.skeleton_tree.keep_nodes_by_names(node_names, pairwise_translation) + + # TODO: combine the following steps before STEP 3 + source_tpose = source_tpose._transfer_to(new_skeleton_tree) + source_state = self._transfer_to(new_skeleton_tree) + + source_tpose = source_tpose._remapped_to(joint_mapping, target_skeleton_tree) + source_state = source_state._remapped_to(joint_mapping, target_skeleton_tree) + + # STEP 2: Rotate the source to align with the target + new_local_rotation = source_tpose.local_rotation.clone() + new_local_rotation[..., 0, :] = quat_mul_norm(rotation_to_target_skeleton, source_tpose.local_rotation[..., 0, :]) + + source_tpose = SkeletonState.from_rotation_and_root_translation( + skeleton_tree=source_tpose.skeleton_tree, + r=new_local_rotation, + t=quat_rotate(rotation_to_target_skeleton, source_tpose.root_translation), + is_local=True, + ) + + new_local_rotation = source_state.local_rotation.clone() + new_local_rotation[..., 0, :] = quat_mul_norm(rotation_to_target_skeleton, source_state.local_rotation[..., 0, :]) + source_state = SkeletonState.from_rotation_and_root_translation( + skeleton_tree=source_state.skeleton_tree, + r=new_local_rotation, + t=quat_rotate(rotation_to_target_skeleton, source_state.root_translation), + is_local=True, + ) + + # STEP 3: Normalize to match the target scale + root_translation_diff = (source_state.root_translation - source_tpose.root_translation) * scale_to_target_skeleton + + # STEP 4: the global rotation from source state relative to source tpose and + # re-apply to the target + current_skeleton_tree = source_state.skeleton_tree + target_tpose_global_rotation = source_state.global_rotation[0, :].clone() + for current_index, name in enumerate(current_skeleton_tree): + if name in target_tpose.skeleton_tree: + target_tpose_global_rotation[current_index, :] = target_tpose.global_rotation[target_tpose.skeleton_tree.index(name), :] + + global_rotation_diff = quat_mul_norm(source_state.global_rotation, quat_inverse(source_tpose.global_rotation)) + new_global_rotation = quat_mul_norm(global_rotation_diff, target_tpose_global_rotation) + + # STEP 5: Putting 3 and 4 together + current_skeleton_tree = source_state.skeleton_tree + shape = source_state.global_rotation.shape[:-1] + shape = shape[:-1] + target_tpose.global_rotation.shape[-2:-1] + new_global_rotation_output = quat_identity(shape) + for current_index, name in enumerate(target_skeleton_tree): + while name not in current_skeleton_tree: + name = target_skeleton_tree.parent_of(name) + parent_index = current_skeleton_tree.index(name) + new_global_rotation_output[:, current_index, :] = new_global_rotation[:, parent_index, :] + + source_state = SkeletonState.from_rotation_and_root_translation( + skeleton_tree=target_skeleton_tree, + r=new_global_rotation_output, + t=target_tpose.root_translation + root_translation_diff, + is_local=False, + ).local_repr() + + return source_state + + def retarget_to_by_tpose( + self, + joint_mapping: Dict[str, str], + source_tpose: "SkeletonState", + target_tpose: "SkeletonState", + rotation_to_target_skeleton, + scale_to_target_skeleton: float, + ) -> "SkeletonState": + """ + Retarget the skeleton state to a target skeleton tree. This is a naive retarget + implementation with rough approximations. See the method `retarget_to()` for more information + + :param joint_mapping: a dictionary of that maps the joint node from the source skeleton to \ + the target skeleton + :type joint_mapping: Dict[str, str] + + :param source_tpose: t-pose of the source skeleton + :type source_tpose: SkeletonState + + :param target_tpose: t-pose of the target skeleton + :type target_tpose: SkeletonState + + :param rotation_to_target_skeleton: the rotation that needs to be applied to the source\ + skeleton to align with the target skeleton. Essentially the rotation is t_R_s, where t is\ + the frame of reference of the target skeleton and s is the frame of reference of the source\ + skeleton + :type rotation_to_target_skeleton: Tensor + :param scale_to_target_skeleton: the factor that needs to be multiplied from source\ + skeleton to target skeleton (unit in distance). For example, to go from `cm` to `m`, the \ + factor needs to be 0.01. + :type scale_to_target_skeleton: float + :rtype: SkeletonState + """ + assert (len(source_tpose.shape) == 0 and len(target_tpose.shape) == 0), "the retargeting script currently doesn't support vectorized operations" + return self.retarget_to( + joint_mapping, + source_tpose.local_rotation, + source_tpose.root_translation, + target_tpose.skeleton_tree, + target_tpose.local_rotation, + target_tpose.root_translation, + rotation_to_target_skeleton, + scale_to_target_skeleton, + ) + + +class SkeletonMotion(SkeletonState): + + def __init__(self, tensor_backend, skeleton_tree, is_local, fps, *args, **kwargs): + self._fps = fps + super().__init__(tensor_backend, skeleton_tree, is_local, *args, **kwargs) + + def clone(self): + return SkeletonMotion(self.tensor.clone(), self.skeleton_tree, self._is_local, self._fps) + + @property + def invariant_property(self): + return { + "skeleton_tree": self.skeleton_tree, + "is_local": self.is_local, + "fps": self.fps, + } + + @property + def global_velocity(self): + """ global velocity """ + curr_index = self.num_joints * 4 + 3 + return self.tensor[..., curr_index:curr_index + self.num_joints * 3].reshape(*(self.tensor.shape[:-1] + (self.num_joints, 3))) + + @property + def global_angular_velocity(self): + """ global angular velocity """ + curr_index = self.num_joints * 7 + 3 + return self.tensor[..., curr_index:curr_index + self.num_joints * 3].reshape(*(self.tensor.shape[:-1] + (self.num_joints, 3))) + + @property + def fps(self): + """ number of frames per second """ + return self._fps + + @property + def time_delta(self): + """ time between two adjacent frames """ + return 1.0 / self.fps + + @property + def global_root_velocity(self): + """ global root velocity """ + return self.global_velocity[..., 0, :] + + @property + def global_root_angular_velocity(self): + """ global root angular velocity """ + return self.global_angular_velocity[..., 0, :] + + @classmethod + def from_state_vector_and_velocity( + cls, + skeleton_tree, + state_vector, + global_velocity, + global_angular_velocity, + is_local, + fps, + ): + """ + Construct a skeleton motion from a skeleton state vector, global velocity and angular + velocity at each joint. + + :param skeleton_tree: the skeleton tree that the motion is based on + :type skeleton_tree: SkeletonTree + :param state_vector: the state vector from the skeleton state by `.tensor` + :type state_vector: Tensor + :param global_velocity: the global velocity at each joint + :type global_velocity: Tensor + :param global_angular_velocity: the global angular velocity at each joint + :type global_angular_velocity: Tensor + :param is_local: if the rotation ins the state vector is given in local frame + :type is_local: boolean + :param fps: number of frames per second + :type fps: int + + :rtype: SkeletonMotion + """ + state_shape = state_vector.shape[:-1] + v = global_velocity.reshape(*(state_shape + (-1,))) + av = global_angular_velocity.reshape(*(state_shape + (-1,))) + new_state_vector = torch.cat([state_vector, v, av], axis=-1) + return cls( + new_state_vector, + skeleton_tree=skeleton_tree, + is_local=is_local, + fps=fps, + ) + + @classmethod + def from_skeleton_state(cls: Type["SkeletonMotion"], skeleton_state: SkeletonState, fps: int): + """ + Construct a skeleton motion from a skeleton state. The velocities are estimated using second + order guassian filter along the last axis. The skeleton state must have at least .dim >= 1 + + :param skeleton_state: the skeleton state that the motion is based on + :type skeleton_state: SkeletonState + :param fps: number of frames per second + :type fps: int + + :rtype: SkeletonMotion + """ + assert (type(skeleton_state) == SkeletonState), "expected type of {}, got {}".format(SkeletonState, type(skeleton_state)) + global_velocity = SkeletonMotion._compute_velocity(p=skeleton_state.global_translation, time_delta=1 / fps) + global_angular_velocity = SkeletonMotion._compute_angular_velocity(r=skeleton_state.global_rotation, time_delta=1 / fps) + return cls.from_state_vector_and_velocity( + skeleton_tree=skeleton_state.skeleton_tree, + state_vector=skeleton_state.tensor, + global_velocity=global_velocity, + global_angular_velocity=global_angular_velocity, + is_local=skeleton_state.is_local, + fps=fps, + ) + + @staticmethod + def _to_state_vector(rot, rt, vel, avel): + state_shape = rot.shape[:-2] + skeleton_state_v = SkeletonState._to_state_vector(rot, rt) + v = vel.reshape(*(state_shape + (-1,))) + av = avel.reshape(*(state_shape + (-1,))) + skeleton_motion_v = torch.cat([skeleton_state_v, v, av], axis=-1) + return skeleton_motion_v + + @classmethod + def from_dict(cls: Type["SkeletonMotion"], dict_repr: OrderedDict, *args, **kwargs) -> "SkeletonMotion": + rot = TensorUtils.from_dict(dict_repr["rotation"], *args, **kwargs) + rt = TensorUtils.from_dict(dict_repr["root_translation"], *args, **kwargs) + vel = TensorUtils.from_dict(dict_repr["global_velocity"], *args, **kwargs) + avel = TensorUtils.from_dict(dict_repr["global_angular_velocity"], *args, **kwargs) + return cls( + SkeletonMotion._to_state_vector(rot, rt, vel, avel), + skeleton_tree=SkeletonTree.from_dict(dict_repr["skeleton_tree"], *args, **kwargs), + is_local=dict_repr["is_local"], + fps=dict_repr["fps"], + ) + + def to_dict(self) -> OrderedDict: + return OrderedDict([ + ("rotation", tensor_to_dict(self.rotation)), + ("root_translation", tensor_to_dict(self.root_translation)), + ("global_velocity", tensor_to_dict(self.global_velocity)), + ("global_angular_velocity", tensor_to_dict(self.global_angular_velocity)), + ("skeleton_tree", self.skeleton_tree.to_dict()), + ("is_local", self.is_local), + ("fps", self.fps), + ]) + + @classmethod + def from_fbx( + cls: Type["SkeletonMotion"], + fbx_file_path, + fbx_configs, + skeleton_tree=None, + is_local=True, + fps=120, + root_joint="", + root_trans_index=0, + *args, + **kwargs, + ) -> "SkeletonMotion": + """ + Construct a skeleton motion from a fbx file (TODO - generalize this). If the skeleton tree + is not given, it will use the first frame of the mocap to construct the skeleton tree. + + :param fbx_file_path: the path of the fbx file + :type fbx_file_path: string + :param fbx_configs: the configuration in terms of {"tmp_path": ..., "fbx_py27_path": ...} + :type fbx_configs: dict + :param skeleton_tree: the optional skeleton tree that the rotation will be applied to + :type skeleton_tree: SkeletonTree, optional + :param is_local: the state vector uses local or global rotation as the representation + :type is_local: bool, optional, default=True + :rtype: SkeletonMotion + """ + joint_names, joint_parents, transforms, fps = fbx_to_array(fbx_file_path, fbx_configs, root_joint, fps) + # swap the last two axis to match the convention + local_transform = euclidean_to_transform(transformation_matrix=torch.from_numpy(np.swapaxes(np.array(transforms), -1, -2),).float()) + local_rotation = transform_rotation(local_transform) + root_translation = transform_translation(local_transform)[..., root_trans_index, :] + joint_parents = torch.from_numpy(np.array(joint_parents)).int() + + if skeleton_tree is None: + local_translation = transform_translation(local_transform).reshape(-1, len(joint_parents), 3)[0] + skeleton_tree = SkeletonTree(joint_names, joint_parents, local_translation) + skeleton_state = SkeletonState.from_rotation_and_root_translation(skeleton_tree, r=local_rotation, t=root_translation, is_local=True) + if not is_local: + skeleton_state = skeleton_state.global_repr() + return cls.from_skeleton_state(skeleton_state=skeleton_state, fps=fps) + + @staticmethod + def _compute_velocity(p, time_delta, guassian_filter=True): + velocity = np.gradient(p.numpy(), axis=-3) / time_delta + if guassian_filter: + velocity = torch.from_numpy(filters.gaussian_filter1d(velocity, 2, axis=-3, mode="nearest")).to(p) + else: + velocity = torch.from_numpy(velocity).to(p) + + return velocity + + @staticmethod + def _compute_angular_velocity(r, time_delta: float, guassian_filter=True): + # assume the second last dimension is the time axis + diff_quat_data = quat_identity_like(r).to(r) + diff_quat_data[..., :-1, :, :] = quat_mul_norm(r[..., 1:, :, :], quat_inverse(r[..., :-1, :, :])) + diff_angle, diff_axis = quat_angle_axis(diff_quat_data) + angular_velocity = diff_axis * diff_angle.unsqueeze(-1) / time_delta + if guassian_filter: + angular_velocity = torch.from_numpy(filters.gaussian_filter1d(angular_velocity.numpy(), 2, axis=-3, mode="nearest"),) + return angular_velocity + + def crop(self, start: int, end: int, fps: Optional[int] = None): + """ + Crop the motion along its last axis. This is equivalent to performing a slicing on the + object with [..., start: end: skip_every] where skip_every = old_fps / fps. Note that the + new fps provided must be a factor of the original fps. + + :param start: the beginning frame index + :type start: int + :param end: the ending frame index + :type end: int + :param fps: number of frames per second in the output (if not given the original fps will be used) + :type fps: int, optional + :rtype: SkeletonMotion + """ + if fps is None: + new_fps = int(self.fps) + old_fps = int(self.fps) + else: + new_fps = int(fps) + old_fps = int(self.fps) + assert old_fps % fps == 0, ("the resampling doesn't support fps with non-integer division " + "from the original fps: {} => {}".format(old_fps, fps)) + skip_every = old_fps // new_fps + s = slice(start, end, skip_every) + z = self[..., s] + + rot = z.local_rotation if z.is_local else z.global_rotation + rt = z.root_translation + vel = z.global_velocity + avel = z.global_angular_velocity + return SkeletonMotion( + SkeletonMotion._to_state_vector(rot, rt, vel, avel), + skeleton_tree=z.skeleton_tree, + is_local=z.is_local, + fps=new_fps, + ) + + def retarget_to( + self, + joint_mapping: Dict[str, str], + source_tpose_local_rotation, + source_tpose_root_translation: np.ndarray, + target_skeleton_tree: "SkeletonTree", + target_tpose_local_rotation, + target_tpose_root_translation: np.ndarray, + rotation_to_target_skeleton, + scale_to_target_skeleton: float, + z_up: bool = True, + ) -> "SkeletonMotion": + """ + Same as the one in :class:`SkeletonState`. This method discards all velocity information before + retargeting and re-estimate the velocity after the retargeting. The same fps is used in the + new retargetted motion. + + :param joint_mapping: a dictionary of that maps the joint node from the source skeleton to \ + the target skeleton + :type joint_mapping: Dict[str, str] + + :param source_tpose_local_rotation: the local rotation of the source skeleton + :type source_tpose_local_rotation: Tensor + + :param source_tpose_root_translation: the root translation of the source tpose + :type source_tpose_root_translation: np.ndarray + + :param target_skeleton_tree: the target skeleton tree + :type target_skeleton_tree: SkeletonTree + + :param target_tpose_local_rotation: the local rotation of the target skeleton + :type target_tpose_local_rotation: Tensor + + :param target_tpose_root_translation: the root translation of the target tpose + :type target_tpose_root_translation: Tensor + + :param rotation_to_target_skeleton: the rotation that needs to be applied to the source\ + skeleton to align with the target skeleton. Essentially the rotation is t_R_s, where t is\ + the frame of reference of the target skeleton and s is the frame of reference of the source\ + skeleton + :type rotation_to_target_skeleton: Tensor + :param scale_to_target_skeleton: the factor that needs to be multiplied from source\ + skeleton to target skeleton (unit in distance). For example, to go from `cm` to `m`, the \ + factor needs to be 0.01. + :type scale_to_target_skeleton: float + :rtype: SkeletonMotion + """ + return SkeletonMotion.from_skeleton_state( + super().retarget_to( + joint_mapping, + source_tpose_local_rotation, + source_tpose_root_translation, + target_skeleton_tree, + target_tpose_local_rotation, + target_tpose_root_translation, + rotation_to_target_skeleton, + scale_to_target_skeleton, + z_up, + ), + self.fps, + ) + + def retarget_to_by_tpose( + self, + joint_mapping: Dict[str, str], + source_tpose: "SkeletonState", + target_tpose: "SkeletonState", + rotation_to_target_skeleton, + scale_to_target_skeleton: float, + z_up: bool = True, + ) -> "SkeletonMotion": + """ + Same as the one in :class:`SkeletonState`. This method discards all velocity information before + retargeting and re-estimate the velocity after the retargeting. The same fps is used in the + new retargetted motion. + + :param joint_mapping: a dictionary of that maps the joint node from the source skeleton to \ + the target skeleton + :type joint_mapping: Dict[str, str] + + :param source_tpose: t-pose of the source skeleton + :type source_tpose: SkeletonState + + :param target_tpose: t-pose of the target skeleton + :type target_tpose: SkeletonState + + :param rotation_to_target_skeleton: the rotation that needs to be applied to the source\ + skeleton to align with the target skeleton. Essentially the rotation is t_R_s, where t is\ + the frame of reference of the target skeleton and s is the frame of reference of the source\ + skeleton + :type rotation_to_target_skeleton: Tensor + :param scale_to_target_skeleton: the factor that needs to be multiplied from source\ + skeleton to target skeleton (unit in distance). For example, to go from `cm` to `m`, the \ + factor needs to be 0.01. + :type scale_to_target_skeleton: float + :rtype: SkeletonMotion + """ + return self.retarget_to( + joint_mapping, + source_tpose.local_rotation, + source_tpose.root_translation, + target_tpose.skeleton_tree, + target_tpose.local_rotation, + target_tpose.root_translation, + rotation_to_target_skeleton, + scale_to_target_skeleton, + z_up, + ) diff --git a/poselib/poselib/skeleton/tests/__init__.py b/poselib/poselib/skeleton/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/poselib/poselib/skeleton/tests/ant.xml b/poselib/poselib/skeleton/tests/ant.xml new file mode 100644 index 0000000..311d96f --- /dev/null +++ b/poselib/poselib/skeleton/tests/ant.xml @@ -0,0 +1,71 @@ + + + diff --git a/poselib/poselib/skeleton/tests/test_skeleton.py b/poselib/poselib/skeleton/tests/test_skeleton.py new file mode 100644 index 0000000..aa9edc3 --- /dev/null +++ b/poselib/poselib/skeleton/tests/test_skeleton.py @@ -0,0 +1,132 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +from ...core import * +from ..skeleton3d import SkeletonTree, SkeletonState, SkeletonMotion + +import numpy as np +import torch + +from ...visualization.common import ( + plot_skeleton_state, + plot_skeleton_motion_interactive, +) + +from ...visualization.plt_plotter import Matplotlib3DPlotter +from ...visualization.skeleton_plotter_tasks import ( + Draw3DSkeletonMotion, + Draw3DSkeletonState, +) + + +def test_skel_tree(): + skel_tree = SkeletonTree.from_mjcf( + "/home/serfcx/DL_Animation/rl_mimic/data/skeletons/humanoid_mimic_mod_2_noind.xml", + backend="pytorch", + ) + skel_tree_rec = SkeletonTree.from_dict(skel_tree.to_dict(), backend="pytorch") + # assert skel_tree.to_str() == skel_tree_rec.to_str() + print(skel_tree.node_names) + print(skel_tree.local_translation) + print(skel_tree.parent_indices) + skel_state = SkeletonState.zero_pose(skeleton_tree=skel_tree) + plot_skeleton_state(task_name="draw_skeleton", skeleton_state=skel_state) + skel_state = skel_state.drop_nodes_by_names(["right_hip", "left_hip"]) + plot_skeleton_state(task_name="draw_skeleton", skeleton_state=skel_state) + + +def test_skel_motion(): + skel_motion = SkeletonMotion.from_file( + "/tmp/tmp.npy", backend="pytorch", load_context=True + ) + + plot_skeleton_motion_interactive(skel_motion) + + +def test_grad(): + source_motion = SkeletonMotion.from_file( + "c:\\Users\\bmatusch\\carbmimic\\data\\motions\\JogFlatTerrain_01_ase.npy", + backend="pytorch", + device="cuda:0", + ) + source_tpose = SkeletonState.from_file( + "c:\\Users\\bmatusch\\carbmimic\\data\\skeletons\\fox_tpose.npy", + backend="pytorch", + device="cuda:0", + ) + + target_tpose = SkeletonState.from_file( + "c:\\Users\\bmatusch\\carbmimic\\data\\skeletons\\flex_tpose.npy", + backend="pytorch", + device="cuda:0", + ) + target_skeleton_tree = target_tpose.skeleton_tree + + joint_mapping = { + "upArm_r": "right_shoulder", + "upArm_l": "left_shoulder", + "loArm_r": "right_elbow", + "loArm_l": "left_elbow", + "upLeg_r": "right_hip", + "upLeg_l": "left_hip", + "loLeg_r": "right_knee", + "loLeg_l": "left_knee", + "foot_r": "right_ankle", + "foot_l": "left_ankle", + "hips": "pelvis", + "neckA": "neck", + "spineA": "abdomen", + } + + rotation_to_target_skeleton = quat_from_angle_axis( + angle=torch.tensor(90.0).float(), + axis=torch.tensor([1, 0, 0]).float(), + degree=True, + ) + + target_motion = source_motion.retarget_to( + joint_mapping=joint_mapping, + source_tpose_local_rotation=source_tpose.local_rotation, + source_tpose_root_translation=source_tpose.root_translation, + target_skeleton_tree=target_skeleton_tree, + target_tpose_local_rotation=target_tpose.local_rotation, + target_tpose_root_translation=target_tpose.root_translation, + rotation_to_target_skeleton=rotation_to_target_skeleton, + scale_to_target_skeleton=0.01, + ) + + target_state = SkeletonState( + target_motion.tensor[800, :], + target_motion.skeleton_tree, + target_motion.is_local, + ) + + skeleton_tree = target_state.skeleton_tree + root_translation = target_state.root_translation + global_translation = target_state.global_translation + + q = np.zeros((len(skeleton_tree), 4), dtype=np.float32) + q[..., 3] = 1.0 + q = torch.from_numpy(q) + max_its = 10000 + + task = Draw3DSkeletonState(task_name="", skeleton_state=target_state) + plotter = Matplotlib3DPlotter(task) + + for i in range(max_its): + r = quat_normalize(q) + s = SkeletonState.from_rotation_and_root_translation( + skeleton_tree, r=r, t=root_translation, is_local=True + ) + print(" quat norm: {}".format(q.norm(p=2, dim=-1).mean().numpy())) + + task.update(s) + plotter.update() + plotter.show() + + +test_grad() \ No newline at end of file diff --git a/poselib/poselib/skeleton/tests/transfer_npy.py b/poselib/poselib/skeleton/tests/transfer_npy.py new file mode 100644 index 0000000..dfcd96b --- /dev/null +++ b/poselib/poselib/skeleton/tests/transfer_npy.py @@ -0,0 +1,31 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import numpy as np +from ...core import Tensor, SO3, Quaternion, Vector3D +from ..skeleton3d import SkeletonTree, SkeletonState, SkeletonMotion + +tpose = np.load( + "/home/serfcx/DL_Animation/rl_mimic/data/skeletons/flex_tpose.npy" +).item() + +local_rotation = SO3.from_numpy(tpose["local_rotation"], dtype="float32") +root_translation = Vector3D.from_numpy(tpose["root_translation"], dtype="float32") +skeleton_tree = tpose["skeleton_tree"] +parent_indices = Tensor.from_numpy(skeleton_tree["parent_indices"], dtype="int32") +local_translation = Vector3D.from_numpy( + skeleton_tree["local_translation"], dtype="float32" +) +node_names = skeleton_tree["node_names"] +skeleton_tree = SkeletonTree(node_names, parent_indices, local_translation) +skeleton_state = SkeletonState.from_rotation_and_root_translation( + skeleton_tree=skeleton_tree, r=local_rotation, t=root_translation, is_local=True +) + +skeleton_state.to_file( + "/home/serfcx/DL_Animation/rl_mimic/data/skeletons/flex_tpose_new.npy" +) diff --git a/poselib/poselib/visualization/__init__.py b/poselib/poselib/visualization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/poselib/poselib/visualization/common.py b/poselib/poselib/visualization/common.py new file mode 100644 index 0000000..6b6f9ae --- /dev/null +++ b/poselib/poselib/visualization/common.py @@ -0,0 +1,189 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import os + +from ..core import logger +from .plt_plotter import Matplotlib3DPlotter +from .skeleton_plotter_tasks import Draw3DSkeletonMotion, Draw3DSkeletonState + + +def plot_skeleton_state(skeleton_state, task_name=""): + """ + Visualize a skeleton state + + :param skeleton_state: + :param task_name: + :type skeleton_state: SkeletonState + :type task_name: string, optional + """ + logger.info("plotting {}".format(task_name)) + task = Draw3DSkeletonState(task_name=task_name, skeleton_state=skeleton_state) + plotter = Matplotlib3DPlotter(task) + plotter.show() + + +def plot_skeleton_states(skeleton_state, skip_n=1, task_name=""): + """ + Visualize a sequence of skeleton state. The dimension of the skeleton state must be 1 + + :param skeleton_state: + :param task_name: + :type skeleton_state: SkeletonState + :type task_name: string, optional + """ + logger.info("plotting {} motion".format(task_name)) + assert len(skeleton_state.shape) == 1, "the state must have only one dimension" + task = Draw3DSkeletonState(task_name=task_name, skeleton_state=skeleton_state[0]) + plotter = Matplotlib3DPlotter(task) + for frame_id in range(skeleton_state.shape[0]): + if frame_id % skip_n != 0: + continue + task.update(skeleton_state[frame_id]) + plotter.update() + plotter.show() + + +def plot_skeleton_motion(skeleton_motion, skip_n=1, task_name=""): + """ + Visualize a skeleton motion along its first dimension. + + :param skeleton_motion: + :param task_name: + :type skeleton_motion: SkeletonMotion + :type task_name: string, optional + """ + logger.info("plotting {} motion".format(task_name)) + task = Draw3DSkeletonMotion( + task_name=task_name, skeleton_motion=skeleton_motion, frame_index=0 + ) + plotter = Matplotlib3DPlotter(task) + for frame_id in range(len(skeleton_motion)): + if frame_id % skip_n != 0: + continue + task.update(frame_id) + plotter.update() + plotter.show() + + +def plot_skeleton_motion_interactive_base(skeleton_motion, task_name=""): + class PlotParams: + def __init__(self, total_num_frames): + self.current_frame = 0 + self.playing = False + self.looping = False + self.confirmed = False + self.playback_speed = 4 + self.total_num_frames = total_num_frames + + def sync(self, other): + self.current_frame = other.current_frame + self.playing = other.playing + self.looping = other.current_frame + self.confirmed = other.confirmed + self.playback_speed = other.playback_speed + self.total_num_frames = other.total_num_frames + + task = Draw3DSkeletonMotion( + task_name=task_name, skeleton_motion=skeleton_motion, frame_index=0 + ) + plotter = Matplotlib3DPlotter(task) + + plot_params = PlotParams(total_num_frames=len(skeleton_motion)) + print("Entered interactive plot - press 'n' to quit, 'h' for a list of commands") + + def press(event): + if event.key == "x": + plot_params.playing = not plot_params.playing + elif event.key == "z": + plot_params.current_frame = plot_params.current_frame - 1 + elif event.key == "c": + plot_params.current_frame = plot_params.current_frame + 1 + elif event.key == "a": + plot_params.current_frame = plot_params.current_frame - 20 + elif event.key == "d": + plot_params.current_frame = plot_params.current_frame + 20 + elif event.key == "w": + plot_params.looping = not plot_params.looping + print("Looping: {}".format(plot_params.looping)) + elif event.key == "v": + plot_params.playback_speed *= 2 + print("playback speed: {}".format(plot_params.playback_speed)) + elif event.key == "b": + if plot_params.playback_speed != 1: + plot_params.playback_speed //= 2 + print("playback speed: {}".format(plot_params.playback_speed)) + elif event.key == "n": + plot_params.confirmed = True + elif event.key == "h": + rows, columns = os.popen("stty size", "r").read().split() + columns = int(columns) + print("=" * columns) + print("x: play/pause") + print("z: previous frame") + print("c: next frame") + print("a: jump 10 frames back") + print("d: jump 10 frames forward") + print("w: looping/non-looping") + print("v: double speed (this can be applied multiple times)") + print("b: half speed (this can be applied multiple times)") + print("n: quit") + print("h: help") + print("=" * columns) + + print( + 'current frame index: {}/{} (press "n" to quit)'.format( + plot_params.current_frame, plot_params.total_num_frames - 1 + ) + ) + + plotter.fig.canvas.mpl_connect("key_press_event", press) + while True: + reset_trail = False + if plot_params.confirmed: + break + if plot_params.playing: + plot_params.current_frame += plot_params.playback_speed + if plot_params.current_frame >= plot_params.total_num_frames: + if plot_params.looping: + plot_params.current_frame %= plot_params.total_num_frames + reset_trail = True + else: + plot_params.current_frame = plot_params.total_num_frames - 1 + if plot_params.current_frame < 0: + if plot_params.looping: + plot_params.current_frame %= plot_params.total_num_frames + reset_trail = True + else: + plot_params.current_frame = 0 + yield plot_params + task.update(plot_params.current_frame, reset_trail) + plotter.update() + + +def plot_skeleton_motion_interactive(skeleton_motion, task_name=""): + """ + Visualize a skeleton motion along its first dimension interactively. + + :param skeleton_motion: + :param task_name: + :type skeleton_motion: SkeletonMotion + :type task_name: string, optional + """ + for _ in plot_skeleton_motion_interactive_base(skeleton_motion, task_name): + pass + + +def plot_skeleton_motion_interactive_multiple(*callables, sync=True): + for _ in zip(*callables): + if sync: + for p1, p2 in zip(_[:-1], _[1:]): + p2.sync(p1) + + +# def plot_skeleton_motion_interactive_multiple_same(skeleton_motions, task_name=""): + diff --git a/poselib/poselib/visualization/core.py b/poselib/poselib/visualization/core.py new file mode 100644 index 0000000..3c7a176 --- /dev/null +++ b/poselib/poselib/visualization/core.py @@ -0,0 +1,78 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +""" +The base abstract classes for plotter and the plotting tasks. It describes how the plotter +deals with the tasks in the general cases +""" +from typing import List + + +class BasePlotterTask(object): + _task_name: str # unique name of the task + _task_type: str # type of the task is used to identify which callable + + def __init__(self, task_name: str, task_type: str) -> None: + self._task_name = task_name + self._task_type = task_type + + @property + def task_name(self): + return self._task_name + + @property + def task_type(self): + return self._task_type + + def get_scoped_name(self, name): + return self._task_name + "/" + name + + def __iter__(self): + """Should override this function to return a list of task primitives + """ + raise NotImplementedError + + +class BasePlotterTasks(object): + def __init__(self, tasks) -> None: + self._tasks = tasks + + def __iter__(self): + for task in self._tasks: + yield from task + + +class BasePlotter(object): + """An abstract plotter which deals with a plotting task. The children class needs to implement + the functions to create/update the objects according to the task given + """ + + _task_primitives: List[BasePlotterTask] + + def __init__(self, task: BasePlotterTask) -> None: + self._task_primitives = [] + self.create(task) + + @property + def task_primitives(self): + return self._task_primitives + + def create(self, task: BasePlotterTask) -> None: + """Create more task primitives from a task for the plotter""" + new_task_primitives = list(task) # get all task primitives + self._task_primitives += new_task_primitives # append them + self._create_impl(new_task_primitives) + + def update(self) -> None: + """Update the plotter for any updates in the task primitives""" + self._update_impl(self._task_primitives) + + def _update_impl(self, task_list: List[BasePlotterTask]) -> None: + raise NotImplementedError + + def _create_impl(self, task_list: List[BasePlotterTask]) -> None: + raise NotImplementedError diff --git a/poselib/poselib/visualization/plt_plotter.py b/poselib/poselib/visualization/plt_plotter.py new file mode 100644 index 0000000..0984020 --- /dev/null +++ b/poselib/poselib/visualization/plt_plotter.py @@ -0,0 +1,402 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +""" +The matplotlib plotter implementation for all the primitive tasks (in our case: lines and +dots) +""" +from typing import Any, Callable, Dict, List + +import matplotlib.pyplot as plt +import mpl_toolkits.mplot3d.axes3d as p3 + +import numpy as np + +from .core import BasePlotter, BasePlotterTask + + +class Matplotlib2DPlotter(BasePlotter): + _fig: plt.figure # plt figure + _ax: plt.axis # plt axis + # stores artist objects for each task (task name as the key) + _artist_cache: Dict[str, Any] + # callables for each task primitives + _create_impl_callables: Dict[str, Callable] + _update_impl_callables: Dict[str, Callable] + + def __init__(self, task: "BasePlotterTask") -> None: + fig, ax = plt.subplots() + self._fig = fig + self._ax = ax + self._artist_cache = {} + + self._create_impl_callables = { + "Draw2DLines": self._lines_create_impl, + "Draw2DDots": self._dots_create_impl, + "Draw2DTrail": self._trail_create_impl, + } + self._update_impl_callables = { + "Draw2DLines": self._lines_update_impl, + "Draw2DDots": self._dots_update_impl, + "Draw2DTrail": self._trail_update_impl, + } + self._init_lim() + super().__init__(task) + + @property + def ax(self): + return self._ax + + @property + def fig(self): + return self._fig + + def show(self): + plt.show() + + def _min(self, x, y): + if x is None: + return y + if y is None: + return x + return min(x, y) + + def _max(self, x, y): + if x is None: + return y + if y is None: + return x + return max(x, y) + + def _init_lim(self): + self._curr_x_min = None + self._curr_y_min = None + self._curr_x_max = None + self._curr_y_max = None + + def _update_lim(self, xs, ys): + self._curr_x_min = self._min(np.min(xs), self._curr_x_min) + self._curr_y_min = self._min(np.min(ys), self._curr_y_min) + self._curr_x_max = self._max(np.max(xs), self._curr_x_max) + self._curr_y_max = self._max(np.max(ys), self._curr_y_max) + + def _set_lim(self): + if not ( + self._curr_x_min is None + or self._curr_x_max is None + or self._curr_y_min is None + or self._curr_y_max is None + ): + self._ax.set_xlim(self._curr_x_min, self._curr_x_max) + self._ax.set_ylim(self._curr_y_min, self._curr_y_max) + self._init_lim() + + @staticmethod + def _lines_extract_xy_impl(index, lines_task): + return lines_task[index, :, 0], lines_task[index, :, 1] + + @staticmethod + def _trail_extract_xy_impl(index, trail_task): + return (trail_task[index : index + 2, 0], trail_task[index : index + 2, 1]) + + def _lines_create_impl(self, lines_task): + color = lines_task.color + self._artist_cache[lines_task.task_name] = [ + self._ax.plot( + *Matplotlib2DPlotter._lines_extract_xy_impl(i, lines_task), + color=color, + linewidth=lines_task.line_width, + alpha=lines_task.alpha + )[0] + for i in range(len(lines_task)) + ] + + def _lines_update_impl(self, lines_task): + lines_artists = self._artist_cache[lines_task.task_name] + for i in range(len(lines_task)): + artist = lines_artists[i] + xs, ys = Matplotlib2DPlotter._lines_extract_xy_impl(i, lines_task) + artist.set_data(xs, ys) + if lines_task.influence_lim: + self._update_lim(xs, ys) + + def _dots_create_impl(self, dots_task): + color = dots_task.color + self._artist_cache[dots_task.task_name] = self._ax.plot( + dots_task[:, 0], + dots_task[:, 1], + c=color, + linestyle="", + marker=".", + markersize=dots_task.marker_size, + alpha=dots_task.alpha, + )[0] + + def _dots_update_impl(self, dots_task): + dots_artist = self._artist_cache[dots_task.task_name] + dots_artist.set_data(dots_task[:, 0], dots_task[:, 1]) + if dots_task.influence_lim: + self._update_lim(dots_task[:, 0], dots_task[:, 1]) + + def _trail_create_impl(self, trail_task): + color = trail_task.color + trail_length = len(trail_task) - 1 + self._artist_cache[trail_task.task_name] = [ + self._ax.plot( + *Matplotlib2DPlotter._trail_extract_xy_impl(i, trail_task), + color=trail_task.color, + linewidth=trail_task.line_width, + alpha=trail_task.alpha * (1.0 - i / (trail_length - 1)) + )[0] + for i in range(trail_length) + ] + + def _trail_update_impl(self, trail_task): + trails_artists = self._artist_cache[trail_task.task_name] + for i in range(len(trail_task) - 1): + artist = trails_artists[i] + xs, ys = Matplotlib2DPlotter._trail_extract_xy_impl(i, trail_task) + artist.set_data(xs, ys) + if trail_task.influence_lim: + self._update_lim(xs, ys) + + def _create_impl(self, task_list): + for task in task_list: + self._create_impl_callables[task.task_type](task) + self._draw() + + def _update_impl(self, task_list): + for task in task_list: + self._update_impl_callables[task.task_type](task) + self._draw() + + def _set_aspect_equal_2d(self, zero_centered=True): + xlim = self._ax.get_xlim() + ylim = self._ax.get_ylim() + + if not zero_centered: + xmean = np.mean(xlim) + ymean = np.mean(ylim) + else: + xmean = 0 + ymean = 0 + + plot_radius = max( + [ + abs(lim - mean_) + for lims, mean_ in ((xlim, xmean), (ylim, ymean)) + for lim in lims + ] + ) + + self._ax.set_xlim([xmean - plot_radius, xmean + plot_radius]) + self._ax.set_ylim([ymean - plot_radius, ymean + plot_radius]) + + def _draw(self): + self._set_lim() + self._set_aspect_equal_2d() + self._fig.canvas.draw() + self._fig.canvas.flush_events() + plt.pause(0.00001) + + +class Matplotlib3DPlotter(BasePlotter): + _fig: plt.figure # plt figure + _ax: p3.Axes3D # plt 3d axis + # stores artist objects for each task (task name as the key) + _artist_cache: Dict[str, Any] + # callables for each task primitives + _create_impl_callables: Dict[str, Callable] + _update_impl_callables: Dict[str, Callable] + + def __init__(self, task: "BasePlotterTask") -> None: + self._fig = plt.figure() + self._ax = p3.Axes3D(self._fig) + self._artist_cache = {} + + self._create_impl_callables = { + "Draw3DLines": self._lines_create_impl, + "Draw3DDots": self._dots_create_impl, + "Draw3DTrail": self._trail_create_impl, + } + self._update_impl_callables = { + "Draw3DLines": self._lines_update_impl, + "Draw3DDots": self._dots_update_impl, + "Draw3DTrail": self._trail_update_impl, + } + self._init_lim() + super().__init__(task) + + @property + def ax(self): + return self._ax + + @property + def fig(self): + return self._fig + + def show(self): + plt.show() + + def _min(self, x, y): + if x is None: + return y + if y is None: + return x + return min(x, y) + + def _max(self, x, y): + if x is None: + return y + if y is None: + return x + return max(x, y) + + def _init_lim(self): + self._curr_x_min = None + self._curr_y_min = None + self._curr_z_min = None + self._curr_x_max = None + self._curr_y_max = None + self._curr_z_max = None + + def _update_lim(self, xs, ys, zs): + self._curr_x_min = self._min(np.min(xs), self._curr_x_min) + self._curr_y_min = self._min(np.min(ys), self._curr_y_min) + self._curr_z_min = self._min(np.min(zs), self._curr_z_min) + self._curr_x_max = self._max(np.max(xs), self._curr_x_max) + self._curr_y_max = self._max(np.max(ys), self._curr_y_max) + self._curr_z_max = self._max(np.max(zs), self._curr_z_max) + + def _set_lim(self): + if not ( + self._curr_x_min is None + or self._curr_x_max is None + or self._curr_y_min is None + or self._curr_y_max is None + or self._curr_z_min is None + or self._curr_z_max is None + ): + self._ax.set_xlim3d(self._curr_x_min, self._curr_x_max) + self._ax.set_ylim3d(self._curr_y_min, self._curr_y_max) + self._ax.set_zlim3d(self._curr_z_min, self._curr_z_max) + self._init_lim() + + @staticmethod + def _lines_extract_xyz_impl(index, lines_task): + return lines_task[index, :, 0], lines_task[index, :, 1], lines_task[index, :, 2] + + @staticmethod + def _trail_extract_xyz_impl(index, trail_task): + return ( + trail_task[index : index + 2, 0], + trail_task[index : index + 2, 1], + trail_task[index : index + 2, 2], + ) + + def _lines_create_impl(self, lines_task): + color = lines_task.color + self._artist_cache[lines_task.task_name] = [ + self._ax.plot( + *Matplotlib3DPlotter._lines_extract_xyz_impl(i, lines_task), + color=color, + linewidth=lines_task.line_width, + alpha=lines_task.alpha + )[0] + for i in range(len(lines_task)) + ] + + def _lines_update_impl(self, lines_task): + lines_artists = self._artist_cache[lines_task.task_name] + for i in range(len(lines_task)): + artist = lines_artists[i] + xs, ys, zs = Matplotlib3DPlotter._lines_extract_xyz_impl(i, lines_task) + artist.set_data(xs, ys) + artist.set_3d_properties(zs) + if lines_task.influence_lim: + self._update_lim(xs, ys, zs) + + def _dots_create_impl(self, dots_task): + color = dots_task.color + self._artist_cache[dots_task.task_name] = self._ax.plot( + dots_task[:, 0], + dots_task[:, 1], + dots_task[:, 2], + c=color, + linestyle="", + marker=".", + markersize=dots_task.marker_size, + alpha=dots_task.alpha, + )[0] + + def _dots_update_impl(self, dots_task): + dots_artist = self._artist_cache[dots_task.task_name] + dots_artist.set_data(dots_task[:, 0], dots_task[:, 1]) + dots_artist.set_3d_properties(dots_task[:, 2]) + if dots_task.influence_lim: + self._update_lim(dots_task[:, 0], dots_task[:, 1], dots_task[:, 2]) + + def _trail_create_impl(self, trail_task): + color = trail_task.color + trail_length = len(trail_task) - 1 + self._artist_cache[trail_task.task_name] = [ + self._ax.plot( + *Matplotlib3DPlotter._trail_extract_xyz_impl(i, trail_task), + color=trail_task.color, + linewidth=trail_task.line_width, + alpha=trail_task.alpha * (1.0 - i / (trail_length - 1)) + )[0] + for i in range(trail_length) + ] + + def _trail_update_impl(self, trail_task): + trails_artists = self._artist_cache[trail_task.task_name] + for i in range(len(trail_task) - 1): + artist = trails_artists[i] + xs, ys, zs = Matplotlib3DPlotter._trail_extract_xyz_impl(i, trail_task) + artist.set_data(xs, ys) + artist.set_3d_properties(zs) + if trail_task.influence_lim: + self._update_lim(xs, ys, zs) + + def _create_impl(self, task_list): + for task in task_list: + self._create_impl_callables[task.task_type](task) + self._draw() + + def _update_impl(self, task_list): + for task in task_list: + self._update_impl_callables[task.task_type](task) + self._draw() + + def _set_aspect_equal_3d(self): + xlim = self._ax.get_xlim3d() + ylim = self._ax.get_ylim3d() + zlim = self._ax.get_zlim3d() + + xmean = np.mean(xlim) + ymean = np.mean(ylim) + zmean = np.mean(zlim) + + plot_radius = max( + [ + abs(lim - mean_) + for lims, mean_ in ((xlim, xmean), (ylim, ymean), (zlim, zmean)) + for lim in lims + ] + ) + + self._ax.set_xlim3d([xmean - plot_radius, xmean + plot_radius]) + self._ax.set_ylim3d([ymean - plot_radius, ymean + plot_radius]) + self._ax.set_zlim3d([zmean - plot_radius, zmean + plot_radius]) + + def _draw(self): + self._set_lim() + self._set_aspect_equal_3d() + self._fig.canvas.draw() + self._fig.canvas.flush_events() + plt.pause(0.00001) diff --git a/poselib/poselib/visualization/simple_plotter_tasks.py b/poselib/poselib/visualization/simple_plotter_tasks.py new file mode 100644 index 0000000..fec4c88 --- /dev/null +++ b/poselib/poselib/visualization/simple_plotter_tasks.py @@ -0,0 +1,192 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +""" +This is where all the task primitives are defined +""" +import numpy as np + +from .core import BasePlotterTask + + +class DrawXDLines(BasePlotterTask): + _lines: np.ndarray + _color: str + _line_width: int + _alpha: float + _influence_lim: bool + + def __init__( + self, + task_name: str, + lines: np.ndarray, + color: str = "blue", + line_width: int = 2, + alpha: float = 1.0, + influence_lim: bool = True, + ) -> None: + super().__init__(task_name=task_name, task_type=self.__class__.__name__) + self._color = color + self._line_width = line_width + self._alpha = alpha + self._influence_lim = influence_lim + self.update(lines) + + @property + def influence_lim(self) -> bool: + return self._influence_lim + + @property + def raw_data(self): + return self._lines + + @property + def color(self): + return self._color + + @property + def line_width(self): + return self._line_width + + @property + def alpha(self): + return self._alpha + + @property + def dim(self): + raise NotImplementedError + + @property + def name(self): + return "{}DLines".format(self.dim) + + def update(self, lines): + self._lines = np.array(lines) + shape = self._lines.shape + assert shape[-1] == self.dim and shape[-2] == 2 and len(shape) == 3 + + def __getitem__(self, index): + return self._lines[index] + + def __len__(self): + return self._lines.shape[0] + + def __iter__(self): + yield self + + +class DrawXDDots(BasePlotterTask): + _dots: np.ndarray + _color: str + _marker_size: int + _alpha: float + _influence_lim: bool + + def __init__( + self, + task_name: str, + dots: np.ndarray, + color: str = "blue", + marker_size: int = 10, + alpha: float = 1.0, + influence_lim: bool = True, + ) -> None: + super().__init__(task_name=task_name, task_type=self.__class__.__name__) + self._color = color + self._marker_size = marker_size + self._alpha = alpha + self._influence_lim = influence_lim + self.update(dots) + + def update(self, dots): + self._dots = np.array(dots) + shape = self._dots.shape + assert shape[-1] == self.dim and len(shape) == 2 + + def __getitem__(self, index): + return self._dots[index] + + def __len__(self): + return self._dots.shape[0] + + def __iter__(self): + yield self + + @property + def influence_lim(self) -> bool: + return self._influence_lim + + @property + def raw_data(self): + return self._dots + + @property + def color(self): + return self._color + + @property + def marker_size(self): + return self._marker_size + + @property + def alpha(self): + return self._alpha + + @property + def dim(self): + raise NotImplementedError + + @property + def name(self): + return "{}DDots".format(self.dim) + + +class DrawXDTrail(DrawXDDots): + @property + def line_width(self): + return self.marker_size + + @property + def name(self): + return "{}DTrail".format(self.dim) + + +class Draw2DLines(DrawXDLines): + @property + def dim(self): + return 2 + + +class Draw3DLines(DrawXDLines): + @property + def dim(self): + return 3 + + +class Draw2DDots(DrawXDDots): + @property + def dim(self): + return 2 + + +class Draw3DDots(DrawXDDots): + @property + def dim(self): + return 3 + + +class Draw2DTrail(DrawXDTrail): + @property + def dim(self): + return 2 + + +class Draw3DTrail(DrawXDTrail): + @property + def dim(self): + return 3 + diff --git a/poselib/poselib/visualization/skeleton_plotter_tasks.py b/poselib/poselib/visualization/skeleton_plotter_tasks.py new file mode 100644 index 0000000..637497b --- /dev/null +++ b/poselib/poselib/visualization/skeleton_plotter_tasks.py @@ -0,0 +1,194 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +""" +This is where all skeleton related complex tasks are defined (skeleton state and skeleton +motion) +""" +import numpy as np + +from .core import BasePlotterTask +from .simple_plotter_tasks import Draw3DDots, Draw3DLines, Draw3DTrail + + +class Draw3DSkeletonState(BasePlotterTask): + _lines_task: Draw3DLines # sub-task for drawing lines + _dots_task: Draw3DDots # sub-task for drawing dots + + def __init__( + self, + task_name: str, + skeleton_state, + joints_color: str = "red", + lines_color: str = "blue", + alpha=1.0, + ) -> None: + super().__init__(task_name=task_name, task_type="3DSkeletonState") + lines, dots = Draw3DSkeletonState._get_lines_and_dots(skeleton_state) + self._lines_task = Draw3DLines( + self.get_scoped_name("bodies"), lines, joints_color, alpha=alpha + ) + self._dots_task = Draw3DDots( + self.get_scoped_name("joints"), dots, lines_color, alpha=alpha + ) + + @property + def name(self): + return "3DSkeleton" + + def update(self, skeleton_state) -> None: + self._update(*Draw3DSkeletonState._get_lines_and_dots(skeleton_state)) + + @staticmethod + def _get_lines_and_dots(skeleton_state): + """Get all the lines and dots needed to draw the skeleton state + """ + assert ( + len(skeleton_state.tensor.shape) == 1 + ), "the state has to be zero dimensional" + dots = skeleton_state.global_translation.numpy() + skeleton_tree = skeleton_state.skeleton_tree + parent_indices = skeleton_tree.parent_indices.numpy() + lines = [] + for node_index in range(len(skeleton_tree)): + parent_index = parent_indices[node_index] + if parent_index != -1: + lines.append([dots[node_index], dots[parent_index]]) + lines = np.array(lines) + return lines, dots + + def _update(self, lines, dots) -> None: + self._lines_task.update(lines) + self._dots_task.update(dots) + + def __iter__(self): + yield from self._lines_task + yield from self._dots_task + + +class Draw3DSkeletonMotion(BasePlotterTask): + def __init__( + self, + task_name: str, + skeleton_motion, + frame_index=None, + joints_color="red", + lines_color="blue", + velocity_color="green", + angular_velocity_color="purple", + trail_color="black", + trail_length=10, + alpha=1.0, + ) -> None: + super().__init__(task_name=task_name, task_type="3DSkeletonMotion") + self._trail_length = trail_length + self._skeleton_motion = skeleton_motion + # if frame_index is None: + curr_skeleton_motion = self._skeleton_motion.clone() + if frame_index is not None: + curr_skeleton_motion.tensor = self._skeleton_motion.tensor[frame_index, :] + # else: + # curr_skeleton_motion = self._skeleton_motion[frame_index, :] + self._skeleton_state_task = Draw3DSkeletonState( + self.get_scoped_name("skeleton_state"), + curr_skeleton_motion, + joints_color=joints_color, + lines_color=lines_color, + alpha=alpha, + ) + vel_lines, avel_lines = Draw3DSkeletonMotion._get_vel_and_avel( + curr_skeleton_motion + ) + self._com_pos = curr_skeleton_motion.root_translation.numpy()[ + np.newaxis, ... + ].repeat(trail_length, axis=0) + self._vel_task = Draw3DLines( + self.get_scoped_name("velocity"), + vel_lines, + velocity_color, + influence_lim=False, + alpha=alpha, + ) + self._avel_task = Draw3DLines( + self.get_scoped_name("angular_velocity"), + avel_lines, + angular_velocity_color, + influence_lim=False, + alpha=alpha, + ) + self._com_trail_task = Draw3DTrail( + self.get_scoped_name("com_trail"), + self._com_pos, + trail_color, + marker_size=2, + influence_lim=True, + alpha=alpha, + ) + + @property + def name(self): + return "3DSkeletonMotion" + + def update(self, frame_index=None, reset_trail=False, skeleton_motion=None) -> None: + if skeleton_motion is not None: + self._skeleton_motion = skeleton_motion + + curr_skeleton_motion = self._skeleton_motion.clone() + if frame_index is not None: + curr_skeleton_motion.tensor = curr_skeleton_motion.tensor[frame_index, :] + if reset_trail: + self._com_pos = curr_skeleton_motion.root_translation.numpy()[ + np.newaxis, ... + ].repeat(self._trail_length, axis=0) + else: + self._com_pos = np.concatenate( + ( + curr_skeleton_motion.root_translation.numpy()[np.newaxis, ...], + self._com_pos[:-1], + ), + axis=0, + ) + self._skeleton_state_task.update(curr_skeleton_motion) + self._com_trail_task.update(self._com_pos) + self._update(*Draw3DSkeletonMotion._get_vel_and_avel(curr_skeleton_motion)) + + @staticmethod + def _get_vel_and_avel(skeleton_motion): + """Get all the velocity and angular velocity lines + """ + pos = skeleton_motion.global_translation.numpy() + vel = skeleton_motion.global_velocity.numpy() + avel = skeleton_motion.global_angular_velocity.numpy() + + vel_lines = np.stack((pos, pos + vel * 0.02), axis=1) + avel_lines = np.stack((pos, pos + avel * 0.01), axis=1) + return vel_lines, avel_lines + + def _update(self, vel_lines, avel_lines) -> None: + self._vel_task.update(vel_lines) + self._avel_task.update(avel_lines) + + def __iter__(self): + yield from self._skeleton_state_task + yield from self._vel_task + yield from self._avel_task + yield from self._com_trail_task + + +class Draw3DSkeletonMotions(BasePlotterTask): + def __init__(self, skeleton_motion_tasks) -> None: + self._skeleton_motion_tasks = skeleton_motion_tasks + + @property + def name(self): + return "3DSkeletonMotions" + + def update(self, frame_index) -> None: + list(map(lambda x: x.update(frame_index), self._skeleton_motion_tasks)) + + def __iter__(self): + yield from self._skeleton_state_tasks diff --git a/poselib/poselib/visualization/tests/__init__.py b/poselib/poselib/visualization/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/poselib/poselib/visualization/tests/test_plotter.py b/poselib/poselib/visualization/tests/test_plotter.py new file mode 100644 index 0000000..7ef1fb6 --- /dev/null +++ b/poselib/poselib/visualization/tests/test_plotter.py @@ -0,0 +1,16 @@ +from typing import cast + +import matplotlib.pyplot as plt +import numpy as np + +from ..core import BasePlotterTask, BasePlotterTasks +from ..plt_plotter import Matplotlib3DPlotter +from ..simple_plotter_tasks import Draw3DDots, Draw3DLines + +task = Draw3DLines(task_name="test", + lines=np.array([[[0, 0, 0], [0, 0, 1]], [[0, 1, 1], [0, 1, 0]]]), color="blue") +task2 = Draw3DDots(task_name="test2", + dots=np.array([[0, 0, 0], [0, 0, 1], [0, 1, 1], [0, 1, 0]]), color="red") +task3 = BasePlotterTasks([task, task2]) +plotter = Matplotlib3DPlotter(cast(BasePlotterTask, task3)) +plt.show() diff --git a/poselib/setup.py b/poselib/setup.py new file mode 100644 index 0000000..c041197 --- /dev/null +++ b/poselib/setup.py @@ -0,0 +1,19 @@ +from setuptools import setup + +setup( + name="poselib", + packages=["poselib"], + version="0.0.42", + description="Framework Agnostic Tensor Programming", + author="Qiyang Li, Kelly Guo, Brendon Matusch", + classifiers=[ + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "License :: OSI Approved :: GNU General Public License (GPL)", + "Operating System :: OS Independent", + "Development Status :: 1 - Planning", + "Environment :: Console", + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering :: GIS", + ], +) diff --git a/requirement.txt b/requirement.txt new file mode 100644 index 0000000..825647f --- /dev/null +++ b/requirement.txt @@ -0,0 +1,39 @@ +mujoco +pytorch_lightning +numpy-stl +vtk +patchelf +termcolor +torchgeometry +scikit-image +numpy +scipy +ipdb +joblib>=1.2.0 +opencv-python==4.6.0.66 +tqdm +pyyaml +wandb +scikit-image +gym +git+https://github.com/ZhengyiLuo/smplx.git@master +git+https://github.com/ZhengyiLuo/SMPLSim.git@master +lxml +human_body_prior +autograd +scikit-learn +chumpy +patchelf +rl-games==1.1.4 +wandb +pyvirtualdisplay +chumpy +patchelf +lxml +ipdb +chardet +cchardet +imageio-ffmpeg +easydict +open3d +gdown \ No newline at end of file diff --git a/scripts/data_process/convert_amass_data.py b/scripts/data_process/convert_amass_data.py new file mode 100644 index 0000000..c0063c7 --- /dev/null +++ b/scripts/data_process/convert_amass_data.py @@ -0,0 +1,152 @@ +import glob +import os +import sys +import pdb +import os.path as osp +sys.path.append(os.getcwd()) + +import torch +from scipy.spatial.transform import Rotation as sRot +import numpy as np +import joblib +from tqdm import tqdm +import argparse +import cv2 +from poselib.poselib.skeleton.skeleton3d import SkeletonTree, SkeletonMotion, SkeletonState +from smpl_sim.smpllib.smpl_joint_names import SMPL_MUJOCO_NAMES, SMPL_BONE_ORDER_NAMES +from smpl_sim.smpllib.smpl_local_robot import SMPL_Robot as LocalRobot + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--debug", action="store_true", default=False) + parser.add_argument("--path", type=str, default="sample_data/amass_db_smplh.pt") + args = parser.parse_args() + + process_split = "train" + upright_start = True + robot_cfg = { + "mesh": False, + "rel_joint_lm": True, + "upright_start": upright_start, + "remove_toe": False, + "real_weight": True, + "real_weight_porpotion_capsules": True, + "real_weight_porpotion_boxes": True, + "replace_feet": True, + "masterfoot": False, + "big_ankle": True, + "freeze_hand": False, + "box_body": False, + "master_range": 50, + "body_params": {}, + "joint_params": {}, + "geom_params": {}, + "actuator_params": {}, + "model": "smpl", + } + + smpl_local_robot = LocalRobot(robot_cfg,) + all_pkls = glob.glob("AMASS_data/**/*.npz", recursive=True) + amass_occlusion = joblib.load("sample_data/amass_copycat_occlusion_v3.pkl") + amass_full_motion_dict = {} + amass_splits = { + 'vald': ['HumanEva', 'MPI_HDM05', 'SFU', 'MPI_mosh'], + 'test': ['Transitions_mocap', 'SSM_synced'], + 'train': ['CMU', 'MPI_Limits', 'TotalCapture', 'KIT', 'EKUT', 'TCD_handMocap', "BMLhandball", "DanceDB", "ACCAD", "BMLmovi", "BioMotionLab_NTroje", "Eyes_Japan_Dataset", "DFaust_67"] # Adding ACCAD + } + process_set = amass_splits[process_split] + length_acc = [] + for data_path in tqdm(all_pkls): + bound = 0 + splits = data_path.split("/")[7:] + key_name_dump = "0-" + "_".join(splits).replace(".npz", "") + + if (not splits[0] in process_set): + continue + + if key_name_dump in amass_occlusion: + issue = amass_occlusion[key_name_dump]["issue"] + if (issue == "sitting" or issue == "airborne") and "idxes" in amass_occlusion[key_name_dump]: + bound = amass_occlusion[key_name_dump]["idxes"][0] # This bounded is calucaled assuming 30 FPS..... + if bound < 10: + print("bound too small", key_name_dump, bound) + continue + else: + print("issue irrecoverable", key_name_dump, issue) + continue + + entry_data = dict(np.load(open(data_path, "rb"), allow_pickle=True)) + + if not 'mocap_framerate' in entry_data: + continue + framerate = entry_data['mocap_framerate'] + + if "0-KIT_442_PizzaDelivery02_poses" == key_name_dump: + bound = -2 + + skip = int(framerate/30) + root_trans = entry_data['trans'][::skip, :] + pose_aa = np.concatenate([entry_data['poses'][::skip, :66], np.zeros((root_trans.shape[0], 6))], axis = -1) + betas = entry_data['betas'] + gender = entry_data['gender'] + N = pose_aa.shape[0] + + if bound == 0: + bound = N + + root_trans = root_trans[:bound] + pose_aa = pose_aa[:bound] + N = pose_aa.shape[0] + if N < 10: + continue + + smpl_2_mujoco = [SMPL_BONE_ORDER_NAMES.index(q) for q in SMPL_MUJOCO_NAMES if q in SMPL_BONE_ORDER_NAMES] + pose_aa_mj = pose_aa.reshape(N, 24, 3)[:, smpl_2_mujoco] + pose_quat = sRot.from_rotvec(pose_aa_mj.reshape(-1, 3)).as_quat().reshape(N, 24, 4) + + beta = np.zeros((16)) + gender_number, beta[:], gender = [0], 0, "neutral" + # print("using neutral model") + smpl_local_robot.load_from_skeleton(betas=torch.from_numpy(beta[None,]), gender=gender_number, objs_info=None) + smpl_local_robot.write_xml(f"phc/data/assets/mjcf/{robot_cfg['model']}_humanoid.xml") + skeleton_tree = SkeletonTree.from_mjcf(f"phc/data/assets/mjcf/{robot_cfg['model']}_humanoid.xml") + root_trans_offset = torch.from_numpy(root_trans) + skeleton_tree.local_translation[0] + + new_sk_state = SkeletonState.from_rotation_and_root_translation( + skeleton_tree, # This is the wrong skeleton tree (location wise) here, but it's fine since we only use the parent relationship here. + torch.from_numpy(pose_quat), + root_trans_offset, + is_local=True) + + if robot_cfg['upright_start']: + pose_quat_global = (sRot.from_quat(new_sk_state.global_rotation.reshape(-1, 4).numpy()) * sRot.from_quat([0.5, 0.5, 0.5, 0.5]).inv()).as_quat().reshape(N, -1, 4) # should fix pose_quat as well here... + + new_sk_state = SkeletonState.from_rotation_and_root_translation(skeleton_tree, torch.from_numpy(pose_quat_global), root_trans_offset, is_local=False) + pose_quat = new_sk_state.local_rotation.numpy() + + + pose_quat_global = new_sk_state.global_rotation.numpy() + pose_quat = new_sk_state.local_rotation.numpy() + fps = 30 + + new_motion_out = {} + new_motion_out['pose_quat_global'] = pose_quat_global + new_motion_out['pose_quat'] = pose_quat + new_motion_out['trans_orig'] = root_trans + new_motion_out['root_trans_offset'] = root_trans_offset + new_motion_out['beta'] = beta + new_motion_out['gender'] = gender + new_motion_out['pose_aa'] = pose_aa + new_motion_out['fps'] = fps + + amass_full_motion_dict[key_name_dump] = new_motion_out + + import ipdb; ipdb.set_trace() + if upright_start: + joblib.dump(amass_full_motion_dict, "data/amass/amass_train_take6_upright.pkl", compress=True) + else: + joblib.dump(amass_full_motion_dict, "data/amass/amass_train_take6.pkl", compress=True) + # joblib.dump(amass_full_motion_dict, "data/amass/amass_test_take6.pkl", compress=True) + # joblib.dump(amass_full_motion_dict, "data/amass_x/singles/total_capture.pkl", compress=True) + # joblib.dump(amass_full_motion_dict, "data/amass_x/upright/singles/total_capture.pkl", compress=True) \ No newline at end of file diff --git a/scripts/data_process/convert_amass_isaac.py b/scripts/data_process/convert_amass_isaac.py new file mode 100644 index 0000000..e4d7fc8 --- /dev/null +++ b/scripts/data_process/convert_amass_isaac.py @@ -0,0 +1,156 @@ +from ast import Try +import torch +import joblib +import matplotlib.pyplot as plt +import numpy as np +from scipy import ndimage +from scipy.spatial.transform import Rotation as sRot +import glob +import os +import sys +import pdb +import os.path as osp +from pathlib import Path + +sys.path.append(os.getcwd()) + +from smpl_sim.khrylib.utils import get_body_qposaddr +from smpl_sim.smpllib.smpl_mujoco import SMPL_BONE_ORDER_NAMES as joint_names +from smpl_sim.smpllib.smpl_local_robot import SMPL_Robot as LocalRobot +import scipy.ndimage.filters as filters +from typing import List, Optional +from tqdm import tqdm +from poselib.poselib.skeleton.skeleton3d import SkeletonTree, SkeletonMotion, SkeletonState +import argparse + +def run(in_file: str, out_file: str): + + robot_cfg = { + "mesh": False, + "model": "smpl", + "upright_start": True, + "body_params": {}, + "joint_params": {}, + "geom_params": {}, + "actuator_params": {}, + } + print(robot_cfg) + + smpl_local_robot = LocalRobot( + robot_cfg, + data_dir="data/smpl", + ) + + amass_data = joblib.load(in_file) + + double = False + + mujoco_joint_names = ['Pelvis', 'L_Hip', 'L_Knee', 'L_Ankle', 'L_Toe', 'R_Hip', 'R_Knee', 'R_Ankle', 'R_Toe', 'Torso', 'Spine', 'Chest', 'Neck', 'Head', 'L_Thorax', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'L_Hand', 'R_Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'R_Hand'] + + + amass_full_motion_dict = {} + for key_name in tqdm(amass_data.keys()): + smpl_data_entry = amass_data[key_name] + B = smpl_data_entry['pose_aa'].shape[0] + + start, end = 0, 0 + + pose_aa = smpl_data_entry['pose_aa'].copy()[start:] + root_trans = smpl_data_entry['trans'].copy()[start:] + B = pose_aa.shape[0] + + beta = smpl_data_entry['beta'].copy() if "beta" in smpl_data_entry else smpl_data_entry['betas'].copy() + if len(beta.shape) == 2: + beta = beta[0] + + gender = smpl_data_entry.get("gender", "neutral") + fps = 30.0 + + if isinstance(gender, np.ndarray): + gender = gender.item() + + if isinstance(gender, bytes): + gender = gender.decode("utf-8") + if gender == "neutral": + gender_number = [0] + elif gender == "male": + gender_number = [1] + elif gender == "female": + gender_number = [2] + else: + import ipdb + ipdb.set_trace() + raise Exception("Gender Not Supported!!") + + smpl_2_mujoco = [joint_names.index(q) for q in mujoco_joint_names if q in joint_names] + batch_size = pose_aa.shape[0] + pose_aa = np.concatenate([pose_aa[:, :66], np.zeros((batch_size, 6))], axis=1) + pose_aa_mj = pose_aa.reshape(-1, 24, 3)[..., smpl_2_mujoco, :].copy() + + num = 1 + if double: + num = 2 + for idx in range(num): + pose_quat = sRot.from_rotvec(pose_aa_mj.reshape(-1, 3)).as_quat().reshape(batch_size, 24, 4) + + gender_number, beta[:], gender = [0], 0, "neutral" + print("using neutral model") + + smpl_local_robot.load_from_skeleton(betas=torch.from_numpy(beta[None,]), gender=gender_number, objs_info=None) + smpl_local_robot.write_xml("phc/data/assets/mjcf/smpl_humanoid_1.xml") + skeleton_tree = SkeletonTree.from_mjcf("phc/data/assets/mjcf/smpl_humanoid_1.xml") + + root_trans_offset = torch.from_numpy(root_trans) + skeleton_tree.local_translation[0] + + new_sk_state = SkeletonState.from_rotation_and_root_translation( + skeleton_tree, # This is the wrong skeleton tree (location wise) here, but it's fine since we only use the parent relationship here. + torch.from_numpy(pose_quat), + root_trans_offset, + is_local=True) + + if robot_cfg['upright_start']: + pose_quat_global = (sRot.from_quat(new_sk_state.global_rotation.reshape(-1, 4).numpy()) * sRot.from_quat([0.5, 0.5, 0.5, 0.5]).inv()).as_quat().reshape(B, -1, 4) # should fix pose_quat as well here... + + new_sk_state = SkeletonState.from_rotation_and_root_translation(skeleton_tree, torch.from_numpy(pose_quat_global), root_trans_offset, is_local=False) + pose_quat = new_sk_state.local_rotation.numpy() + + ############################################################ + # key_name_dump = key_name + f"_{idx}" + key_name_dump = key_name + if idx == 1: + left_to_right_index = [0, 5, 6, 7, 8, 1, 2, 3, 4, 9, 10, 11, 12, 13, 19, 20, 21, 22, 23, 14, 15, 16, 17, 18] + pose_quat_global = pose_quat_global[:, left_to_right_index] + pose_quat_global[..., 0] *= -1 + pose_quat_global[..., 2] *= -1 + + root_trans_offset[..., 1] *= -1 + ############################################################ + + new_motion_out = {} + new_motion_out['pose_quat_global'] = pose_quat_global + new_motion_out['pose_quat'] = pose_quat + new_motion_out['trans_orig'] = root_trans + new_motion_out['root_trans_offset'] = root_trans_offset + new_motion_out['beta'] = beta + new_motion_out['gender'] = gender + new_motion_out['pose_aa'] = pose_aa + new_motion_out['fps'] = fps + amass_full_motion_dict[key_name_dump] = new_motion_out + + Path(out_file).parents[0].mkdir(parents=True, exist_ok=True) + joblib.dump(amass_full_motion_dict, out_file) + return + +# import ipdb + +# ipdb.set_trace() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in_file", type=str, default="sample_data/amass_copycat_take6_train.pkl") + parser.add_argument("--out_file", type=str, default="data/amass/pkls/amass_isaac_im_train_take6_upright_slim.pkl") + args = parser.parse_args() + run( + in_file=args.in_file, + out_file=args.out_file + ) diff --git a/scripts/data_process/convert_data_mdm.py b/scripts/data_process/convert_data_mdm.py new file mode 100644 index 0000000..45d120e --- /dev/null +++ b/scripts/data_process/convert_data_mdm.py @@ -0,0 +1,158 @@ +from ast import Try +import torch +import joblib +import matplotlib.pyplot as plt +import numpy as np +from scipy import ndimage +from scipy.spatial.transform import Rotation as sRot +import glob +import os +import sys +import pdb +import os.path as osp + +sys.path.append(os.getcwd()) + +from smpl_sim.utils.config_utils.copycat_config import Config as CC_Config +from smpl_sim.khrylib.utils import get_body_qposaddr +from smpl_sim.smpllib.smpl_mujoco import SMPL_BONE_ORDER_NAMES as joint_names +from smpl_sim.smpllib.smpl_robot import Robot +from smpl_sim.smpllib.smpl_local_robot import SMPL_Robot as LocalRobot +import scipy.ndimage.filters as filters +from typing import List, Optional +from tqdm import tqdm +from poselib.poselib.skeleton.skeleton3d import SkeletonTree, SkeletonMotion, SkeletonState + +robot_cfg = { + "mesh": False, + "model": "smpl", + "upright_start": True, + "body_params": {}, + "joint_params": {}, + "geom_params": {}, + "actuator_params": {}, +} +print(robot_cfg) + +smpl_local_robot = LocalRobot( + robot_cfg, + data_dir="data/smpl", +) +# res_data = joblib.load("data/mdm/res.pk") +# res_data = joblib.load("data/mdm/res_wave.pk") +# res_data = joblib.load("data/mdm/res_phone.pk") +res_data = joblib.load("data/mdm/res_run.pk") + +ipdb.set_trace() +amass_data = {} +for i in range(len(res_data['json_file']['thetas'])): + pose_euler = np.array(res_data['json_file']['thetas'])[i].reshape(-1, 24, 3) + B = pose_euler.shape[0] + trans = np.array(res_data['json_file']['root_translation'])[i] + pose_aa = sRot.from_euler('XYZ', pose_euler.reshape(-1, 3), degrees=True).as_rotvec().reshape(B, 72) + + transform = sRot.from_euler('xyz', np.array([np.pi / 2, 0, 0]), degrees=False) + new_root = (transform * sRot.from_rotvec(pose_aa[:, :3])).as_rotvec() + pose_aa[:, :3] = new_root + + trans = trans.dot(transform.as_matrix().T) + trans[:, 2] = trans[:, 2] - (trans[0, 2] - 0.92) + + amass_data[f"{i}"] = {"pose_aa": pose_aa, "trans": trans, 'beta': np.zeros(10)} + +double = False + +mujoco_joint_names = ['Pelvis', 'L_Hip', 'L_Knee', 'L_Ankle', 'L_Toe', 'R_Hip', 'R_Knee', 'R_Ankle', 'R_Toe', 'Torso', 'Spine', 'Chest', 'Neck', 'Head', 'L_Thorax', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'L_Hand', 'R_Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'R_Hand'] +amass_full_motion_dict = {} +for key_name in tqdm(amass_data.keys()): + key_name_dump = key_name + smpl_data_entry = amass_data[key_name] + file_name = f"data/amass/singles/{key_name}.npy" + B = smpl_data_entry['pose_aa'].shape[0] + + start, end = 0, 0 + + pose_aa = smpl_data_entry['pose_aa'].copy()[start:] + root_trans = smpl_data_entry['trans'].copy()[start:] + B = pose_aa.shape[0] + + beta = smpl_data_entry['beta'].copy() if "beta" in smpl_data_entry else smpl_data_entry['betas'].copy() + if len(beta.shape) == 2: + beta = beta[0] + + gender = smpl_data_entry.get("gender", "neutral") + fps = 30.0 + + if isinstance(gender, np.ndarray): + gender = gender.item() + + if isinstance(gender, bytes): + gender = gender.decode("utf-8") + if gender == "neutral": + gender_number = [0] + elif gender == "male": + gender_number = [1] + elif gender == "female": + gender_number = [2] + else: + import ipdb + ipdb.set_trace() + raise Exception("Gender Not Supported!!") + + smpl_2_mujoco = [joint_names.index(q) for q in mujoco_joint_names if q in joint_names] + batch_size = pose_aa.shape[0] + pose_aa = np.concatenate([pose_aa[:, :66], np.zeros((batch_size, 6))], axis=1) + pose_aa_mj = pose_aa.reshape(-1, 24, 3)[..., smpl_2_mujoco, :].copy() + + num = 1 + pose_quat = sRot.from_rotvec(pose_aa_mj.reshape(-1, 3)).as_quat().reshape(batch_size, 24, 4) + + gender_number, beta[:], gender = [0], 0, "neutral" + print("using neutral model") + + smpl_local_robot.load_from_skeleton(betas=torch.from_numpy(beta[None,]), gender=gender_number, objs_info=None) + smpl_local_robot.write_xml("pulse/data/assets/mjcf/smpl_humanoid_1.xml") + skeleton_tree = SkeletonTree.from_mjcf("pulse/data/assets/mjcf/smpl_humanoid_1.xml") + + root_trans_offset = torch.from_numpy(root_trans) + skeleton_tree.local_translation[0] + + new_sk_state = SkeletonState.from_rotation_and_root_translation( + skeleton_tree, # This is the wrong skeleton tree (location wise) here, but it's fine since we only use the parent relationship here. + torch.from_numpy(pose_quat), + root_trans_offset, + is_local=True) + + if robot_cfg['upright_start']: + pose_quat_global = (sRot.from_quat(new_sk_state.global_rotation.reshape(-1, 4).numpy()) * sRot.from_quat([0.5, 0.5, 0.5, 0.5]).inv()).as_quat().reshape(B, -1, 4) # should fix pose_quat as well here... + + print("############### filtering!!! ###############") + import scipy.ndimage.filters as filters + from smpl_sim.utils.transform_utils import quat_correct + root_trans_offset = filters.gaussian_filter1d(root_trans_offset, 3, axis=0, mode="nearest") + root_trans_offset = torch.from_numpy(root_trans_offset) + pose_quat_global = np.stack([quat_correct(pose_quat_global[:, i]) for i in range(pose_quat_global.shape[1])], axis=1) + + # select_quats = np.linalg.norm(pose_quat_global[:-1, :] - pose_quat_global[1:, :], axis=2) > np.linalg.norm(pose_quat_global[:-1, :] + pose_quat_global[1:, :], axis=2) # checkup + + filtered_quats = filters.gaussian_filter1d(pose_quat_global, 2, axis=0, mode="nearest") + pose_quat_global = filtered_quats / np.linalg.norm(filtered_quats, axis=-1)[..., None] + print("############### filtering!!! ###############") + new_sk_state = SkeletonState.from_rotation_and_root_translation(skeleton_tree, torch.from_numpy(pose_quat_global), root_trans_offset, is_local=False) + pose_quat = new_sk_state.local_rotation.numpy() + + new_motion_out = {} + new_motion_out['pose_quat_global'] = pose_quat_global + new_motion_out['pose_quat'] = pose_quat + new_motion_out['trans_orig'] = root_trans + new_motion_out['root_trans_offset'] = root_trans_offset + new_motion_out['beta'] = beta + new_motion_out['gender'] = gender + new_motion_out['pose_aa'] = pose_aa + new_motion_out['fps'] = fps + amass_full_motion_dict[key_name_dump] = new_motion_out + +import ipdb + +ipdb.set_trace() +joblib.dump(amass_full_motion_dict, "data/mdm/mdm_isaac_run.pkl") +# joblib.dump(amass_full_motion_dict, "data/amass/pkls/hybrik/test_hyberIK_sfv.pkl") diff --git a/scripts/data_process/convert_data_smpl.py b/scripts/data_process/convert_data_smpl.py new file mode 100644 index 0000000..79131b0 --- /dev/null +++ b/scripts/data_process/convert_data_smpl.py @@ -0,0 +1,140 @@ +from ast import Try +import torch +import joblib +import matplotlib.pyplot as plt +import numpy as np +from scipy import ndimage +from scipy.spatial.transform import Rotation as sRot +import glob +import os +import sys +import pdb +import os.path as osp + +sys.path.append(os.getcwd()) + +from smpl_sim.khrylib.utils import get_body_qposaddr +from smpl_sim.smpllib.smpl_mujoco import SMPL_BONE_ORDER_NAMES as joint_names +from smpl_sim.smpllib.smpl_local_robot import SMPL_Robot as LocalRobot +import scipy.ndimage.filters as filters +from typing import List, Optional +from tqdm import tqdm +from poselib.poselib.skeleton.skeleton3d import SkeletonTree, SkeletonMotion, SkeletonState + +robot_cfg = { + "mesh": False, + "model": "smpl", + "upright_start": True, + "body_params": {}, + "joint_params": {}, + "geom_params": {}, + "actuator_params": {}, +} +print(robot_cfg) + +smpl_local_robot = LocalRobot( + robot_cfg, + data_dir="data/smpl", +) + +amass_data = joblib.load("insert_your_data") + +double = False + +mujoco_joint_names = ['Pelvis', 'L_Hip', 'L_Knee', 'L_Ankle', 'L_Toe', 'R_Hip', 'R_Knee', 'R_Ankle', 'R_Toe', 'Torso', 'Spine', 'Chest', 'Neck', 'Head', 'L_Thorax', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'L_Hand', 'R_Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'R_Hand'] + + + +amass_remove_data = [] + +full_motion_dict = {} +for key_name in tqdm(amass_data.keys()): + smpl_data_entry = amass_data[key_name] + B = smpl_data_entry['pose_aa'].shape[0] + + start, end = 0, 0 + + pose_aa = smpl_data_entry['pose_aa'].copy()[start:] + root_trans = smpl_data_entry['trans'].copy()[start:] + B = pose_aa.shape[0] + + beta = smpl_data_entry['beta'].copy() if "beta" in smpl_data_entry else smpl_data_entry['betas'].copy() + if len(beta.shape) == 2: + beta = beta[0] + + gender = smpl_data_entry.get("gender", "neutral") + fps = smpl_data_entry.get("fps", 30.0) + + if isinstance(gender, np.ndarray): + gender = gender.item() + + if isinstance(gender, bytes): + gender = gender.decode("utf-8") + if gender == "neutral": + gender_number = [0] + elif gender == "male": + gender_number = [1] + elif gender == "female": + gender_number = [2] + else: + import ipdb + ipdb.set_trace() + raise Exception("Gender Not Supported!!") + + smpl_2_mujoco = [joint_names.index(q) for q in mujoco_joint_names if q in joint_names] + batch_size = pose_aa.shape[0] + pose_aa = np.concatenate([pose_aa[:, :66], np.zeros((batch_size, 6))], axis=1) + pose_aa_mj = pose_aa.reshape(-1, 24, 3)[..., smpl_2_mujoco, :].copy() + + num = 1 + if double: + num = 2 + for idx in range(num): + pose_quat = sRot.from_rotvec(pose_aa_mj.reshape(-1, 3)).as_quat().reshape(batch_size, 24, 4) + + gender_number, beta[:], gender = [0], 0, "neutral" + print("using neutral model") + + smpl_local_robot.load_from_skeleton(betas=torch.from_numpy(beta[None,]), gender=gender_number, objs_info=None) + smpl_local_robot.write_xml("pulse/data/assets/mjcf/smpl_humanoid_1.xml") + skeleton_tree = SkeletonTree.from_mjcf("pulse/data/assets/mjcf/smpl_humanoid_1.xml") + + root_trans_offset = torch.from_numpy(root_trans) + skeleton_tree.local_translation[0] + + new_sk_state = SkeletonState.from_rotation_and_root_translation( + skeleton_tree, # This is the wrong skeleton tree (location wise) here, but it's fine since we only use the parent relationship here. + torch.from_numpy(pose_quat), + root_trans_offset, + is_local=True) + + if robot_cfg['upright_start']: + pose_quat_global = (sRot.from_quat(new_sk_state.global_rotation.reshape(-1, 4).numpy()) * sRot.from_quat([0.5, 0.5, 0.5, 0.5]).inv()).as_quat().reshape(B, -1, 4) # should fix pose_quat as well here... + + new_sk_state = SkeletonState.from_rotation_and_root_translation(skeleton_tree, torch.from_numpy(pose_quat_global), root_trans_offset, is_local=False) + pose_quat = new_sk_state.local_rotation.numpy() + + ############################################################ + # key_name_dump = key_name + f"_{idx}" + key_name_dump = key_name + if idx == 1: + left_to_right_index = [0, 5, 6, 7, 8, 1, 2, 3, 4, 9, 10, 11, 12, 13, 19, 20, 21, 22, 23, 14, 15, 16, 17, 18] + pose_quat_global = pose_quat_global[:, left_to_right_index] + pose_quat_global[..., 0] *= -1 + pose_quat_global[..., 2] *= -1 + + root_trans_offset[..., 1] *= -1 + ############################################################ + + new_motion_out = {} + new_motion_out['pose_quat_global'] = pose_quat_global + new_motion_out['pose_quat'] = pose_quat + new_motion_out['trans_orig'] = root_trans + new_motion_out['root_trans_offset'] = root_trans_offset + new_motion_out['beta'] = beta + new_motion_out['gender'] = gender + new_motion_out['pose_aa'] = pose_aa + new_motion_out['fps'] = fps + full_motion_dict[key_name_dump] = new_motion_out + +import ipdb; ipdb.set_trace() +joblib.dump(full_motion_dict, "insert_your_data") diff --git a/scripts/data_process/process_amass_db.py b/scripts/data_process/process_amass_db.py new file mode 100644 index 0000000..175ded3 --- /dev/null +++ b/scripts/data_process/process_amass_db.py @@ -0,0 +1,289 @@ +import glob +import os +import sys +import pdb +import os.path as osp +sys.path.append(os.getcwd()) + +import numpy as np +import glob +import pickle as pk +import joblib +import torch +import argparse + +from tqdm import tqdm +from smpl_sim.utils.transform_utils import ( + convert_aa_to_orth6d, + convert_orth_6d_to_aa, + vertizalize_smpl_root, + rotation_matrix_to_angle_axis, + rot6d_to_rotmat, +) +from scipy.spatial.transform import Rotation as sRot +from smpl_sim.smpllib.smpl_parser import SMPL_Parser +from smpl_sim.utils.flags import flags + +np.random.seed(1) +left_right_idx = [ + 0, + 2, + 1, + 3, + 5, + 4, + 6, + 8, + 7, + 9, + 11, + 10, + 12, + 14, + 13, + 15, + 17, + 16, + 19, + 18, + 21, + 20, + 23, + 22, +] + + +def left_to_rigth_euler(pose_euler): + pose_euler[:, :, 0] = pose_euler[:, :, 0] * -1 + pose_euler[:, :, 2] = pose_euler[:, :, 2] * -1 + pose_euler = pose_euler[:, left_right_idx, :] + return pose_euler + + +def flip_smpl(pose, trans=None): + """ + Pose input batch * 72 + """ + curr_spose = sRot.from_rotvec(pose.reshape(-1, 3)) + curr_spose_euler = curr_spose.as_euler("ZXY", degrees=False).reshape(pose.shape[0], 24, 3) + curr_spose_euler = left_to_rigth_euler(curr_spose_euler) + curr_spose_rot = sRot.from_euler("ZXY", curr_spose_euler.reshape(-1, 3), degrees=False) + curr_spose_aa = curr_spose_rot.as_rotvec().reshape(pose.shape[0], 24, 3) + if trans != None: + pass + # target_root_mat = curr_spose.as_matrix().reshape(pose.shape[0], 24, 3, 3)[:, 0] + # root_mat = curr_spose_rot.as_matrix().reshape(pose.shape[0], 24, 3, 3)[:, 0] + # apply_mat = np.matmul(target_root_mat[0], np.linalg.inv(root_mat[0])) + + return curr_spose_aa.reshape(-1, 72) + + +def sample_random_hemisphere_root(): + rot = np.random.random() * np.pi * 2 + pitch = np.random.random() * np.pi / 3 + np.pi + r = sRot.from_rotvec([pitch, 0, 0]) + r2 = sRot.from_rotvec([0, rot, 0]) + root_vec = (r * r2).as_rotvec() + return root_vec + + +def sample_seq_length(seq, tran, seq_length=150): + if seq_length != -1: + num_possible_seqs = seq.shape[0] // seq_length + max_seq = seq.shape[0] + + start_idx = np.random.randint(0, 10) + start_points = [max(0, max_seq - (seq_length + start_idx))] + + for i in range(1, num_possible_seqs - 1): + start_points.append(i * seq_length + np.random.randint(-10, 10)) + + if num_possible_seqs >= 2: + start_points.append(max_seq - seq_length - np.random.randint(0, 10)) + + seqs = [seq[i:(i + seq_length)] for i in start_points] + trans = [tran[i:(i + seq_length)] for i in start_points] + else: + seqs = [seq] + trans = [tran] + start_points = [] + return seqs, trans, start_points + + +def get_random_shape(batch_size): + shape_params = torch.rand(1, 10).repeat(batch_size, 1) + s_id = torch.tensor(np.random.normal(scale=1.5, size=(3))) + shape_params[:, :3] = s_id + return shape_params + + + +def count_consec(lst): + consec = [1] + for x, y in zip(lst, lst[1:]): + if x == y - 1: + consec[-1] += 1 + else: + consec.append(1) + return consec + + + +def fix_height_smpl_vanilla(pose_aa, th_trans, th_betas, gender, seq_name): + # no filtering, just fix height + gender = gender.item() if isinstance(gender, np.ndarray) else gender + if isinstance(gender, bytes): + gender = gender.decode("utf-8") + + if gender == "neutral": + smpl_parser = smpl_parser_n + elif gender == "male": + smpl_parser = smpl_parser_m + elif gender == "female": + smpl_parser = smpl_parser_f + else: + print(gender) + raise Exception("Gender Not Supported!!") + + batch_size = pose_aa.shape[0] + verts, jts = smpl_parser.get_joints_verts(pose_aa[0:1], th_betas.repeat((1, 1)), th_trans=th_trans[0:1]) + + # vertices = verts[0].numpy() + gp = torch.min(verts[:, :, 2]) + + # if gp < 0: + th_trans[:, 2] -= gp + + return th_trans + +def process_qpos_list(qpos_list): + amass_res = {} + removed_k = [] + pbar = qpos_list + for (k, v) in tqdm(pbar): + # print("=" * 20) + k = "0-" + k + seq_name = k + betas = v["betas"] + gender = v["gender"] + amass_fr = v["mocap_framerate"] + skip = int(amass_fr / target_fr) + amass_pose = v["poses"][::skip] + amass_trans = v["trans"][::skip] + + bound = amass_pose.shape[0] + if k in amass_occlusion: + issue = amass_occlusion[k]["issue"] + if (issue == "sitting" or issue == "airborne") and "idxes" in amass_occlusion[k]: + bound = amass_occlusion[k]["idxes"][0] # This bounded is calucaled assuming 30 FPS..... + if bound < 10: + print("bound too small", k, bound) + continue + else: + print("issue irrecoverable", k, issue) + continue + + seq_length = amass_pose.shape[0] + if seq_length < 10: + continue + with torch.no_grad(): + amass_pose = amass_pose[:bound] + batch_size = amass_pose.shape[0] + amass_pose = np.concatenate([amass_pose[:, :66], np.zeros((batch_size, 6))], axis=1) # We use SMPL and not SMPLH + + pose_aa = torch.tensor(amass_pose) # After sampling the bound + + amass_trans = torch.tensor(amass_trans[:bound]) # After sampling the bound + betas = torch.from_numpy(betas) + + + amass_trans = fix_height_smpl_vanilla( + pose_aa=pose_aa, + th_betas=betas, + th_trans=amass_trans, + gender=gender, + seq_name=k, + ) + + pose_seq_6d = convert_aa_to_orth6d(torch.tensor(pose_aa)).reshape(batch_size, -1, 6) + + amass_res[seq_name] = { + "pose_aa": pose_aa.numpy(), + "pose_6d": pose_seq_6d.numpy(), + # "qpos": qpos, + "trans": amass_trans.numpy(), + "beta": betas.numpy(), + "seq_name": seq_name, + "gender": gender, + } + + if flags.debug and len(amass_res) > 10: + break + print(removed_k) + return amass_res + + +amass_splits = { + 'vald': ['HumanEva', 'MPI_HDM05', 'SFU', 'MPI_mosh'], + 'test': ['Transitions_mocap', 'SSM_synced'], + 'train': ['CMU', 'MPI_Limits', 'TotalCapture', 'Eyes_Japan_Dataset', 'KIT', 'BML', 'EKUT', 'TCD_handMocap', "BMLhandball", "DanceDB", "ACCAD", "BMLmovi", "BioMotionLab", "Eyes", "DFaust"] # Adding ACCAD +} + +amass_split_dict = {} +for k, v in amass_splits.items(): + for d in v: + amass_split_dict[d] = k + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--debug", action="store_true", default=False) + parser.add_argument("--path", type=str, default="sample_data/amass_db_smplh.pt") + args = parser.parse_args() + + np.random.seed(0) + flags.debug = args.debug + take_num = "copycat_take6" + amass_seq_data = {} + seq_length = -1 + + target_fr = 30 + video_annot = {} + counter = 0 + seq_counter = 0 + db_dataset = args.path + amass_db = joblib.load(db_dataset) + amass_occlusion = joblib.load("sample_data/amass_copycat_occlusion_v3.pkl") + + + qpos_list = list(amass_db.items()) + np.random.seed(0) + np.random.shuffle(qpos_list) + smpl_parser_n = SMPL_Parser(model_path="data/smpl", gender="neutral", use_pca=False, create_transl=False) + smpl_parser_m = SMPL_Parser(model_path="data/smpl", gender="male", use_pca=False, create_transl=False) + smpl_parser_f = SMPL_Parser(model_path="data/smpl", gender="female", use_pca=False, create_transl=False) + + amass_seq_data = process_qpos_list(qpos_list) + + + train_data = {} + test_data = {} + valid_data = {} + for k, v in amass_seq_data.items(): + start_name = k.split("-")[1] + found = False + for dataset_key in amass_split_dict.keys(): + if start_name.lower().startswith(dataset_key.lower()): + found = True + split = amass_split_dict[dataset_key] + if split == "test": + test_data[k] = v + elif split == "valid": + valid_data[k] = v + else: + train_data[k] = v + if not found: + print(f"Not found!! {start_name}") + + joblib.dump(train_data, f"sample_data/amass_{take_num}_train.pkl") + joblib.dump(test_data, f"sample_data/amass_{take_num}_test.pkl") + joblib.dump(valid_data, f"sample_data/amass_{take_num}_valid.pkl") diff --git a/scripts/data_process/process_amass_raw.py b/scripts/data_process/process_amass_raw.py new file mode 100644 index 0000000..e011c94 --- /dev/null +++ b/scripts/data_process/process_amass_raw.py @@ -0,0 +1,188 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de +import glob +import os +import sys +import pdb +import os.path as osp + +sys.path.append(os.getcwd()) + + +import os +import joblib +import argparse +import numpy as np +import os.path as osp +from tqdm import tqdm +from pathlib import Path + +dict_keys = ["betas", "dmpls", "gender", "mocap_framerate", "poses", "trans"] + +# extract SMPL joints from SMPL-H model +joints_to_use = np.array( + [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 37, + ] +) +joints_to_use = np.arange(0, 156).reshape((-1, 3))[joints_to_use].reshape(-1) + +all_sequences = [ + "ACCAD", + "BMLmovi", + "BioMotionLab_NTroje", + "CMU", + "DFaust_67", + "EKUT", + "Eyes_Japan_Dataset", + "HumanEva", + "KIT", + "MPI_HDM05", + "MPI_Limits", + "MPI_mosh", + "SFU", + "SSM_synced", + "TCD_handMocap", + "TotalCapture", + "Transitions_mocap", + "BMLhandball", + "DanceDB" +] + +def read_data(folder, sequences): + # sequences = [osp.join(folder, x) for x in sorted(os.listdir(folder)) if osp.isdir(osp.join(folder, x))] + + if sequences == "all": + sequences = all_sequences + + db = {} + print(folder) + for seq_name in sequences: + print(f"Reading {seq_name} sequence...") + seq_folder = osp.join(folder, seq_name) + + datas = read_single_sequence(seq_folder, seq_name) + db.update(datas) + print(seq_name, "number of seqs", len(datas)) + + return db + + +def read_single_sequence(folder, seq_name): + subjects = os.listdir(folder) + + datas = {} + + for subject in tqdm(subjects): + if not osp.isdir(osp.join(folder, subject)): + continue + actions = [ + x for x in os.listdir(osp.join(folder, subject)) if x.endswith(".npz") + ] + + for action in actions: + fname = osp.join(folder, subject, action) + + if fname.endswith("shape.npz"): + continue + + data = dict(np.load(fname)) + # data['poses'] = pose = data['poses'][:, joints_to_use] + + # shape = np.repeat(data['betas'][:10][np.newaxis], pose.shape[0], axis=0) + # theta = np.concatenate([pose,shape], axis=1) + vid_name = f"{seq_name}_{subject}_{action[:-4]}" + + datas[vid_name] = data + # thetas.append(theta) + + return datas + + +def read_seq_data(folder, nsubjects, fps): + subjects = os.listdir(folder) + sequences = {} + + assert nsubjects < len(subjects), "nsubjects should be less than len(subjects)" + + for subject in subjects[:nsubjects]: + actions = os.listdir(osp.join(folder, subject)) + + for action in actions: + data = np.load(osp.join(folder, subject, action)) + mocap_framerate = int(data["mocap_framerate"]) + sampling_freq = mocap_framerate // fps + sequences[(subject, action)] = data["poses"][ + 0::sampling_freq, joints_to_use + ] + + train_set = {} + test_set = {} + + for i, (k, v) in enumerate(sequences.items()): + if i < len(sequences.keys()) - len(sequences.keys()) // 4: + train_set[k] = v + else: + test_set[k] = v + + return train_set, test_set + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--dir", type=str, help="dataset directory", default="data/amass" + ) + parser.add_argument( + "--out_dir", type=str, help="dataset directory", default="out" + ) + parser.add_argument( + '--sequences', type=str, nargs='+', help='which sequences to use', default=all_sequences + ) + + args = parser.parse_args() + out_path = Path(args.out_dir) + out_path.mkdir(exist_ok=True) + db_file = osp.join(out_path, "amass_db_smplh.pt") + + db = read_data(args.dir, sequences=args.sequences) + + + print(f"Saving AMASS dataset to {db_file}") + joblib.dump(db, db_file) diff --git a/scripts/demo/video_to_pose_server.py b/scripts/demo/video_to_pose_server.py new file mode 100644 index 0000000..6f0a204 --- /dev/null +++ b/scripts/demo/video_to_pose_server.py @@ -0,0 +1,394 @@ +#!/usr/bin/env python3 +import os +import cv2 +import joblib +import numpy as np +import time + +import tensorflow as tf +import tensorflow_hub as hub + + +import asyncio +from aiohttp import web +import cv2 +import aiohttp +import numpy as np +import threading +from scipy.spatial.transform import Rotation as sRot + +import time +import torch +from collections import deque +from datetime import datetime +from torchvision import transforms as T +import time +from ultralytics import YOLO +import scipy.interpolate as interpolate + +gpus = tf.config.experimental.list_physical_devices('GPU') +for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + +det_model = YOLO("yolov8s.pt") +# accepts all formats - image/dir/Path/URL/video/PIL/ndarray. 0 for webcam +STANDING_POSE = np.array([[[-0.1443, -0.9426, -0.2548], + [-0.2070, -0.8571, -0.2571], + [-0.0800, -0.8503, -0.2675], + [-0.1555, -1.0663, -0.3057], + [-0.2639, -0.5003, -0.2846], + [-0.0345, -0.4931, -0.3108], + [-0.1587, -1.2094, -0.2755], + [-0.2534, -0.1022, -0.3361], + [-0.0699, -0.1012, -0.3517], + [-0.1548, -1.2679, -0.2675], + [-0.2959, -0.0627, -0.2105], + [-0.0213, -0.0424, -0.2277], + [-0.1408, -1.4894, -0.2892], + [-0.2271, -1.3865, -0.2622], + [-0.0715, -1.3832, -0.2977], + [-0.1428, -1.5753, -0.2303], + [-0.3643, -1.3792, -0.2646], + [ 0.0509, -1.3730, -0.3271], + [-0.3861, -1.1423, -0.3032], + [ 0.0634, -1.1300, -0.3714], + [-0.4086, -0.9130, -0.2000], + [ 0.1203, -0.8943, -0.3002], + [-0.4000, -0.8282, -0.1817], + [ 0.1207, -0.8087, -0.2787]]]).repeat(5, axis = 0) + +def fps_20_to_30(mdm_jts): + jts = [] + N = mdm_jts.shape[0] + for i in range(24): + int_x = mdm_jts[:, i, 0] + int_y = mdm_jts[:, i, 1] + int_z = mdm_jts[:, i, 2] + x = np.arange(0, N) + f_x = interpolate.interp1d(x, int_x) + f_y = interpolate.interp1d(x, int_y) + f_z = interpolate.interp1d(x, int_z) + + new_x = f_x(np.linspace(0, N-1, int(N * 1.5))) + new_y = f_y(np.linspace(0, N-1, int(N * 1.5))) + new_z = f_z(np.linspace(0, N-1, int(N * 1.5))) + jts.append(np.stack([new_x, new_y, new_z], axis = 1)) + jts = np.stack(jts, axis = 1) + return jts + + +def xyxy2xywh(bbox): + x1, y1, x2, y2 = bbox + + cx = (x1 + x2) / 2 + cy = (y1 + y2) / 2 + w = x2 - x1 + h = y2 - y1 + return [cx, cy, w, h] + +def download_model(model_type): + server_prefix = 'https://omnomnom.vision.rwth-aachen.de/data/metrabs' + model_zippath = tf.keras.utils.get_file( + origin=f'{server_prefix}/{model_type}.zip', + extract=True, cache_subdir='models') + model_path = os.path.join(os.path.dirname(model_zippath), model_type) + return model_path + +def start_pose_estimate(): + global pose_mat, trans, dt, reset_offset, offset_height, superfast, j3d, j2d, num_ppl, bbox, frame, fps + offset = np.zeros((5, 1)) + + from scipy.spatial.transform import Rotation as sRot + global_transform = sRot.from_quat([0.5, 0.5, 0.5, 0.5]).inv().as_matrix() + transform = sRot.from_euler('xyz', np.array([-np.pi / 2, 0, 0]), degrees=False).as_matrix() + + prev_box = None + t_s = time.time() + print('### Run Model...') + + # model = tf.saved_model.load(download_model('metrabs_mob3l_y4')) + model = hub.load('https://bit.ly/metrabs_s') # or _l + + skeleton = 'smpl_24' + joint_names = model.per_skeleton_joint_names[skeleton].numpy().astype(str) + joint_edges = model.per_skeleton_joint_edges[skeleton].numpy() + # viz = poseviz.PoseViz(joint_names, joint_edges) + print("==================================> Metrabs model loaded <==================================") + + with torch.no_grad(): + while True: + if not frame is None: + # pred = model.detect_poses(frame, skeleton=skeleton, default_fov_degrees=55, detector_threshold=0.5, num_aug=5) + pred = model.estimate_poses(frame, tf.constant(bbox, dtype=tf.float32), skeleton=skeleton, default_fov_degrees=55, num_aug=1) + + dt = time.time() - t_s + fps = 1/dt + + # camera = poseviz.Camera.from_fov(55, frame.shape[:2]) + # viz.update(frame, pred['boxes'], pred['poses3d'], camera) + pred_j3d = pred['poses3d'].numpy() + num_ppl = min(pred_j3d.shape[0], 5) + + j3d_curr = pred_j3d[:num_ppl]/1000 + if num_ppl < 5: + j3d[num_ppl:, 0, 0] = np.arange(5 - num_ppl) + 1 + + j2d = pred['poses2d'].numpy() + t_s = time.time() + + if reset_offset: + offset[:num_ppl] = - offset_height - j3d_curr[:num_ppl, [0], 1] + reset_offset = False + + j3d_curr[:offset.shape[0], ..., 1] += offset[:num_ppl] + + j3d = j3d.copy() # Trying to handle race condition + j3d[:num_ppl] = j3d_curr + +def get_max_iou_box(det_output, prev_bbox, thrd=0.9): + max_score = 0 + max_bbox = None + for i in range(det_output['boxes'].shape[0]): + bbox = det_output['boxes'][i] + score = det_output['scores'][i] + # if float(score) < thrd: + # continue + # area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) + iou = calc_iou(prev_bbox, bbox) + iou_score = float(score) * iou + if float(iou_score) > max_score: + max_bbox = [float(x) for x in bbox] + max_score = iou_score + + if max_bbox is None: + max_bbox = prev_bbox + + return max_bbox + +def calc_iou(bbox1, bbox2): + bbox1 = [float(x) for x in bbox1] + bbox2 = [float(x) for x in bbox2] + + xA = max(bbox1[0], bbox2[0]) + yA = max(bbox1[1], bbox2[1]) + xB = min(bbox1[2], bbox2[2]) + yB = min(bbox1[3], bbox2[3]) + + interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1) + + box1Area = (bbox1[2] - bbox1[0] + 1) * (bbox1[3] - bbox1[1] + 1) + box2Area = (bbox2[2] - bbox2[0] + 1) * (bbox2[3] - bbox2[1] + 1) + + iou = interArea / float(box1Area + box2Area - interArea) + + return iou + +def get_one_box(det_output, thrd=0.9): + max_area = 0 + max_bbox = None + + if det_output['boxes'].shape[0] == 0 or thrd < 1e-5: + return None + + for i in range(det_output['boxes'].shape[0]): + bbox = det_output['boxes'][i] + score = det_output['scores'][i] + if float(score) < thrd: + continue + area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) + if float(area) > max_area: + max_bbox = [float(x) for x in bbox] + max_area = area + + if max_bbox is None: + return get_one_box(det_output, thrd=thrd - 0.1) + + return max_bbox + +def commandline_input(): + global pose_mat, trans, dt, reset_offset, offset_height, superfast, j3d, j2d, num_ppl, bbox, frame, fps + + while True: + command = input('Type a message to send to the server: ') + if command == 'exit': + print('Exiting!') + raise SystemExit(0) + elif command.startswith("r"): + splits = command.split(":") + if len(splits) > 1: + offset_height = float(splits[-1]) + reset_offset = True + elif command.startswith("fps"): + print(fps) + else: + print("Unkonw command") + +def frames_from_webcam(): + global frame, images_acc, recording, j2d, bbox + cap = cv2.VideoCapture(-1) + prev_box = None + + while (cap.isOpened()): + # Capture frame-by-frame + ret, frame_orig = cap.read() + if not ret: + continue + # x1, y1, x2, y2 = bbox + detec_threshold = 0.6 + + frame = cv2.cvtColor(frame_orig, cv2.COLOR_BGR2RGB) # send to the detector & model + yolo_output = det_model.predict(source=frame, show=False, classes=[0], verbose=False) + + if len(yolo_output[0].boxes) > 0: + yolo_out_xyxy = yolo_output[0].boxes.xyxy.cpu().numpy() + + bbox = np.stack([yolo_out_xyxy[:, 0], yolo_out_xyxy[:, 1], (yolo_out_xyxy[:, 2] - yolo_out_xyxy[:, 0]), (yolo_out_xyxy[:, 3] - yolo_out_xyxy[:, 1])], axis = 1) + + for i in range(len(yolo_out_xyxy)): + x1, y1, x2, y2 = yolo_out_xyxy[i] + frame_orig = cv2.rectangle(frame_orig, (int(x1), int(y1)), (int(x2), int(y2)), (154, 201, 219), 5) + + if not j2d is None: + for pt in j2d.reshape(-1, 2): + x, y = pt + frame_orig = cv2.circle(frame_orig, (int(x), int(y)), 3, (255, 136, 132), 3) + + if recording: + images_acc.append(frame_orig.copy()) + + cv2.imshow('frame', frame_orig) + + if cv2.waitKey(1) == ord('q'): + break + # yield frame + +async def pose_getter(request): + # query env configurations + global pose_mat, trans, dt, j3d, superfast + curr_paths = {} + if superfast: + json_resp = { + "j3d": j3d.tolist(), + "dt": dt, + } + + else: + json_resp = { + "pose_mat": pose_mat.tolist(), + "trans": trans.tolist(), + "dt": dt, + } + + return web.json_response(json_resp) + +# async def commad_interface(request): + + +async def websocket_handler(request): + print('Websocket connection starting') + global pose_mat, trans, dt, sim_talker + sim_talker = aiohttp.web.WebSocketResponse() + + await sim_talker.prepare(request) + print('Websocket connection ready') + + async for msg in sim_talker: + if msg.type == aiohttp.WSMsgType.TEXT: + if msg.data == "get_pose": + await sim_talker.send_json({ + "pose_mat": pose_mat.tolist(), + "trans": trans.tolist(), + "dt": dt, + }) + + print('Websocket connection closed') + return sim_talker + +def write_frames_to_video(frames, out_file_name = "output.mp4", frame_rate = 30, add_text = None, text_color = (255, 255, 255)): + print(f"######################## Writing number of frames {len(frames)} ########################") + if len(frames) == 0: + return + y_shape, x_shape, _ = frames[0].shape + out = cv2.VideoWriter(out_file_name, cv2.VideoWriter_fourcc(*'FMP4'), frame_rate, (x_shape, y_shape)) + transform_dtype = False + transform_256 = False + + if frames[0].dtype != np.uint8: + transform_dtype = True + if np.max(frames[0]) < 1: + transform_256 = True + + for i in range(len(frames)): + curr_frame = frames[i] + + if transform_256: + curr_frame = curr_frame * 256 + if transform_dtype: + curr_frame = curr_frame.astype(np.uint8) + if not add_text is None: + cv2.putText(curr_frame, add_text , (0, 20), 3, 1, text_color) + + out.write(curr_frame) + out.release() + +async def talk_websocket_handler(request): + print('Websocket connection starting') + global reset_offset, trans, offset_height, recording, images_acc + ws_talker = aiohttp.web.WebSocketResponse() + + await ws_talker.prepare(request) + print('Websocket connection ready') + + async for msg in ws_talker: + # print(msg) + if msg.type == aiohttp.WSMsgType.TEXT: + print("\n" + msg.data) + if msg.data.startswith("r"): + splits = msg.data.split(":") + if len(splits) > 1: + offset_height = float(splits[-1]) + reset_offset = True + elif msg.data.startswith("s"): + recording = True + print(f"----------------> recording: {recording}") + # if recording: + # pass + # if recording and not sim_talker is None: + # await sim_talker.send_json({"action": "start_record"}) + elif msg.data.startswith("e"): + recording = False + print(f"----------------> recording: {recording}") + + elif msg.data.startswith("w"): + curr_date_time = datetime.now().strftime('%Y-%m-%d-%H:%M:%S') + out_file_name = f"output/hybrik_{curr_date_time}.mp4" + print(f"----------------> writing video: {out_file_name}") + write_frames_to_video(images_acc, out_file_name = out_file_name) + images_acc = deque(maxlen = 24000) + elif msg.data.startswith("get_pose"): + await sim_talker.send_json({ + "j3d": j3d.tolist(), + "dt": dt, + }) + + await ws_talker.send_str("Done!") + + print('Websocket connection closed') + return ws_talker + + +bbox, pose_mat, j3d, j2d, trans, dt, ws_talkers, reset_offset, offset_height, images_acc, recording, sim_talker, num_ppl, fps= np.zeros([5, 4]), np.zeros([24, 3, 3]), np.zeros([5, 24, 3]), None, np.zeros([3]), 1 / 10, [], True, 0.92, deque(maxlen = 24000), False, None, 0, 0 +frame = None +superfast = True +# main() +app = web.Application(client_max_size=1024**2) +app.router.add_route('GET', '/ws', websocket_handler) +app.router.add_route('GET', '/ws_talk', talk_websocket_handler) +app.router.add_route('GET', '/get_pose', pose_getter) +threading.Thread(target=frames_from_webcam, daemon=True).start() +threading.Thread(target=start_pose_estimate, daemon=True).start() +threading.Thread(target=commandline_input, daemon=True).start() +print("=================================================================") +print("r: reset offset (use r:0.91), s: start recording, e: end recording, w: write video") +print("=================================================================") +web.run_app(app, port=8080) \ No newline at end of file diff --git a/scripts/joint_monkey_smpl.py b/scripts/joint_monkey_smpl.py new file mode 100644 index 0000000..0b57d98 --- /dev/null +++ b/scripts/joint_monkey_smpl.py @@ -0,0 +1,295 @@ +""" +Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + +NVIDIA CORPORATION and its licensors retain all intellectual property +and proprietary rights in and to this software, related documentation +and any modifications thereto. Any use, reproduction, disclosure or +distribution of this software and related documentation without an express +license agreement from NVIDIA CORPORATION is strictly prohibited. + +Joint Monkey +------------ +- Animates degree-of-freedom ranges for a given asset. +- Demonstrates usage of DOF properties and states. +- Demonstrates line drawing utilities to visualize DOF frames (origin and axis). +""" + +import math +import numpy as np +from isaacgym import gymapi, gymutil + + +def clamp(x, min_value, max_value): + return max(min(x, max_value), min_value) + + +# simple asset descriptor for selecting from a list + + +class AssetDesc: + + def __init__(self, file_name, flip_visual_attachments=False): + self.file_name = file_name + self.flip_visual_attachments = flip_visual_attachments +# load asset +asset_root = "/" +asset_root = "./" + +asset_descriptors = [ + AssetDesc("test_good.xml", False), +] + +# parse arguments +args = gymutil.parse_arguments( + description="Joint monkey: Animate degree-of-freedom ranges", + custom_parameters=[{ + "name": + "--asset_id", + "type": + int, + "default": + 0, + "help": + "Asset id (0 - %d)" % (len(asset_descriptors) - 1) + }, { + "name": "--speed_scale", + "type": float, + "default": 1.0, + "help": "Animation speed scale" + }, { + "name": "--show_axis", + "action": "store_true", + "help": "Visualize DOF axis" + }]) + +if args.asset_id < 0 or args.asset_id >= len(asset_descriptors): + print("*** Invalid asset_id specified. Valid range is 0 to %d" % + (len(asset_descriptors) - 1)) + quit() + +# initialize gym +gym = gymapi.acquire_gym() + +# configure sim +sim_params = gymapi.SimParams() +sim_params.dt = dt = 1.0 / 60.0 +if args.physics_engine == gymapi.SIM_FLEX: + pass +elif args.physics_engine == gymapi.SIM_PHYSX: + sim_params.physx.solver_type = 1 + sim_params.physx.num_position_iterations = 6 + sim_params.physx.num_velocity_iterations = 0 + sim_params.physx.num_threads = args.num_threads + sim_params.physx.use_gpu = args.use_gpu + +sim_params.use_gpu_pipeline = False +if args.use_gpu_pipeline: + print("WARNING: Forcing CPU pipeline.") + +sim = gym.create_sim(args.compute_device_id, args.graphics_device_id, + args.physics_engine, sim_params) +if sim is None: + print("*** Failed to create sim") + quit() + +# add ground plane +plane_params = gymapi.PlaneParams() +gym.add_ground(sim, plane_params) + +# create viewer +viewer = gym.create_viewer(sim, gymapi.CameraProperties()) +if viewer is None: + print("*** Failed to create viewer") + quit() + + +asset_file = asset_descriptors[args.asset_id].file_name + +asset_options = gymapi.AssetOptions() +asset_options.fix_base_link = True +asset_options.flip_visual_attachments = asset_descriptors[ + args.asset_id].flip_visual_attachments +asset_options.use_mesh_materials = True +asset_options.replace_cylinder_with_capsule = True + +print("Loading asset '%s' from '%s'" % (asset_file, asset_root)) +asset = gym.load_asset(sim, asset_root, asset_file, asset_options) + +# get array of DOF names +dof_names = gym.get_asset_dof_names(asset) + +# get array of DOF properties +dof_props = gym.get_asset_dof_properties(asset) + +# create an array of DOF states that will be used to update the actors +num_dofs = gym.get_asset_dof_count(asset) +dof_states = np.zeros(num_dofs, dtype=gymapi.DofState.dtype) + +# get list of DOF types +dof_types = [gym.get_asset_dof_type(asset, i) for i in range(num_dofs)] + +# get the position slice of the DOF state array +dof_positions = dof_states['pos'] + +# get the limit-related slices of the DOF properties array +stiffnesses = dof_props['stiffness'] +dampings = dof_props['damping'] +armatures = dof_props['armature'] +has_limits = dof_props['hasLimits'] +lower_limits = dof_props['lower'] +upper_limits = dof_props['upper'] + +# initialize default positions, limits, and speeds (make sure they are in reasonable ranges) +defaults = np.zeros(num_dofs) +speeds = np.zeros(num_dofs) +for i in range(num_dofs): + if has_limits[i]: + if dof_types[i] == gymapi.DOF_ROTATION: + lower_limits[i] = clamp(lower_limits[i], -math.pi, math.pi) + upper_limits[i] = clamp(upper_limits[i], -math.pi, math.pi) + # make sure our default position is in range + if lower_limits[i] > 0.0: + defaults[i] = lower_limits[i] + elif upper_limits[i] < 0.0: + defaults[i] = upper_limits[i] + else: + # set reasonable animation limits for unlimited joints + if dof_types[i] == gymapi.DOF_ROTATION: + # unlimited revolute joint + lower_limits[i] = -math.pi + upper_limits[i] = math.pi + elif dof_types[i] == gymapi.DOF_TRANSLATION: + # unlimited prismatic joint + lower_limits[i] = -1.0 + upper_limits[i] = 1.0 + # set DOF position to default + dof_positions[i] = defaults[i] + # set speed depending on DOF type and range of motion + if dof_types[i] == gymapi.DOF_ROTATION: + speeds[i] = args.speed_scale * clamp( + 2 * + (upper_limits[i] - lower_limits[i]), 0.25 * math.pi, 3.0 * math.pi) + else: + speeds[i] = args.speed_scale * clamp( + 2 * (upper_limits[i] - lower_limits[i]), 0.1, 7.0) + +# Print DOF properties +for i in range(num_dofs): + print("DOF %d" % i) + print(" Name: '%s'" % dof_names[i]) + print(" Type: %s" % gym.get_dof_type_string(dof_types[i])) + print(" Stiffness: %r" % stiffnesses[i]) + print(" Damping: %r" % dampings[i]) + print(" Armature: %r" % armatures[i]) + print(" Limited? %r" % has_limits[i]) + if has_limits[i]: + print(" Lower %f" % lower_limits[i]) + print(" Upper %f" % upper_limits[i]) + +# set up the env grid +num_envs = 36 +num_per_row = 6 +spacing = 2.5 +env_lower = gymapi.Vec3(-spacing, 0.0, -spacing) +env_upper = gymapi.Vec3(spacing, spacing, spacing) + +# position the camera +cam_pos = gymapi.Vec3(17.2, 2.0, 16) +cam_target = gymapi.Vec3(5, -2.5, 13) +gym.viewer_camera_look_at(viewer, None, cam_pos, cam_target) + +# cache useful handles +envs = [] +actor_handles = [] + +print("Creating %d environments" % num_envs) +for i in range(num_envs): + # create env + env = gym.create_env(sim, env_lower, env_upper, num_per_row) + envs.append(env) + + # add actor + pose = gymapi.Transform() + pose.p = gymapi.Vec3(0.0, 1.32, 0.0) + pose.r = gymapi.Quat(-0.707107, 0.0, 0.0, 0.707107) + + actor_handle = gym.create_actor(env, asset, pose, "actor", i, 1) + actor_handles.append(actor_handle) + + # set default DOF positions + gym.set_actor_dof_states(env, actor_handle, dof_states, gymapi.STATE_ALL) + +# joint animation states +ANIM_SEEK_LOWER = 1 +ANIM_SEEK_UPPER = 2 +ANIM_SEEK_DEFAULT = 3 +ANIM_FINISHED = 4 + +# initialize animation state +anim_state = ANIM_SEEK_LOWER +current_dof = 0 +print("Animating DOF %d ('%s')" % (current_dof, dof_names[current_dof])) + +while not gym.query_viewer_has_closed(viewer): + + # step the physics + gym.simulate(sim) + gym.fetch_results(sim, True) + + speed = speeds[current_dof] + + # animate the dofs + if anim_state == ANIM_SEEK_LOWER: + dof_positions[current_dof] -= speed * dt + if dof_positions[current_dof] <= lower_limits[current_dof]: + dof_positions[current_dof] = lower_limits[current_dof] + anim_state = ANIM_SEEK_UPPER + elif anim_state == ANIM_SEEK_UPPER: + dof_positions[current_dof] += speed * dt + if dof_positions[current_dof] >= upper_limits[current_dof]: + dof_positions[current_dof] = upper_limits[current_dof] + anim_state = ANIM_SEEK_DEFAULT + if anim_state == ANIM_SEEK_DEFAULT: + dof_positions[current_dof] -= speed * dt + if dof_positions[current_dof] <= defaults[current_dof]: + dof_positions[current_dof] = defaults[current_dof] + anim_state = ANIM_FINISHED + elif anim_state == ANIM_FINISHED: + dof_positions[current_dof] = defaults[current_dof] + current_dof = (current_dof + 1) % num_dofs + anim_state = ANIM_SEEK_LOWER + print("Animating DOF %d ('%s')" % + (current_dof, dof_names[current_dof])) + + if args.show_axis: + gym.clear_lines(viewer) + + # clone actor state in all of the environments + for i in range(num_envs): + gym.set_actor_dof_states(envs[i], actor_handles[i], dof_states, + gymapi.STATE_POS) + + if args.show_axis: + # get the DOF frame (origin and axis) + dof_handle = gym.get_actor_dof_handle(envs[i], actor_handles[i], + current_dof) + frame = gym.get_dof_frame(envs[i], dof_handle) + + # draw a line from DOF origin along the DOF axis + p1 = frame.origin + p2 = frame.origin + frame.axis * 0.7 + color = gymapi.Vec3(1.0, 0.0, 0.0) + gymutil.draw_line(p1, p2, color, gym, viewer, envs[i]) + + # update the viewer + gym.step_graphics(sim) + gym.draw_viewer(viewer, sim, True) + + # Wait for dt to elapse in real time. + # This synchronizes the physics simulation with the rendering rate. + gym.sync_frame_time(sim) + +print("Done") + +gym.destroy_viewer(viewer) +gym.destroy_sim(sim) diff --git a/scripts/mdm_test.py b/scripts/mdm_test.py new file mode 100644 index 0000000..d7296d5 --- /dev/null +++ b/scripts/mdm_test.py @@ -0,0 +1,42 @@ +import glob +import os +import sys +import pdb +import os.path as osp + +sys.path.append(os.getcwd()) +# os.system("export REPLICATE_API_TOKEN=e47c32b4a1208437d0c5c02d85afb297353bab1b") + +import replicate +import joblib + +model = replicate.models.get("daanelson/motion_diffusion_model") +version = model.versions.get("3e2218c061c18b2a7388dd91b6677b6515529d4db4d719a6513a23522d23cfa7") + +# https://replicate.com/daanelson/motion_diffusion_model/versions/3e2218c061c18b2a7388dd91b6677b6515529d4db4d719a6513a23522d23cfa7#input +inputs = { + # Prompt + 'prompt': "the person walked forward and is picking up his toolbox.", + + # How many + 'num_repetitions': 3, + + # Choose the format of the output, either an animation or a json file + # of the animation data. The json format is: {"thetas": + # [...], "root_translation": [...], "joint_map": [...]}, where + # "thetas" is an [nframes x njoints x 3] array of + # joint rotations in degrees, "root_translation" is an [nframes x 3] + # array of (X, Y, Z) positions of the root, and "joint_map" is a list + # mapping the SMPL joint index to the corresponding + # HumanIK joint name + # 'output_format': "json_file", + 'output_format': "animation", +} + +# https://replicate.com/daanelson/motion_diffusion_model/versions/3e2218c061c18b2a7388dd91b6677b6515529d4db4d719a6513a23522d23cfa7#output-schema +output = version.predict(**inputs) +import ipdb + +ipdb.set_trace() + +joblib.dump(output, "data/mdm/res.pk") \ No newline at end of file diff --git a/scripts/mjcf_to_urdf.py b/scripts/mjcf_to_urdf.py new file mode 100644 index 0000000..32a32f7 --- /dev/null +++ b/scripts/mjcf_to_urdf.py @@ -0,0 +1,28 @@ +#rudimentary MuJoCo mjcf to ROS URDF converter using the UrdfEditor + +import pybullet_utils.bullet_client as bc +import pybullet_data as pd + +import pybullet_utils.urdfEditor as ed +import argparse +parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--mjcf', help='MuJoCo xml file to be converted to URDF', default='mjcf/humanoid.xml') +args = parser.parse_args() + +p = bc.BulletClient() +p.setAdditionalSearchPath(pd.getDataPath()) +objs = p.loadMJCF(args.mjcf, flags=p.URDF_USE_IMPLICIT_CYLINDER) + +for o in objs: + #print("o=",o, p.getBodyInfo(o), p.getNumJoints(o)) + humanoid = objs[o] + ed0 = ed.UrdfEditor() + ed0.initializeFromBulletBody(humanoid, p._client) + robotName = str(p.getBodyInfo(o)[1],'utf-8') + partName = str(p.getBodyInfo(o)[0], 'utf-8') + + print("robotName=",robotName) + print("partName=",partName) + + saveVisuals=False + ed0.saveUrdf(robotName+"_"+partName+".urdf", saveVisuals) \ No newline at end of file diff --git a/scripts/pmcp/forward_pmcp.py b/scripts/pmcp/forward_pmcp.py new file mode 100644 index 0000000..22db411 --- /dev/null +++ b/scripts/pmcp/forward_pmcp.py @@ -0,0 +1,67 @@ +import glob +import os +import sys +import pdb +import os.path as osp +sys.path.append(os.getcwd()) +from rl_games.algos_torch import torch_ext +import joblib +import numpy as np +import argparse + + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--exp', default='') + parser.add_argument('--idx', default=0) + parser.add_argument('--epoch', default=200000) + + args = parser.parse_args() + + trained_idx = int(args.idx) + exp_name = args.exp + epoch = int(args.epoch) + print(f"PNN Processing for: exp_name: {exp_name}, idx: {trained_idx}, epoch: {epoch}") + import ipdb; ipdb.set_trace() + + + checkpoint = torch_ext.load_checkpoint(f"output/dgx/{exp_name}/Humanoid_{epoch:08d}.pth") + amass_train_data_take6 = joblib.load("data/amass/pkls/amass_isaac_im_train_take6_upright_slim.pkl") + + failed_keys_dict = {} + termination_history_dict = {} + all_keys = set() + for failed_path in sorted(glob.glob(f"output/dgx/{exp_name}/failed_*"))[:]: + failed_idx = int(failed_path.split("/")[-1].split("_")[-1].split(".")[0]) + failed_keys_entry = joblib.load(failed_path) + failed_keys = failed_keys_entry['failed_keys'] + failed_keys_dict[failed_idx] = failed_keys + termination_history_dict[failed_idx] = failed_keys_entry['termination_history'] + [all_keys.add(k) for k in failed_keys] + + dump_keys = [] + for k, v in failed_keys_dict.items(): + if k <= epoch and k >= epoch - 2500 * 5: + dump_keys.append(v) + + dump_keys = np.concatenate(dump_keys) + + network_name_prefix = "a2c_network.pnn.actors" + + + loading_keys = [k for k in checkpoint['model'].keys() if k.startswith(f"{network_name_prefix}.{trained_idx}")] + copy_keys = [k for k in checkpoint['model'].keys() if k.startswith(f"{network_name_prefix}.{trained_idx + 1}")] + + + for idx, key_name in enumerate(copy_keys): + checkpoint['model'][key_name].copy_(checkpoint['model'][loading_keys[idx]]) + + torch_ext.save_checkpoint(f"output/dgx/{exp_name}/Humanoid_{epoch + 1:08d}", checkpoint) + + failed_dump = {key: amass_train_data_take6[key] for key in dump_keys if key in amass_train_data_take6} + + os.makedirs(f"data/amass/pkls/auto_pmcp", exist_ok=True) + print(f"dumping {len(failed_dump)} samples to data/amass/pkls/auto_pmcp/{exp_name}_{epoch}.pkl") + joblib.dump(failed_dump, f"data/amass/pkls/auto_pmcp/{exp_name}_{epoch}.pkl") diff --git a/scripts/quest_camera.py b/scripts/quest_camera.py new file mode 100644 index 0000000..b523116 --- /dev/null +++ b/scripts/quest_camera.py @@ -0,0 +1,42 @@ +import cv2 +import numpy as np +# Open camera 0 +cap = cv2.VideoCapture(0) +cap.set(cv2.CAP_PROP_CONVERT_RGB, 0) +# Get the default video size +width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) +height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + +# Define the codec and create a VideoWriter object +fourcc = cv2.VideoWriter_fourcc(*'mp4v') + +out = cv2.VideoWriter('output.mp4', fourcc, 30, (1920, 320)) + +# Start capturing and processing frames +while True: + # Capture frame-by-frame + ret, frame = cap.read() + + # If frame is not available, break the loop + if not ret: + break + # Write the frame to the output video filef + + + # # frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + # gray_frame = np.concatenate([frame[:, :320].transpose(1, 0, 2)[::-1, ], frame[:, 320:640].transpose(1, 0, 2)[::-1, ], frame[:, 640:960].transpose(1, 0, 2)[::-1, ], frame[:, 960:].transpose(1, 0, 2)[::-1, ]], axis = 1) + + # # out.write(cv2.cvtColor(gray_frame, cv2.COLOR_GRAY2BGR)) + # out.write(gray_frame) + + # Display the resulting frame + cv2.imshow('frame', frame[..., 1]) + + # Wait for 1 millisecond for user to press 'q' key to exit + if cv2.waitKey(1) & 0xFF == ord('q'): + break + +# Release the capture and output objects, and close all windows +cap.release() +out.release() +cv2.destroyAllWindows() \ No newline at end of file diff --git a/scripts/render_smpl_o3d.py b/scripts/render_smpl_o3d.py new file mode 100644 index 0000000..c767aae --- /dev/null +++ b/scripts/render_smpl_o3d.py @@ -0,0 +1,142 @@ +import glob +import os +import sys +import pdb +import os.path as osp + +sys.path.append(os.getcwd()) + +import open3d as o3d +import open3d.visualization.rendering as rendering +import imageio +from tqdm import tqdm +import joblib +import numpy as np +import torch + +from smpl_sim.smpllib.smpl_parser import ( + SMPL_Parser, + SMPLH_Parser, + SMPLX_Parser, +) +import random + +from smpl_sim.smpllib.smpl_mujoco import SMPL_BONE_ORDER_NAMES as joint_names +from poselib.poselib.skeleton.skeleton3d import SkeletonTree, SkeletonMotion, SkeletonState +from scipy.spatial.transform import Rotation as sRot +import matplotlib.pyplot as plt +from tqdm import tqdm +import cv2 + +paused, reset, recording, image_list, writer, control, curr_zoom = False, False, False, [], None, None, 0.01 + + +def main(): + render = rendering.OffscreenRenderer(2560, 960) + # render.scene.set_clear_color(np.array([0, 0, 0, 1])) + ############ Load SMPL Data ############ + pkl_dir = "output/renderings/smpl_im_comp_8-2023-02-05-15:36:14.pkl" + mujoco_joint_names = ['Pelvis', 'L_Hip', 'L_Knee', 'L_Ankle', 'L_Toe', 'R_Hip', 'R_Knee', 'R_Ankle', 'R_Toe', 'Torso', 'Spine', 'Chest', 'Neck', 'Head', 'L_Thorax', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'L_Hand', 'R_Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'R_Hand'] + Name = pkl_dir.split("/")[-1].split(".")[0] + pkl_data = joblib.load(pkl_dir) + data_dir = "data/smpl" + mujoco_2_smpl = [mujoco_joint_names.index(q) for q in joint_names if q in mujoco_joint_names] + smpl_parser_n = SMPL_Parser(model_path=data_dir, gender="neutral") + smpl_parser_m = SMPL_Parser(model_path=data_dir, gender="male") + smpl_parser_f = SMPL_Parser(model_path=data_dir, gender="female") + + data_seq = pkl_data['0_0'] + pose_quat, trans = data_seq['body_quat'].numpy()[::2], data_seq['trans'].numpy()[::2] + skeleton_tree = SkeletonTree.from_dict(data_seq['skeleton_tree']) + offset = skeleton_tree.local_translation[0] + root_trans_offset = trans - offset.numpy() + gender, beta = data_seq['betas'][0], data_seq['betas'][1:] + + if gender == 0: + smpl_parser = smpl_parser_n + elif gender == 1: + smpl_parser = smpl_parser_m + else: + smpl_parser = smpl_parser_f + + sk_state = SkeletonState.from_rotation_and_root_translation(skeleton_tree, torch.from_numpy(pose_quat), torch.from_numpy(trans), is_local=True) + + global_rot = sk_state.global_rotation + B, J, N = global_rot.shape + pose_quat = (sRot.from_quat(global_rot.reshape(-1, 4).numpy()) * sRot.from_quat([0.5, 0.5, 0.5, 0.5])).as_quat().reshape(B, -1, 4) + B_down = pose_quat.shape[0] + new_sk_state = SkeletonState.from_rotation_and_root_translation(skeleton_tree, torch.from_numpy(pose_quat), torch.from_numpy(trans), is_local=False) + local_rot = new_sk_state.local_rotation + pose_aa = sRot.from_quat(local_rot.reshape(-1, 4).numpy()).as_rotvec().reshape(B_down, -1, 3) + pose_aa = pose_aa[:, mujoco_2_smpl, :].reshape(B_down, -1) + root_trans_offset[..., :2] = root_trans_offset[..., :2] - root_trans_offset[0:1, :2] + with torch.no_grad(): + vertices, joints = smpl_parser.get_joints_verts(pose=torch.from_numpy(pose_aa), th_trans=torch.from_numpy(root_trans_offset), th_betas=torch.from_numpy(beta[None,])) + # vertices, joints = smpl_parser.get_joints_verts(pose=torch.from_numpy(pose_aa), th_betas=torch.from_numpy(beta[None,])) + vertices = vertices.numpy() + faces = smpl_parser.faces + smpl_mesh = o3d.geometry.TriangleMesh() + smpl_mesh.vertices = o3d.utility.Vector3dVector(vertices[0]) + smpl_mesh.triangles = o3d.utility.Vector3iVector(faces) + # smpl_mesh.compute_triangle_normals() + smpl_mesh.compute_vertex_normals() + + groun_plane = rendering.MaterialRecord() + groun_plane.base_color = [1, 1, 1, 1] + # groun_plane.shader = "defaultLit" + + box = o3d.geometry.TriangleMesh() + ground_size = 10 + box = box.create_box(width=ground_size, height=1, depth=ground_size) + box.compute_triangle_normals() + # box.compute_vertex_normals() + box.translate(np.array([-ground_size / 2, -1, -ground_size / 2])) + box.rotate(sRot.from_euler('x', 90, degrees=True).as_matrix(), center=(0, 0, 0)) + render.scene.add_geometry("box", box, groun_plane) + + # cyl.compute_vertex_normals() + # cyl.translate([-2, 0, 1.5]) + + ending_color = rendering.MaterialRecord() + ending_color.base_color = np.array([35, 102, 218, 256]) / 256 + ending_color.shader = "defaultLit" + + render.scene.add_geometry("cyl", smpl_mesh, ending_color) + eye_level = 1 + render.setup_camera(60.0, [0, 0, eye_level], [0, -3, eye_level], [0, 0, 1]) # center (lookat), eye (pos), up + + # render.scene.scene.set_sun_light([0, 1, 0], [1.0, 1.0, 1.0], 100000) + # render.scene.scene.enable_sun_light(True) + # render.scene.scene.enable_light_shadow("sun", True) + + for i in tqdm(range(0, 50, 5)): + smpl_mesh.vertices = o3d.utility.Vector3dVector(vertices[i]) + color_rgb = np.array([35, 102, 218, 256]) / 256 * (1 - i / 50) + color_rgb[-1] = 1 + ending_color.base_color = color_rgb + render.scene.add_geometry(f"cly_{i}", smpl_mesh, ending_color) + break + + # render.scene.show_axes(True) + img = render.render_to_image() + cv2.imwrite("output/renderings/iccv2023/test_data.png", np.asarray(img)[..., ::-1]) + plt.figure(dpi=400) + plt.imshow(img) + plt.show() + + # writer = imageio.get_writer("output/renderings/test_data.mp4", fps=30, macro_block_size=None) + + # for i in tqdm(range(B_down)): + + # smpl_mesh.vertices = o3d.utility.Vector3dVector(vertices[i]) + + # render.scene.remove_geometry('cyl') + # render.scene.add_geometry("cyl", smpl_mesh, color) + # img = render.render_to_image() + # writer.append_data(np.asarray(img)) + + # writer.close() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/vis/vis_motion.py b/scripts/vis/vis_motion.py new file mode 100644 index 0000000..64ed802 --- /dev/null +++ b/scripts/vis/vis_motion.py @@ -0,0 +1,382 @@ +""" +Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + +NVIDIA CORPORATION and its licensors retain all intellectual property +and proprietary rights in and to this software, related documentation +and any modifications thereto. Any use, reproduction, disclosure or +distribution of this software and related documentation without an express +license agreement from NVIDIA CORPORATION is strictly prohibited. + +Visualize motion library +""" +import glob +import os +import sys +import pdb +import os.path as osp + +sys.path.append(os.getcwd()) + +import joblib +import numpy as np +from isaacgym import gymapi, gymutil, gymtorch +import torch +from phc.utils.motion_lib_smpl import MotionLibSMPL as MotionLibSMPL +from smpl_sim.smpllib.smpl_local_robot import SMPL_Robot +from poselib.poselib.skeleton.skeleton3d import SkeletonTree +from phc.utils.flags import flags + +flags.test = True +flags.im_eval = True + + +def clamp(x, min_value, max_value): + return max(min(x, max_value), min_value) + + +# simple asset descriptor for selecting from a list + + +class AssetDesc: + + def __init__(self, file_name, flip_visual_attachments=False): + self.file_name = file_name + self.flip_visual_attachments = flip_visual_attachments + + +masterfoot = False +# masterfoot = True +robot_cfg = { + "mesh": True, + "rel_joint_lm": False, + "masterfoot": masterfoot, + "upright_start": True, + "remove_toe": False, + "real_weight_porpotion_capsules": True, + "model": "smpl", + "body_params": {}, + "joint_params": {}, + "geom_params": {}, + "actuator_params": {}, +} +smpl_robot = SMPL_Robot( + robot_cfg, + data_dir="data/smpl", +) + +gender_beta = np.array([1.0000, -0.2141, -0.1140, 0.3848, 0.9583, 1.7619, 1.5040, 0.5765, 0.9636, 0.2636, -0.4202, 0.5075, -0.7371, -2.6490, 0.0867, 1.4699, -1.1865]) +smpl_robot.load_from_skeleton(betas=torch.from_numpy(gender_beta[None, 1:]), gender=gender_beta[0:1], objs_info=None) +test_good = f"/tmp/smpl/test_good.xml" +smpl_robot.write_xml(test_good) +sk_tree = SkeletonTree.from_mjcf(test_good) + +asset_descriptors = [ + AssetDesc(test_good, False), +] + +# parse arguments +args = gymutil.parse_arguments(description="Joint monkey: Animate degree-of-freedom ranges", + custom_parameters=[{ + "name": "--asset_id", + "type": int, + "default": 0, + "help": "Asset id (0 - %d)" % (len(asset_descriptors) - 1) + }, { + "name": "--speed_scale", + "type": float, + "default": 1.0, + "help": "Animation speed scale" + }, { + "name": "--show_axis", + "action": "store_true", + "help": "Visualize DOF axis" + }]) + +if args.asset_id < 0 or args.asset_id >= len(asset_descriptors): + print("*** Invalid asset_id specified. Valid range is 0 to %d" % (len(asset_descriptors) - 1)) + quit() + +# initialize gym +gym = gymapi.acquire_gym() + +# configure sim +sim_params = gymapi.SimParams() +sim_params.dt = dt = 1.0 / 60.0 +sim_params.up_axis = gymapi.UP_AXIS_Z +sim_params.gravity = gymapi.Vec3(0.0, 0.0, -9.81) +if args.physics_engine == gymapi.SIM_FLEX: + pass +elif args.physics_engine == gymapi.SIM_PHYSX: + sim_params.physx.solver_type = 1 + sim_params.physx.num_position_iterations = 6 + sim_params.physx.num_velocity_iterations = 0 + sim_params.physx.num_threads = args.num_threads + sim_params.physx.use_gpu = args.use_gpu + sim_params.use_gpu_pipeline = args.use_gpu_pipeline + +if not args.use_gpu_pipeline: + print("WARNING: Forcing CPU pipeline.") + +sim = gym.create_sim(args.compute_device_id, args.graphics_device_id, args.physics_engine, sim_params) +if sim is None: + print("*** Failed to create sim") + quit() + +# add ground plane +plane_params = gymapi.PlaneParams() +plane_params.normal = gymapi.Vec3(0.0, 0.0, 1.0) +gym.add_ground(sim, plane_params) + +# create viewer +viewer = gym.create_viewer(sim, gymapi.CameraProperties()) +if viewer is None: + print("*** Failed to create viewer") + quit() + +# load asset +# asset_root = "amp/data/assets" +# asset_root = "./" +asset_root = "/" +asset_file = asset_descriptors[args.asset_id].file_name + +asset_options = gymapi.AssetOptions() +# asset_options.fix_base_link = True +# asset_options.flip_visual_attachments = asset_descriptors[ +# args.asset_id].flip_visual_attachments +asset_options.use_mesh_materials = True + +print("Loading asset '%s' from '%s'" % (asset_file, asset_root)) +asset = gym.load_asset(sim, asset_root, asset_file, asset_options) + +# set up the env grid +num_envs = 1 +num_per_row = 5 +spacing = 5 +env_lower = gymapi.Vec3(-spacing, spacing, 0) +env_upper = gymapi.Vec3(spacing, spacing, spacing) + +# position the camera +cam_pos = gymapi.Vec3(0, -10.0, 3) +cam_target = gymapi.Vec3(0, 0, 0) +gym.viewer_camera_look_at(viewer, None, cam_pos, cam_target) + +# cache useful handles +envs = [] +actor_handles = [] + +num_dofs = gym.get_asset_dof_count(asset) +print("Creating %d environments" % num_envs) +for i in range(num_envs): + # create env + env = gym.create_env(sim, env_lower, env_upper, num_per_row) + envs.append(env) + + # add actor + pose = gymapi.Transform() + pose.p = gymapi.Vec3(0.0, 0, 0.0) + pose.r = gymapi.Quat(0, 0.0, 0.0, 1) + + actor_handle = gym.create_actor(env, asset, pose, "actor", i, 1) + actor_handles.append(actor_handle) + + # set default DOF positions + dof_states = np.zeros(num_dofs, dtype=gymapi.DofState.dtype) + gym.set_actor_dof_states(env, actor_handle, dof_states, gymapi.STATE_ALL) + +# Setup Motion +body_ids = [] +key_body_names = ["R_Ankle", "L_Ankle", "R_Wrist", "L_Wrist"] +for body_name in key_body_names: + body_id = gym.find_actor_rigid_body_handle(envs[0], actor_handles[0], body_name) + assert (body_id != -1) + body_ids.append(body_id) +gym.prepare_sim(sim) +body_ids = np.array(body_ids) + +motion_file = "data/amass/pkls/amass_isaac_im_patch_upright_slim.pkl" +# motion_file = "data/amass/pkls/amass_isaac_im_train_upright_slim.pkl" +# motion_file = "data/amass/pkls/amass_isaac_locomotion_upright.pkl" +# motion_file = "data/amass/pkls/amass_isaac_slowalk_upright.pkl" +# motion_file = "data/amass/pkls/amass_isaac_slowalk_upright_slim.pkl" +# motion_file = "data/amass/pkls/singles/hard1_upright_slim.pkl" +# motion_file = "data/amass/pkls/amass_isaac_slowalk_upright_slim_double.pkl" +# motion_file = "data/amass/pkls/amass_isaac_run_upright_slim.pkl" +# motion_file = "data/amass/pkls/singles/0-BioMotionLab_NTroje_rub077_0027_circle_walk_poses_upright_slim.pkl" +# motion_file = "data/amass/pkls/amass_isaac_run_upright_slim_double.pkl" +# motion_file = "data/amass/pkls/amass_isaac_walk_upright_test_slim.pkl" +# motion_file = "data/amass/pkls/amass_isaac_crawl_upright_slim.pkl" +# motion_file = "data/amass/pkls/singles/test_test_test.pkl" +# motion_file = "data/amass/pkls/singles/long_upright_slim.pkl" +# motion_file = "data/amass/pkls/test_hyberIK.pkl" +motion_data = joblib.load(motion_file) +# print(motion_keys) + +if masterfoot: + _body_names_orig = ['Pelvis', 'L_Hip', 'L_Knee', 'L_Ankle', 'L_Toe', 'R_Hip', 'R_Knee', 'R_Ankle', 'R_Toe', 'Torso', 'Spine', 'Chest', 'Neck', 'Head', 'L_Thorax', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'L_Hand', 'R_Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'R_Hand'] + _body_names = [ + 'Pelvis', 'L_Hip', 'L_Knee', 'L_Ankle', 'L_Toe', 'L_Toe_1', 'L_Toe_1_1', 'L_Toe_2', 'R_Hip', 'R_Knee', 'R_Ankle', 'R_Toe', 'R_Toe_1', 'R_Toe_1_1', 'R_Toe_2', 'Torso', 'Spine', 'Chest', 'Neck', 'Head', 'L_Thorax', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'L_Hand', 'R_Thorax', 'R_Shoulder', + 'R_Elbow', 'R_Wrist', 'R_Hand' + ] + _body_to_orig = [_body_names.index(name) for name in _body_names_orig] + _body_to_orig_without_toe = [_body_names.index(name) for name in _body_names_orig if name not in ['L_Toe', 'R_Toe']] + orig_to_orig_without_toe = [_body_names_orig.index(name) for name in _body_names_orig if name not in ['L_Toe', 'R_Toe']] + + _masterfoot_config = { + "body_names_orig": _body_names_orig, + "body_names": _body_names, + "body_to_orig": _body_to_orig, + "body_to_orig_without_toe": _body_to_orig_without_toe, + "orig_to_orig_without_toe": orig_to_orig_without_toe, + } +else: + _masterfoot_config = None + +device = (torch.device("cuda", index=0) if torch.cuda.is_available() else torch.device("cpu")) + +motion_lib = MotionLibSMPL(motion_file=motion_file, key_body_ids=body_ids, device=device, masterfoot_conifg=_masterfoot_config, fix_height=False, multi_thread=False) +num_motions = 30 +curr_start = 0 +motion_lib.load_motions(skeleton_trees=[sk_tree] * num_motions, gender_betas=[torch.zeros(17)] * num_motions, limb_weights=[np.zeros(10)] * num_motions, random_sample=False) +motion_keys = motion_lib.curr_motion_keys + +current_dof = 0 +speeds = np.zeros(num_dofs) + +time_step = 0 +rigidbody_state = gym.acquire_rigid_body_state_tensor(sim) +rigidbody_state = gymtorch.wrap_tensor(rigidbody_state) +rigidbody_state = rigidbody_state.reshape(num_envs, -1, 13) + +actor_root_state = gym.acquire_actor_root_state_tensor(sim) +actor_root_state = gymtorch.wrap_tensor(actor_root_state) + +gym.subscribe_viewer_keyboard_event(viewer, gymapi.KEY_LEFT, "previous") +gym.subscribe_viewer_keyboard_event(viewer, gymapi.KEY_RIGHT, "next") +gym.subscribe_viewer_keyboard_event(viewer, gymapi.KEY_G, "add") +gym.subscribe_viewer_keyboard_event(viewer, gymapi.KEY_P, "print") +gym.subscribe_viewer_keyboard_event(viewer, gymapi.KEY_T, "next_batch") +motion_id = 0 +motion_acc = set() +if masterfoot: + left_to_right_index = [7, 8, 9, 10, 11, 12, 13, 0, 1, 2, 3, 4, 5, 6, 14, 15, 16, 17, 18, 24, 25, 26, 27, 28, 19, 20, 21, 22, 23] +else: + left_to_right_index = [4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 18, 19, 20, 21, 22, 13, 14, 15, 16, 17] +env_ids = torch.arange(num_envs).int().to(args.sim_device) +while not gym.query_viewer_has_closed(viewer): + # step the physics + + motion_len = motion_lib.get_motion_length(motion_id).item() + motion_time = time_step % motion_len + # motion_time = 0 + + motion_res = motion_lib.get_motion_state(torch.tensor([motion_id]).to(args.compute_device_id), torch.tensor([motion_time]).to(args.compute_device_id)) + + root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, smpl_params, limb_weights, pose_aa, rb_pos, rb_rot, body_vel, body_ang_vel = \ + motion_res["root_pos"], motion_res["root_rot"], motion_res["dof_pos"], motion_res["root_vel"], motion_res["root_ang_vel"], motion_res["dof_vel"], \ + motion_res["motion_bodies"], motion_res["motion_limb_weights"], motion_res["motion_aa"], motion_res["rg_pos"], motion_res["rb_rot"], motion_res["body_vel"], motion_res["body_ang_vel"] + + if args.show_axis: + gym.clear_lines(viewer) + + #################### Heading invarance check: #################### + # from phc.env.tasks.humanoid_im import compute_imitation_observations + # from phc.env.tasks.humanoid import compute_humanoid_observations_smpl_max + # from phc.env.tasks.humanoid_amp import build_amp_observations_smpl + + # motion_res_10 = motion_lib.get_motion_state(torch.tensor([motion_id]).to(args.compute_device_id), torch.tensor([0]).to(args.compute_device_id)) + # motion_res_100 = motion_lib.get_motion_state(torch.tensor([motion_id]).to(args.compute_device_id), torch.tensor([3]).to(args.compute_device_id)) + + # root_pos_10, root_rot_10, dof_pos_10, root_vel_10, root_ang_vel_10, dof_vel_10, key_pos_10, smpl_params_10, limb_weights_10, pose_aa_10, rb_pos_10, rb_rot_10, body_vel_10, body_ang_vel_10 = \ + # motion_res_10["root_pos"], motion_res_10["root_rot"], motion_res_10["dof_pos"], motion_res_10["root_vel"], motion_res_10["root_ang_vel"], motion_res_10["dof_vel"], \ + # motion_res_10["key_pos"], motion_res_10["motion_bodies"], motion_res_10["motion_limb_weights"], motion_res_10["motion_aa"], motion_res_10["rg_pos"], motion_res_10["rb_rot"], motion_res_10["body_vel"], motion_res_10["body_ang_vel"] + + # root_pos_100, root_rot_100, dof_pos_100, root_vel_100, root_ang_vel_100, dof_vel_100, key_pos_100, smpl_params_100, limb_weights_100, pose_aa_100, rb_pos_100, rb_rot_100, body_vel_100, body_ang_vel_100 = \ + # motion_res_100["root_pos"], motion_res_100["root_rot"], motion_res_100["dof_pos"], motion_res_100["root_vel"], motion_res_100["root_ang_vel"], motion_res_100["dof_vel"], \ + # motion_res_100["key_pos"], motion_res_100["motion_bodies"], motion_res_100["motion_limb_weights"], motion_res_100["motion_aa"], motion_res_100["rg_pos"], motion_res_100["rb_rot"], motion_res_100["body_vel"], motion_res_100["body_ang_vel"] + + # # obs = compute_imitation_observations(root_pos_100, root_rot_100, rb_pos_100, rb_rot_100, body_vel_100, body_ang_vel_100, rb_pos_10, rb_rot_10, body_vel_10, body_ang_vel_10, 1, True) + # # obs_im = compute_humanoid_observations_smpl_max(rb_pos_100, rb_rot_100, body_vel_100, body_ang_vel_100, smpl_params_100, limb_weights_100, True, False, True, True, True) + # obs_amp = build_amp_observations_smpl( + # root_pos_100, root_rot_100, body_vel_100[:, 0, :], + # body_ang_vel_100[:, 0, :], dof_pos_100, dof_vel_100, rb_pos_100, + # smpl_params_100, limb_weights_100, None, True, False, False, True, True, True) + + # motion_lib.load_motions(skeleton_trees = [sk_tree] * num_motions, gender_betas = [torch.zeros(17)] * num_motions, limb_weights = [np.zeros(10)] * num_motions, random_sample=False) + # # joblib.dump(obs_amp, "a.pkl") + # import ipdb + # ipdb.set_trace() + + #################### Heading invarance check: #################### + + ########################################################################### + # root_pos[:, 1] *= -1 + # key_pos[:, 1] *= -1 # Will need to flip these as well + # root_rot[:, 0] *= -1 + # root_rot[:, 2] *= -1 + + # dof_vel = dof_vel.reshape(len(left_to_right_index), 3)[left_to_right_index] + # dof_vel[:, 0] = dof_vel[:, 0] * -1 + # dof_vel[:, 2] = dof_vel[:, 2] * -1 + # dof_vel = dof_vel.reshape(1, len(left_to_right_index) * 3) + + # dof_pos = dof_pos.reshape(len(left_to_right_index), 3)[left_to_right_index] + # dof_pos[:, 0] = dof_pos[:, 0] * -1 + # dof_pos[:, 2] = dof_pos[:, 2] * -1 + # dof_pos = dof_pos.reshape(1, len(left_to_right_index) * 3) + ########################################################################### + root_states = torch.cat([root_pos, root_rot, root_vel, root_ang_vel], dim=-1).repeat(num_envs, 1) + # gym.set_actor_root_state_tensor(sim, gymtorch.unwrap_tensor(root_states)) + gym.set_actor_root_state_tensor_indexed(sim, gymtorch.unwrap_tensor(root_states), gymtorch.unwrap_tensor(env_ids), len(env_ids)) + + gym.refresh_actor_root_state_tensor(sim) + + # dof_pos = dof_pos.cpu().numpy() + # dof_states['pos'] = dof_pos + # speed = speeds[current_dof] + + dof_state = torch.stack([dof_pos, torch.zeros_like(dof_pos)], dim=-1).squeeze().repeat(num_envs, 1) + gym.set_dof_state_tensor_indexed(sim, gymtorch.unwrap_tensor(dof_state), gymtorch.unwrap_tensor(env_ids), len(env_ids)) + + gym.simulate(sim) + gym.refresh_rigid_body_state_tensor(sim) + gym.fetch_results(sim, True) + + # print((rigidbody_state[None, ] - rigidbody_state[:, None]).sum().abs()) + # print((actor_root_state[None, ] - actor_root_state[:, None]).sum().abs()) + + # pose_quat = motion_lib._motion_data['0-ACCAD_Female1Running_c3d_C5 - walk to run_poses']['pose_quat_global'] + # diff = quat_mul(quat_inverse(rb_rot[0, :]), rigidbody_state[0, :, 3:7]); np.set_printoptions(precision=4, suppress=1); print(diff.cpu().numpy()); print(torch_utils.quat_to_angle_axis(diff)[0]) + + # update the viewer + gym.step_graphics(sim) + gym.draw_viewer(viewer, sim, True) + + # Wait for dt to elapse in real time. + # This synchronizes the physics simulation with the rendering rate. + gym.sync_frame_time(sim) + # time_step += 1/5 + time_step += dt + + for evt in gym.query_viewer_action_events(viewer): + if evt.action == "previous" and evt.value > 0: + motion_id = (motion_id - 1) % num_motions + print(f"Motion ID: {motion_id}. Motion length: {motion_len:.3f}. Motion Name: {motion_keys[motion_id]}") + elif evt.action == "next" and evt.value > 0: + motion_id = (motion_id + 1) % num_motions + print(f"Motion ID: {motion_id}. Motion length: {motion_len:.3f}. Motion Name: {motion_keys[motion_id]}") + elif evt.action == "add" and evt.value > 0: + motion_acc.add(motion_keys[motion_id]) + print(f"Adding motion {motion_keys[motion_id]}") + elif evt.action == "print" and evt.value > 0: + print(motion_acc) + elif evt.action == "next_batch" and evt.value > 0: + curr_start += num_motions + motion_lib.load_motions(skeleton_trees=[sk_tree] * num_motions, gender_betas=[torch.zeros(17)] * num_motions, limb_weights=[np.zeros(10)] * num_motions, random_sample=False, start_idx=curr_start) + motion_keys = motion_lib.curr_motion_keys + print(f"Next batch {curr_start}") + + time_step = 0 +print("Done") + +gym.destroy_viewer(viewer) +gym.destroy_sim(sim) diff --git a/scripts/vis/vis_smpl_o3d.py b/scripts/vis/vis_smpl_o3d.py new file mode 100644 index 0000000..9e8b2fc --- /dev/null +++ b/scripts/vis/vis_smpl_o3d.py @@ -0,0 +1,262 @@ +import glob +import os +import sys +import pdb +import os.path as osp + +sys.path.append(os.getcwd()) + +import open3d as o3d +import open3d.visualization.rendering as rendering +import imageio +from tqdm import tqdm +import joblib +import numpy as np +import torch + +from smpl_sim.smpllib.smpl_parser import ( + SMPL_Parser, + SMPLH_Parser, + SMPLX_Parser, +) +import random + +from smpl_sim.smpllib.smpl_mujoco import SMPL_BONE_ORDER_NAMES as joint_names +from poselib.poselib.skeleton.skeleton3d import SkeletonTree, SkeletonMotion, SkeletonState +from scipy.spatial.transform import Rotation as sRot +import matplotlib.pyplot as plt +from tqdm import tqdm +import math + +paused, reset, recording, image_list, writer, control, curr_zoom = False, False, False, [], None, None, 0.01 + + +def pause_func(action): + global paused + paused = not paused + print(f"Paused: {paused}") + return True + + +def reset_func(action): + global reset + reset = not reset + print(f"Reset: {reset}") + return True + + +def record_func(action): + global recording, writer + if not recording: + fps = 30 + curr_video_file_name = "test.mp4" + writer = imageio.get_writer(curr_video_file_name, fps=fps, macro_block_size=None) + elif not writer is None: + writer.close() + writer = None + + recording = not recording + + print(f"Recording: {recording}") + return True + + +def zoom_func(action): + global control, curr_zoom + curr_zoom = curr_zoom * 0.9 + control.set_zoom(curr_zoom) + print(f"Reset: {reset}") + return True + + + + +Name = "getting_started" +Title = "Getting Started" + +data_dir = "data/smpl" +smpl_parser_n = SMPL_Parser(model_path=data_dir, gender="neutral") +smpl_parser_m = SMPL_Parser(model_path=data_dir, gender="male") +smpl_parser_f = SMPL_Parser(model_path=data_dir, gender="female") + +# pkl_dir = "output/renderings/smpl_ego_long_8-2023-01-20-11:28:00.pkl" +pkl_dir = "output/renderings/smpl_ego_7-2023-07-13-16:15:23.pkl" +# pkl_dir = "output/renderings/smpl_im_comp_pnn_1_1_demo-2023-03-13-11:36:14.pkl" +Name = pkl_dir.split("/")[-1].split(".")[0] +pkl_data = joblib.load(pkl_dir) +mujoco_joint_names = ['Pelvis', 'L_Hip', 'L_Knee', 'L_Ankle', 'L_Toe', 'R_Hip', 'R_Knee', 'R_Ankle', 'R_Toe', 'Torso', 'Spine', 'Chest', 'Neck', 'Head', 'L_Thorax', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'L_Hand', 'R_Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'R_Hand'] +mujoco_2_smpl = [mujoco_joint_names.index(q) for q in joint_names if q in mujoco_joint_names] + +# data_file = "data/quest/home1_isaac.pkl" +# sk_tree = SkeletonTree.from_mjcf(f"/tmp/smpl/test_good.xml") +# motion_lib = MotionLibSMPLTest("data/quest/home1_isaac.pkl", [7, 3, 22, 17],torch.device("cpu")) +# motion_lib.load_motions(skeleton_trees=[sk_tree], +# gender_betas=[torch.zeros(17)] , +# limb_weights=[np.zeros(10)] , +# random_sample=False) + + +def main(): + global reset, paused, recording, image_list, control + o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Debug) + vis = o3d.visualization.VisualizerWithKeyCallback() + vis.create_window() + + ############ Loading texture ############ + texture_path = "/hdd/zen/data/SURREAL/smpl_data/" + faces_uv = np.load(os.path.join(texture_path, 'final_faces_uv_mapping.npy')) + uv_sampler = torch.from_numpy(faces_uv.reshape(-1, 2, 2, 2)) + uv_sampler = uv_sampler.view(-1, 13776, 2 * 2, 2) + texture_img_path_male = osp.join(texture_path, "textures", "male") + texture_img_path_female = osp.join(texture_path, "textures", "female") + ############ Loading texture ############ + + smpl_meshes = dict() + items = list(pkl_data.items()) + + for entry_key, data_seq in tqdm(items): + gender, beta = data_seq['betas'][0], data_seq['betas'][1:] + if gender == 0: + smpl_parser = smpl_parser_n + texture_image_path = texture_img_path_male + elif gender == 1: + smpl_parser = smpl_parser_m + texture_image_path = texture_img_path_male + else: + smpl_parser = smpl_parser_f + texture_image_path = texture_img_path_female + + pose_quat, trans = data_seq['body_quat'].numpy()[::2], data_seq['trans'].numpy()[::2] + # if pose_quat.shape[0] < 200: + # continue + skeleton_tree = SkeletonTree.from_dict(data_seq['skeleton_tree']) + offset = skeleton_tree.local_translation[0] + root_trans_offset = trans - offset.numpy() + + sk_state = SkeletonState.from_rotation_and_root_translation(skeleton_tree, torch.from_numpy(pose_quat), torch.from_numpy(trans), is_local=True) + + global_rot = sk_state.global_rotation + B, J, N = global_rot.shape + pose_quat = (sRot.from_quat(global_rot.reshape(-1, 4).numpy()) * sRot.from_quat([0.5, 0.5, 0.5, 0.5])).as_quat().reshape(B, -1, 4) + B_down = pose_quat.shape[0] + new_sk_state = SkeletonState.from_rotation_and_root_translation(skeleton_tree, torch.from_numpy(pose_quat), torch.from_numpy(trans), is_local=False) + local_rot = new_sk_state.local_rotation + pose_aa = sRot.from_quat(local_rot.reshape(-1, 4).numpy()).as_rotvec().reshape(B_down, -1, 3) + pose_aa = pose_aa[:, mujoco_2_smpl, :].reshape(B_down, -1) + + vertices, joints = smpl_parser.get_joints_verts(pose=torch.from_numpy(pose_aa), th_trans=torch.from_numpy(root_trans_offset), th_betas=torch.from_numpy(beta[None,])) + vertices = vertices.numpy() + faces = smpl_parser.faces + smpl_mesh = o3d.geometry.TriangleMesh() + smpl_mesh.vertices = o3d.utility.Vector3dVector(vertices[0]) + smpl_mesh.triangles = o3d.utility.Vector3iVector(faces) + + ######################## Smampling texture ######################## + batch_size = 1 + uv_sampler = uv_sampler.repeat(batch_size, 1, 1, 1) ##torch.Size([B, 13776, 4, 2]) + full_path = "nongrey_male_0237.jpg" + # full_path = random.choice(os.listdir(texture_image_path)) + texture_image = plt.imread(osp.join(texture_image_path, full_path)) + + texture_image = np.transpose(texture_image, (2, 0, 1)) + texture_image = torch.from_numpy(texture_image).float() / 255.0 + textures = torch.nn.functional.grid_sample(texture_image[None,], uv_sampler, align_corners=True) #torch.Size([N, 3, 13776, 4]) + textures = textures.permute(0, 2, 3, 1) #torch.Size([N, 13776, 4, 3]) + # textures = textures.view(-1, 13776, 2, 2, 3) #torch.Size([N, 13776, 2, 2, 3]) + textures = textures.squeeze().numpy() + + vertex_colors = {} + for idx, f in enumerate(faces): + colors = textures[idx] + for vidx, vid in enumerate(f): + vertex_colors[vid] = colors[vidx] + vertex_colors = np.array([vertex_colors[i] for i in range(len(vertex_colors))]) + smpl_mesh.vertex_colors = o3d.utility.Vector3dVector(vertex_colors) + smpl_mesh.compute_triangle_normals() + # smpl_mesh.compute_vertex_normals() + ######################## Smampling texture ######################## + vis.add_geometry(smpl_mesh) + smpl_meshes[entry_key] = { + 'mesh': smpl_mesh, + "vertices": vertices, + } + + box = o3d.geometry.TriangleMesh() + ground_size, height = 10, 0.01 + box = box.create_box(width=ground_size, height=height, depth=ground_size) + box.translate(np.array([-ground_size / 2, -height, -ground_size / 2])) + box.rotate(sRot.from_euler("xyz", [np.pi / 2, 0, 0]).as_matrix()) + box.compute_vertex_normals() + box.vertex_colors = o3d.utility.Vector3dVector(np.array([[1, 1, 1]]).repeat(8, axis=0)) + vis.add_geometry(box) + + # spheres = [] + # for _ in range(3): + # sphere = o3d.geometry.TriangleMesh() + # sphere = sphere.create_sphere(radius=0.1) + # sphere.compute_vertex_normals() + # sphere.vertex_colors = o3d.utility.Vector3dVector(np.array([[0.1, 0.9, 0.1]]).repeat(len(sphere.vertices), axis=0)) + # spheres.append(sphere) + + # sphere_pos = np.zeros([3, 3]) + # [vis.add_geometry(sphere) for sphere in spheres] + + control = vis.get_view_control() + + control.unset_constant_z_far() + control.unset_constant_z_near() + i = 0 + N = vertices.shape[0] + + vis.register_key_callback(32, pause_func) + vis.register_key_callback(82, reset_func) + vis.register_key_callback(76, record_func) + vis.register_key_callback(90, zoom_func) + + control.set_up(np.array([0, 0, 1])) + control.set_front(np.array([1, 0, 0])) + control.set_lookat(vertices[0, 0]) + + control.set_zoom(0.5) + dt = 1 / 30 + + tracker_pos = pkl_data['0_0']['ref_body_pos_subset'][::2].cpu().numpy() + + while True: + vis.poll_events() + for smpl_mesh_key, smpl_mesh_data in smpl_meshes.items(): + verts = smpl_mesh_data["vertices"] + smpl_mesh_data["mesh"].vertices = o3d.utility.Vector3dVector(verts[i % verts.shape[0]]) + vis.update_geometry(smpl_mesh_data["mesh"]) + + # motion_res = motion_lib.get_motion_state(torch.tensor([0]), torch.tensor([(i % verts.shape[0]) * dt])) + # curr_pos = motion_res['rg_pos'][0, [13, 18 ,23]].numpy() + curr_pos = tracker_pos[i % verts.shape[0]] + + # for idx, s in enumerate(spheres): + # s.translate((curr_pos - sphere_pos)[idx]) + # vis.update_geometry(s) + # sphere_pos = curr_pos + # sphere.translate(verts[0, 0]) + + # vis.update_geometry(sphere) + + if not paused: + i = (i + 1) + + if reset: + i = 0 + reset = False + if recording: + rgb = vis.capture_screen_float_buffer() + rgb = (np.asarray(rgb) * 255).astype(np.uint8) + w, h, _ = rgb.shape + w, h = math.floor(w / 2.) * 2, math.floor(h / 2.) * 2 + rgb = rgb[:w, :h, :] + writer.append_data(rgb) + + vis.update_renderer() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/vis/vis_smpl_o3d_ego.py b/scripts/vis/vis_smpl_o3d_ego.py new file mode 100644 index 0000000..753d19c --- /dev/null +++ b/scripts/vis/vis_smpl_o3d_ego.py @@ -0,0 +1,260 @@ +import glob +import os +import sys +import pdb +import os.path as osp + +sys.path.append(os.getcwd()) +import open3d as o3d +import open3d.visualization.rendering as rendering +import imageio +from tqdm import tqdm +import joblib +import numpy as np +import torch + +from smpl_sim.smpllib.smpl_parser import ( + SMPL_Parser, + SMPLH_Parser, + SMPLX_Parser, +) +import random + +from smpl_sim.smpllib.smpl_mujoco import SMPL_BONE_ORDER_NAMES as joint_names +from poselib.poselib.skeleton.skeleton3d import SkeletonTree, SkeletonMotion, SkeletonState +from scipy.spatial.transform import Rotation as sRot +import matplotlib.pyplot as plt +from tqdm import tqdm +import matplotlib as mpl +from datetime import datetime +paused, reset, recording, image_list, writer, control, curr_zoom = False, False, False, [], None, None, 0.01 + + +def pause_func(action): + global paused + paused = not paused + print(f"Paused: {paused}") + return True + + +def reset_func(action): + global reset + reset = not reset + print(f"Reset: {reset}") + return True + + +def record_func(action): + global recording, writer + if not recording: + fps = 30 + curr_date_time = datetime.now().strftime('%Y-%m-%d-%H:%M:%S') + curr_video_file_name = f"output/renderings/o3d/{curr_date_time}-test.mp4" + writer = imageio.get_writer(curr_video_file_name, fps=fps, macro_block_size=None) + elif not writer is None: + writer.close() + writer = None + + recording = not recording + + print(f"Recording: {recording}") + return True + + +def zoom_func(action): + global control, curr_zoom + curr_zoom = curr_zoom * 0.9 + control.set_zoom(curr_zoom) + print(f"Reset: {reset}") + return True + +colorpicker = mpl.colormaps['Blues'] +mujoco_joint_names = ['Pelvis', 'L_Hip', 'L_Knee', 'L_Ankle', 'L_Toe', 'R_Hip', 'R_Knee', 'R_Ankle', 'R_Toe', 'Torso', 'Spine', 'Chest', 'Neck', 'Head', 'L_Thorax', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'L_Hand', 'R_Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'R_Hand'] + +Name = "getting_started" +Title = "Getting Started" + +data_dir = "data/smpl" +smpl_parser_n = SMPL_Parser(model_path=data_dir, gender="neutral") +smpl_parser_m = SMPL_Parser(model_path=data_dir, gender="male") +smpl_parser_f = SMPL_Parser(model_path=data_dir, gender="female") + +pkl_dir = "output/renderings/smpl_im_comp_pnn_1_1_demo-2023-03-13-14:28:47.pkl" +Name = pkl_dir.split("/")[-1].split(".")[0] +pkl_data = joblib.load(pkl_dir) +mujoco_2_smpl = [mujoco_joint_names.index(q) for q in joint_names if q in mujoco_joint_names] + +# data_file = "data/quest/home1_isaac.pkl" +# sk_tree = SkeletonTree.from_mjcf(f"/tmp/smpl/test_good.xml") +# motion_lib.load_motions(skeleton_trees=[sk_tree], +# gender_betas=[torch.zeros(17)] , +# limb_weights=[np.zeros(10)] , +# random_sample=False) + + +def main(): + global reset, paused, recording, image_list, control + o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Debug) + vis = o3d.visualization.VisualizerWithKeyCallback() + vis.create_window() + + ############ Loading texture ############ + texture_path = "/hdd/zen/data/SURREAL/smpl_data/" + faces_uv = np.load(os.path.join(texture_path, 'final_faces_uv_mapping.npy')) + uv_sampler = torch.from_numpy(faces_uv.reshape(-1, 2, 2, 2)) + uv_sampler = uv_sampler.view(-1, 13776, 2 * 2, 2) + texture_img_path_male = osp.join(texture_path, "textures", "male") + texture_img_path_female = osp.join(texture_path, "textures", "female") + ############ Loading texture ############ + + smpl_meshes = dict() + items = list(pkl_data.items()) + + for entry_key, data_seq in tqdm(items): + gender, beta = data_seq['betas'][0], data_seq['betas'][1:] + if gender == 0: + smpl_parser = smpl_parser_n + texture_image_path = texture_img_path_male + elif gender == 1: + smpl_parser = smpl_parser_m + texture_image_path = texture_img_path_male + else: + smpl_parser = smpl_parser_f + texture_image_path = texture_img_path_female + + pose_quat, trans = data_seq['body_quat'].numpy()[::2], data_seq['trans'].numpy()[::2] + if pose_quat.shape[0] < 200: + continue + skeleton_tree = SkeletonTree.from_dict(data_seq['skeleton_tree']) + offset = skeleton_tree.local_translation[0] + root_trans_offset = trans - offset.numpy() + + sk_state = SkeletonState.from_rotation_and_root_translation(skeleton_tree, torch.from_numpy(pose_quat), torch.from_numpy(trans), is_local=True) + + global_rot = sk_state.global_rotation + B, J, N = global_rot.shape + pose_quat = (sRot.from_quat(global_rot.reshape(-1, 4).numpy()) * sRot.from_quat([0.5, 0.5, 0.5, 0.5])).as_quat().reshape(B, -1, 4) + B_down = pose_quat.shape[0] + new_sk_state = SkeletonState.from_rotation_and_root_translation(skeleton_tree, torch.from_numpy(pose_quat), torch.from_numpy(trans), is_local=False) + local_rot = new_sk_state.local_rotation + pose_aa = sRot.from_quat(local_rot.reshape(-1, 4).numpy()).as_rotvec().reshape(B_down, -1, 3) + pose_aa = pose_aa[:, mujoco_2_smpl, :].reshape(B_down, -1) + + vertices, joints = smpl_parser.get_joints_verts(pose=torch.from_numpy(pose_aa), th_trans=torch.from_numpy(root_trans_offset), th_betas=torch.from_numpy(beta[None,])) + vertices = vertices.numpy() + faces = smpl_parser.faces + smpl_mesh = o3d.geometry.TriangleMesh() + smpl_mesh.vertices = o3d.utility.Vector3dVector(vertices[0]) + smpl_mesh.triangles = o3d.utility.Vector3iVector(faces) + + ######################## Smampling texture ######################## + batch_size = 1 + # uv_sampler = uv_sampler.repeat(batch_size, 1, 1, 1) ##torch.Size([B, 13776, 4, 2]) + # full_path = "nongrey_male_0237.jpg" + # # full_path = random.choice(os.listdir(texture_image_path)) + # texture_image = plt.imread(osp.join(texture_image_path, full_path)) + + # texture_image = np.transpose(texture_image, (2, 0, 1)) + # texture_image = torch.from_numpy(texture_image).float() / 255.0 + # textures = torch.nn.functional.grid_sample(texture_image[None,], uv_sampler, align_corners=True) #torch.Size([N, 3, 13776, 4]) + # textures = textures.permute(0, 2, 3, 1) #torch.Size([N, 13776, 4, 3]) + # # textures = textures.view(-1, 13776, 2, 2, 3) #torch.Size([N, 13776, 2, 2, 3]) + # textures = textures.squeeze().numpy() + + # vertex_colors = {} + # for idx, f in enumerate(faces): + # colors = textures[idx] + # for vidx, vid in enumerate(f): + # vertex_colors[vid] = colors[vidx] + # vertex_colors = np.array([vertex_colors[i] for i in range(len(vertex_colors))]) + # smpl_mesh.vertex_colors = o3d.utility.Vector3dVector(vertex_colors) + + vertex_colors = colorpicker(0.6)[:3] + smpl_mesh.paint_uniform_color(vertex_colors) + + smpl_mesh.compute_vertex_normals() + ######################## Smampling texture ######################## + vis.add_geometry(smpl_mesh) + smpl_meshes[entry_key] = { + 'mesh': smpl_mesh, + "vertices": vertices, + } + break + + box = o3d.geometry.TriangleMesh() + ground_size, height = 50, 0.01 + box = box.create_box(width=ground_size, height=height, depth=ground_size) + box.translate(np.array([-ground_size / 2, -height, -ground_size / 2])) + box.rotate(sRot.from_euler("xyz", [np.pi / 2, 0, 0]).as_matrix()) + box.compute_vertex_normals() + box.vertex_colors = o3d.utility.Vector3dVector(np.array([[1, 1, 1]]).repeat(8, axis=0)) + + spheres = [] + for _ in range(24): + sphere = o3d.geometry.TriangleMesh() + sphere = sphere.create_sphere(radius=0.05) + sphere.compute_vertex_normals() + sphere.vertex_colors = o3d.utility.Vector3dVector(np.array([[0.1, 0.9, 0.1]]).repeat(len(sphere.vertices), axis=0)) + spheres.append(sphere) + + sphere_pos = np.zeros([24, 3]) + [vis.add_geometry(sphere) for sphere in spheres] + vis.add_geometry(box) + + control = vis.get_view_control() + + control.unset_constant_z_far() + control.unset_constant_z_near() + i = 0 + N = vertices.shape[0] + + vis.register_key_callback(32, pause_func) + vis.register_key_callback(82, reset_func) + vis.register_key_callback(76, record_func) + vis.register_key_callback(90, zoom_func) + + control.set_up(np.array([0, 0, 1])) + control.set_front(np.array([1, 0, 0])) + control.set_lookat(vertices[0, 0]) + + control.set_zoom(0.5) + dt = 1 / 30 + + to_isaac_mat = sRot.from_euler('xyz', np.array([-np.pi / 2, 0, 0]), degrees=False).as_matrix() + tracker_pos = pkl_data['0_0']['ref_body_pos_subset'][::2].cpu().numpy() + tracker_pos = np.matmul(tracker_pos, to_isaac_mat.T) + + while True: + vis.poll_events() + for smpl_mesh_key, smpl_mesh_data in smpl_meshes.items(): + verts = smpl_mesh_data["vertices"] + smpl_mesh_data["mesh"].vertices = o3d.utility.Vector3dVector(verts[i % verts.shape[0]]) + vis.update_geometry(smpl_mesh_data["mesh"]) + + # motion_res = motion_lib.get_motion_state(torch.tensor([0]), torch.tensor([(i % verts.shape[0]) * dt])) + # curr_pos = motion_res['rg_pos'][0, [13, 18 ,23]].numpy() + curr_pos = tracker_pos[i % verts.shape[0]] + + for idx, s in enumerate(spheres): + s.translate((curr_pos - sphere_pos)[idx]) + vis.update_geometry(s) + sphere_pos = curr_pos + # sphere.translate(verts[0, 0]) + # vis.update_geometry(sphere) + + if not paused: + i = (i + 1) + + if reset: + i = 0 + reset = False + if recording: + rgb = vis.capture_screen_float_buffer() + rgb = (np.asarray(rgb) * 255).astype(np.uint8) + writer.append_data(rgb) + + vis.update_renderer() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/vis/vis_smpl_o3d_multi.py b/scripts/vis/vis_smpl_o3d_multi.py new file mode 100644 index 0000000..0d8457a --- /dev/null +++ b/scripts/vis/vis_smpl_o3d_multi.py @@ -0,0 +1,279 @@ +import glob +import os +import sys +import pdb +import os.path as osp + +sys.path.append(os.getcwd()) + +import open3d as o3d +import open3d.visualization.rendering as rendering +import imageio +from tqdm import tqdm +import joblib +import numpy as np +import torch + +from smpl_sim.smpllib.smpl_parser import ( + SMPL_Parser, + SMPLH_Parser, + SMPLX_Parser, +) +import random + +from smpl_sim.smpllib.smpl_mujoco import SMPL_BONE_ORDER_NAMES as joint_names +from poselib.poselib.skeleton.skeleton3d import SkeletonTree, SkeletonMotion, SkeletonState +from scipy.spatial.transform import Rotation as sRot +import matplotlib.pyplot as plt +from tqdm import tqdm +import cv2 +import matplotlib as mpl + +paused, reset, recording, image_list, writer, control, curr_zoom = False, False, False, [], None, None, 0.01 + + +def pause_func(action): + global paused + paused = not paused + print(f"Paused: {paused}") + return True + + +def reset_func(action): + global reset + reset = not reset + print(f"Reset: {reset}") + return True + + +def record_func(action): + global recording, writer + if not recording: + fps = 30 + curr_video_file_name = "test.mp4" + writer = imageio.get_writer(curr_video_file_name, fps=fps, macro_block_size=None) + elif not writer is None: + writer.close() + writer = None + + recording = not recording + + print(f"Recording: {recording}") + return True + + +def capture_func(action): + global capture + + capture = not capture + + return True + + +def zoom_func(action): + global control, curr_zoom + curr_zoom = curr_zoom * 0.9 + control.set_zoom(curr_zoom) + print(f"Reset: {reset}") + return True + + +mujoco_joint_names = ['Pelvis', 'L_Hip', 'L_Knee', 'L_Ankle', 'L_Toe', 'R_Hip', 'R_Knee', 'R_Ankle', 'R_Toe', 'Torso', 'Spine', 'Chest', 'Neck', 'Head', 'L_Thorax', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'L_Hand', 'R_Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'R_Hand'] + +Name = "getting_started" +Title = "Getting Started" + +data_dir = "data/smpl" +smpl_parser_n = SMPL_Parser(model_path=data_dir, gender="neutral") +smpl_parser_m = SMPL_Parser(model_path=data_dir, gender="male") +smpl_parser_f = SMPL_Parser(model_path=data_dir, gender="female") + +# pkl_dir = "output/renderings/smpl_ego_long_8-2023-01-20-11:28:00.pkl" +# pkl_dir = "output/renderings/smpl_im_comp_8-2023-02-05-15:36:14.pkl" +pkl_dir = "output/renderings/smpl_im_comp_pnn_3-2023-03-07-14:31:50.pkl" +Name = pkl_dir.split("/")[-1].split(".")[0] +pkl_data = joblib.load(pkl_dir) +mujoco_2_smpl = [mujoco_joint_names.index(q) for q in joint_names if q in mujoco_joint_names] + +# data_file = "data/quest/home1_isaac.pkl" +# sk_tree = SkeletonTree.from_mjcf(f"/tmp/smpl/test_good.xml") +# motion_lib = MotionLibSMPLTest("data/quest/home1_isaac.pkl", [7, 3, 22, 17],torch.device("cpu")) +# motion_lib.load_motions(skeleton_trees=[sk_tree], +# gender_betas=[torch.zeros(17)] , +# limb_weights=[np.zeros(10)] , +# random_sample=False) + + +def main(): + global reset, paused, recording, image_list, control, capture + capture = False + o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Debug) + vis = o3d.visualization.VisualizerWithKeyCallback() + vis.create_window() + + opt = vis.get_render_option() + vis.get_render_option().mesh_shade_option = o3d.visualization.MeshShadeOption.Color + + # opt.background_color = [0.5, 0.5, 0.5] # gray color + + smpl_meshes = dict() + items = list(pkl_data.items()) + color_picker = [mpl.colormaps['YlOrBr'], mpl.colormaps['Reds']] + idx = 0 + print(len(items)) + for entry_key, data_seq in tqdm(items): + + gender, beta = data_seq['betas'][0], data_seq['betas'][1:] + if gender == 0: + smpl_parser = smpl_parser_n + elif gender == 1: + smpl_parser = smpl_parser_m + else: + smpl_parser = smpl_parser_f + + pose_quat, trans = data_seq['body_quat'].numpy()[::2], data_seq['trans'].numpy()[::2] + skeleton_tree = SkeletonTree.from_dict(data_seq['skeleton_tree']) + offset = skeleton_tree.local_translation[0] + root_trans_offset = trans - offset.numpy() + + sk_state = SkeletonState.from_rotation_and_root_translation(skeleton_tree, torch.from_numpy(pose_quat), torch.from_numpy(trans), is_local=True) + + global_rot = sk_state.global_rotation + B, J, N = global_rot.shape + pose_quat = (sRot.from_quat(global_rot.reshape(-1, 4).numpy()) * sRot.from_quat([0.5, 0.5, 0.5, 0.5])).as_quat().reshape(B, -1, 4) + B_down = pose_quat.shape[0] + new_sk_state = SkeletonState.from_rotation_and_root_translation(skeleton_tree, torch.from_numpy(pose_quat), torch.from_numpy(trans), is_local=False) + local_rot = new_sk_state.local_rotation + pose_aa = sRot.from_quat(local_rot.reshape(-1, 4).numpy()).as_rotvec().reshape(B_down, -1, 3) + pose_aa = pose_aa[:, mujoco_2_smpl, :].reshape(B_down, -1) + with torch.no_grad(): + vertices, joints = smpl_parser.get_joints_verts(pose=torch.from_numpy(pose_aa), th_trans=torch.from_numpy(root_trans_offset), th_betas=torch.from_numpy(beta[None,])) + + vertices = vertices.numpy() + faces = smpl_parser.faces + max_frames = vertices.shape[0] + + for i in tqdm(range(0, max_frames - 270, 15)): + smpl_mesh = o3d.geometry.TriangleMesh() + smpl_mesh.vertices = o3d.utility.Vector3dVector(vertices[i]) + smpl_mesh.triangles = o3d.utility.Vector3iVector(faces) + # vertex_colors = np.array([35, 102, 218]) / 256 * (1 - i / vertices.shape[0]) + # vertex_colors = color_picker[idx] * ((i + 60) / max_frames) + vertex_colors = color_picker[idx](1 - ((i / max_frames) * 0.5 + 0.2))[:3] + + smpl_mesh.paint_uniform_color(vertex_colors) + smpl_mesh.compute_vertex_normals() + vis.add_geometry(smpl_mesh) + + for i in tqdm(range(max_frames - 280, max_frames - 230, 15)): + smpl_mesh = o3d.geometry.TriangleMesh() + smpl_mesh.vertices = o3d.utility.Vector3dVector(vertices[i]) + smpl_mesh.triangles = o3d.utility.Vector3iVector(faces) + # vertex_colors = np.array([35, 102, 218]) / 256 * (1 - i / vertices.shape[0]) + # vertex_colors = color_picker[idx] * ((i + 60) / max_frames) + vertex_colors = color_picker[idx](1 - ((i / max_frames) * 0.5 + 0.2))[:3] + + smpl_mesh.paint_uniform_color(vertex_colors) + smpl_mesh.compute_vertex_normals() + vis.add_geometry(smpl_mesh) + + for i in tqdm(range(max_frames - 230, max_frames - 200, 8)): + smpl_mesh = o3d.geometry.TriangleMesh() + smpl_mesh.vertices = o3d.utility.Vector3dVector(vertices[i]) + smpl_mesh.triangles = o3d.utility.Vector3iVector(faces) + # vertex_colors = np.array([35, 102, 218]) / 256 * (1 - i / vertices.shape[0]) + # vertex_colors = color_picker[idx] * ((i + 60) / max_frames) + vertex_colors = color_picker[idx](1 - ((i / max_frames) * 0.5 + 0.2))[:3] + + smpl_mesh.paint_uniform_color(vertex_colors) + smpl_mesh.compute_vertex_normals() + vis.add_geometry(smpl_mesh) + + for i in tqdm(range(max_frames - 200, max_frames - 30, 15)): + smpl_mesh = o3d.geometry.TriangleMesh() + smpl_mesh.vertices = o3d.utility.Vector3dVector(vertices[i]) + smpl_mesh.triangles = o3d.utility.Vector3iVector(faces) + # vertex_colors = np.array([35, 102, 218]) / 256 * (1 - i / vertices.shape[0]) + # vertex_colors = color_picker[idx] * ((i + 60) / max_frames) + vertex_colors = color_picker[idx](1 - ((i / max_frames) * 0.5 + 0.2))[:3] + + smpl_mesh.paint_uniform_color(vertex_colors) + smpl_mesh.compute_vertex_normals() + vis.add_geometry(smpl_mesh) + + for i in tqdm(range(max_frames - 30, max_frames, 15)): + smpl_mesh = o3d.geometry.TriangleMesh() + smpl_mesh.vertices = o3d.utility.Vector3dVector(vertices[i]) + smpl_mesh.triangles = o3d.utility.Vector3iVector(faces) + # vertex_colors = np.array([35, 102, 218]) / 256 * (1 - i / vertices.shape[0]) + vertex_colors = color_picker[idx](1 - ((i / max_frames) * 0.5 + 0.2))[:3] + + smpl_mesh.paint_uniform_color(vertex_colors) + smpl_mesh.compute_vertex_normals() + vis.add_geometry(smpl_mesh) + + smpl_meshes[entry_key] = { + 'mesh': smpl_mesh, + "vertices": vertices, + } + idx += 1 + + box = o3d.geometry.TriangleMesh() + ground_size, height = 50, 0.01 + box = box.create_box(width=ground_size, height=height, depth=ground_size) + box.translate(np.array([-ground_size / 2, -height, -ground_size / 2])) + box.rotate(sRot.from_euler("xyz", [np.pi / 2, 0, 0]).as_matrix()) + box.compute_vertex_normals() + box.compute_triangle_normals() + box.vertex_colors = o3d.utility.Vector3dVector(np.array([[1, 1, 1]]).repeat(8, axis=0)) + # box.paint_uniform_color(vertex_colors) + vis.add_geometry(box) + + control = vis.get_view_control() + + control.unset_constant_z_far() + control.unset_constant_z_near() + i = 0 + + vis.register_key_callback(32, pause_func) + vis.register_key_callback(82, reset_func) + vis.register_key_callback(76, record_func) + vis.register_key_callback(67, capture_func) + vis.register_key_callback(90, zoom_func) + + control.set_up(np.array([0, 0, 1])) + control.set_front(np.array([-5, 0, 1])) + control.set_lookat(np.array([0, 0, 1])) + + control.set_zoom(0.1) + dt = 1 / 30 + + tracker_pos = pkl_data['0_0']['ref_body_pos_subset'][::2].cpu().numpy() + + while True: + vis.poll_events() + + # if not paused: + # i = (i + 1) + + if reset: + i = 0 + reset = False + if recording: + rgb = vis.capture_screen_float_buffer() + rgb = (np.asarray(rgb) * 255).astype(np.uint8) + writer.append_data(rgb) + if capture: + rgb = vis.capture_screen_float_buffer() + rgb = (np.asarray(rgb) * 255).astype(np.uint8) + name = input("Enter image name:") + img_name = f"output/renderings/iccv2023/{name}.png" + print("Captruing image to {}".format(img_name)) + cv2.imwrite(img_name, np.asarray(rgb)[..., ::-1]) + capture = False + + vis.update_renderer() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/vis/vis_smpl_o3d_single.py b/scripts/vis/vis_smpl_o3d_single.py new file mode 100644 index 0000000..976d961 --- /dev/null +++ b/scripts/vis/vis_smpl_o3d_single.py @@ -0,0 +1,262 @@ +import glob +import os +import sys +import pdb +import os.path as osp + +sys.path.append(os.getcwd()) + +import open3d as o3d +import open3d.visualization.rendering as rendering +import imageio +from tqdm import tqdm +import joblib +import numpy as np +import torch + +from smpl_sim.smpllib.smpl_parser import ( + SMPL_Parser, + SMPLH_Parser, + SMPLX_Parser, +) +import random + +from smpl_sim.smpllib.smpl_mujoco import SMPL_BONE_ORDER_NAMES as joint_names +from poselib.poselib.skeleton.skeleton3d import SkeletonTree, SkeletonMotion, SkeletonState +from scipy.spatial.transform import Rotation as sRot +import matplotlib.pyplot as plt +from tqdm import tqdm +import cv2 +import matplotlib as mpl +from datetime import datetime + +colorpicker = mpl.colormaps['Blues'] + +paused, reset, recording, image_list, writer, control, curr_zoom = False, False, False, [], None, None, 0.01 + + +def pause_func(action): + global paused + paused = not paused + print(f"Paused: {paused}") + return True + + +def reset_func(action): + global reset + reset = not reset + print(f"Reset: {reset}") + return True + + +def record_func(action): + global recording, writer + if not recording: + fps = 30 + curr_date_time = datetime.now().strftime('%Y-%m-%d-%H:%M:%S') + curr_video_file_name = f"output/renderings/o3d/{curr_date_time}-test.mp4" + print(f"==================== writing to videl {curr_video_file_name} ====================") + writer = imageio.get_writer(curr_video_file_name, fps=fps, macro_block_size=None) + elif not writer is None: + writer.close() + writer = None + + recording = not recording + + print(f"Recording: {recording}") + return True + + +def capture_func(action): + global capture + + capture = not capture + + return True + + +def zoom_func(action): + global control, curr_zoom + curr_zoom = curr_zoom * 0.9 + control.set_zoom(curr_zoom) + print(f"Reset: {reset}") + return True + + +mujoco_joint_names = ['Pelvis', 'L_Hip', 'L_Knee', 'L_Ankle', 'L_Toe', 'R_Hip', 'R_Knee', 'R_Ankle', 'R_Toe', 'Torso', 'Spine', 'Chest', 'Neck', 'Head', 'L_Thorax', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'L_Hand', 'R_Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'R_Hand'] + +Name = "getting_started" +Title = "Getting Started" + +data_dir = "data/smpl" +smpl_parser_n = SMPL_Parser(model_path=data_dir, gender="neutral") +smpl_parser_m = SMPL_Parser(model_path=data_dir, gender="male") +smpl_parser_f = SMPL_Parser(model_path=data_dir, gender="female") + +# pkl_dir = "output/renderings/smpl_ego_long_8-2023-01-20-11:28:00.pkl" +# pkl_dir = "output/renderings/smpl_im_comp_8-2023-02-05-15:36:14.pkl" +# pkl_dir = "output/renderings/smpl_im_comp_pnn_1_1_demo-2023-03-12-18:57:01.pkl" +pkl_dir = "output/renderings/smpl_im_comp_pnn_1_1_demo-2023-03-14-14:40:46.pkl" +Name = pkl_dir.split("/")[-1].split(".")[0] +pkl_data = joblib.load(pkl_dir) +mujoco_2_smpl = [mujoco_joint_names.index(q) for q in joint_names if q in mujoco_joint_names] + +# data_file = "data/quest/home1_isaac.pkl" +# sk_tree = SkeletonTree.from_mjcf(f"/tmp/smpl/test_good.xml") +# motion_lib = MotionLibSMPLTest("data/quest/home1_isaac.pkl", [7, 3, 22, 17],torch.device("cpu")) +# motion_lib.load_motions(skeleton_trees=[sk_tree], +# gender_betas=[torch.zeros(17)] , +# limb_weights=[np.zeros(10)] , +# random_sample=False) + + +def main(): + global reset, paused, recording, image_list, control, capture + capture = False + o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Debug) + vis = o3d.visualization.VisualizerWithKeyCallback() + vis.create_window() + + opt = vis.get_render_option() + # vis.get_render_option().mesh_shade_option = o3d.visualization.MeshShadeOption.Color + + opt.background_color = [1, 1, 1] + + smpl_meshes = dict() + items = list(pkl_data.items()) + idx = 0 + print(len(items)) + vertices_acc = [] + + for entry_key, data_seq in tqdm(items): + + gender, beta = data_seq['betas'][0], data_seq['betas'][1:] + if gender == 0: + smpl_parser = smpl_parser_n + elif gender == 1: + smpl_parser = smpl_parser_m + else: + smpl_parser = smpl_parser_f + + pose_quat, trans = data_seq['body_quat'].numpy()[::2], data_seq['trans'].numpy()[::2] + skeleton_tree = SkeletonTree.from_dict(data_seq['skeleton_tree']) + offset = skeleton_tree.local_translation[0] + root_trans_offset = trans - offset.numpy() + + sk_state = SkeletonState.from_rotation_and_root_translation(skeleton_tree, torch.from_numpy(pose_quat), torch.from_numpy(trans), is_local=True) + + global_rot = sk_state.global_rotation + B, J, N = global_rot.shape + pose_quat = (sRot.from_quat(global_rot.reshape(-1, 4).numpy()) * sRot.from_quat([0.5, 0.5, 0.5, 0.5])).as_quat().reshape(B, -1, 4) + B_down = pose_quat.shape[0] + new_sk_state = SkeletonState.from_rotation_and_root_translation(skeleton_tree, torch.from_numpy(pose_quat), torch.from_numpy(trans), is_local=False) + local_rot = new_sk_state.local_rotation + pose_aa = sRot.from_quat(local_rot.reshape(-1, 4).numpy()).as_rotvec().reshape(B_down, -1, 3) + pose_aa = pose_aa[:, mujoco_2_smpl, :].reshape(B_down, -1) + with torch.no_grad(): + vertices, joints = smpl_parser.get_joints_verts(pose=torch.from_numpy(pose_aa), th_trans=torch.from_numpy(root_trans_offset), th_betas=torch.from_numpy(beta[None,])) + + vertices = vertices.numpy() + faces = smpl_parser.faces + + smpl_mesh = o3d.geometry.TriangleMesh() + smpl_mesh.vertices = o3d.utility.Vector3dVector(vertices[0]) + smpl_mesh.triangles = o3d.utility.Vector3iVector(faces) + vertex_colors = colorpicker(0.6 - idx * 0.3)[:3] + smpl_mesh.paint_uniform_color(vertex_colors) + # smpl_mesh.compute_triangle_normals() + smpl_mesh.compute_vertex_normals() + ######################## Smampling texture ######################## + vis.add_geometry(smpl_mesh) + smpl_meshes[entry_key] = { + 'mesh': smpl_mesh, + "vertices": vertices, + } + idx += 1 + # vertices_acc.append(vertices) + + # faces = smpl_parser.faces + # vertices = np.concatenate(vertices_acc) + # max_frames = vertices.shape[0] + + # smpl_mesh = o3d.geometry.TriangleMesh() + # smpl_mesh.vertices = o3d.utility.Vector3dVector(vertices[0]) + # smpl_mesh.triangles = o3d.utility.Vector3iVector(faces) + # # vertex_colors = np.array([35, 102, 218]) / 256 * (1 - i / vertices.shape[0]) + # # vertex_colors = color_picker[idx] * ((i + 60) / max_frames) + # vertex_colors = colorpicker(0.6)[:3] + + # smpl_mesh.paint_uniform_color(vertex_colors) + # smpl_mesh.compute_vertex_normals() + # vis.add_geometry(smpl_mesh) + + # smpl_meshes[entry_key] = { + # 'mesh': smpl_mesh, + # "vertices": vertices, + # } + + + box = o3d.geometry.TriangleMesh() + ground_size, height = 50, 0.01 + box = box.create_box(width=ground_size, height=height, depth=ground_size) + box.translate(np.array([-ground_size / 2, -height, -ground_size / 2])) + box.rotate(sRot.from_euler("xyz", [np.pi / 2, 0, 0]).as_matrix()) + box.compute_vertex_normals() + # box.compute_triangle_normals() + box.vertex_colors = o3d.utility.Vector3dVector(np.array([[1, 1, 1]]).repeat(8, axis=0)) + # box.paint_uniform_color(vertex_colors) + vis.add_geometry(box) + + control = vis.get_view_control() + + control.unset_constant_z_far() + control.unset_constant_z_near() + i = 0 + + vis.register_key_callback(32, pause_func) + vis.register_key_callback(82, reset_func) + vis.register_key_callback(76, record_func) + vis.register_key_callback(67, capture_func) + vis.register_key_callback(90, zoom_func) + + control.set_up(np.array([0, 0, 1])) + control.set_front(np.array([0, 5, 1])) + control.set_lookat(np.array([0, 0, 1])) + + control.set_zoom(1) + dt = 1 / 30 + + tracker_pos = pkl_data['0_0']['ref_body_pos_subset'][::2].cpu().numpy() + + while True: + vis.poll_events() + for smpl_mesh_key, smpl_mesh_data in smpl_meshes.items(): + verts = smpl_mesh_data["vertices"] + smpl_mesh_data["mesh"].vertices = o3d.utility.Vector3dVector(verts[i % verts.shape[0]]) + vis.update_geometry(smpl_mesh_data["mesh"]) + + + if not paused: + i = (i + 1) + + if reset: + i = 0 + reset = False + if recording: + rgb = vis.capture_screen_float_buffer() + rgb = (np.asarray(rgb) * 255).astype(np.uint8) + writer.append_data(rgb) + if capture: + rgb = vis.capture_screen_float_buffer() + rgb = (np.asarray(rgb) * 255).astype(np.uint8) + name = input("Enter image name:") + img_name = f"output/renderings/iccv2023/{name}.png" + print("Captruing image to {}".format(img_name)) + cv2.imwrite(img_name, np.asarray(rgb)[..., ::-1]) + capture = False + + vis.update_renderer() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/ws_client.py b/scripts/ws_client.py new file mode 100644 index 0000000..fe865a6 --- /dev/null +++ b/scripts/ws_client.py @@ -0,0 +1,69 @@ +import asyncio +import os + +import aiohttp +import json +import numpy as np + +import subprocess + +HOST = os.getenv('HOST', '172.29.229.220') +# HOST = os.getenv('HOST', '0.0.0.0') +# HOST = os.getenv('HOST', 'KLAB-BUTTER.PC.CS.CMU.EDU') +PORT = int(os.getenv('PORT', 8080)) + + +async def main(): + session = aiohttp.ClientSession() + URL = f'http://{HOST}:{PORT}/ws_talk' + async with session.ws_connect(URL) as ws: + + await prompt_and_send(ws) + async for msg in ws: + print('Message received from server:', msg.data) + await prompt_and_send(ws) + if msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR): + break + + # session = aiohttp.ClientSession() + # URL = f'http://{HOST}:{PORT}/ws' + # import time + # async with session.ws_connect(URL) as ws: + # await ws.send_str("get_pose") + # async for msg in ws: + # t_s = time.time() + # json_data = json.loads(msg.data) + # print(json_data['pose_mat'][0]) + + # await ws.send_str("get_pose") + + # if msg.type in (aiohttp.WSMsgType.CLOSED, + # aiohttp.WSMsgType.ERROR): + # break + + # await asyncio.sleep(1/30) + + # dt = time.time() - t_s + # print(1/dt) + + +async def prompt_and_send(ws): + new_msg_to_send = input('Type a message to send to the server: ') + if new_msg_to_send == 'exit': + print('Exiting!') + raise SystemExit(0) + elif new_msg_to_send == "s": + # subprocess.Popen(["simplescreenrecorder", "--start-recording"]) + pass + elif new_msg_to_send == "e": + pass + + await ws.send_str(new_msg_to_send) + return new_msg_to_send + + +if __name__ == '__main__': + print('Type "exit" to quit') + # loop = asyncio.get_event_loop() + # loop.run_forever(main()) + asyncio.run(main()) \ No newline at end of file