From 93bc09a32a4415d926c5edfc065c1e6609bf4dd5 Mon Sep 17 00:00:00 2001 From: Kelvin Lee Date: Fri, 5 Apr 2024 07:49:50 -0700 Subject: [PATCH] git: readying for public release --- .github/dependabot.yml | 11 + .gitignore | 160 +++ .pre-commit-config.yaml | 16 + CHANGELOG.md | 6 + LICENSE | 202 ++++ README.md | 127 +++ docker/.gitkeep | 0 docker/Dockerfile.intel | 22 + intel-requirements.txt | 5 + ...l harmonics symbolic differentiation.ipynb | 948 +++++++++++++++++ pyproject.toml | 88 ++ scripts/benchmark.py | 134 +++ scripts/dynamic_shapes.py | 98 ++ scripts/measure_numerical_error.py | 178 ++++ scripts/profile_script.py | 109 ++ src/equitriton/__init__.py | 29 + src/equitriton/benchmark.py | 144 +++ src/equitriton/patch.py | 65 ++ src/equitriton/sph_harm/__init__.py | 9 + src/equitriton/sph_harm/bindings.py | 340 +++++++ src/equitriton/sph_harm/main.py | 94 ++ .../sph_harm/tests/test_correctness.py | 111 ++ src/equitriton/sph_harm/tests/test_main.py | 58 ++ src/equitriton/sph_harm/tests/test_utils.py | 16 + src/equitriton/sph_harm/triton_kernels.py | 962 ++++++++++++++++++ src/equitriton/tests/test_benchmark.py | 16 + src/equitriton/utils.py | 45 + 27 files changed, 3993 insertions(+) create mode 100644 .github/dependabot.yml create mode 100644 .gitignore create mode 100644 .pre-commit-config.yaml create mode 100644 CHANGELOG.md create mode 100644 LICENSE create mode 100644 README.md create mode 100644 docker/.gitkeep create mode 100644 docker/Dockerfile.intel create mode 100644 intel-requirements.txt create mode 100644 notebooks/Spherical harmonics symbolic differentiation.ipynb create mode 100644 pyproject.toml create mode 100644 scripts/benchmark.py create mode 100644 scripts/dynamic_shapes.py create mode 100644 scripts/measure_numerical_error.py create mode 100644 scripts/profile_script.py create mode 100644 src/equitriton/__init__.py create mode 100644 src/equitriton/benchmark.py create mode 100644 src/equitriton/patch.py create mode 100644 src/equitriton/sph_harm/__init__.py create mode 100644 src/equitriton/sph_harm/bindings.py create mode 100644 src/equitriton/sph_harm/main.py create mode 100644 src/equitriton/sph_harm/tests/test_correctness.py create mode 100644 src/equitriton/sph_harm/tests/test_main.py create mode 100644 src/equitriton/sph_harm/tests/test_utils.py create mode 100644 src/equitriton/sph_harm/triton_kernels.py create mode 100644 src/equitriton/tests/test_benchmark.py create mode 100644 src/equitriton/utils.py diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..9d866e3 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,11 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file + +version: 2 +updates: + - package-ecosystem: "pip" # See documentation for possible values + directory: "/" # Location of package manifests + schedule: + interval: "weekly" diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..68bc17f --- /dev/null +++ b/.gitignore @@ -0,0 +1,160 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..4dd9d9e --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,16 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: debug-statements +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.5.0 + hooks: + - id: ruff + args: [ --fix ] + types_or: [ python, pyi, jupyter ] + - id: ruff-format + types_or: [ python, pyi, jupyter ] diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..c4bf327 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,6 @@ +# CHANGELOG + +## v0.1.0 (2024-06-28) + +Initial release, includes up to $l=4$ forward/backward kernels for +spherical harmonics, based on the original implementation for `e3nn`. diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..2e90d33 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2024 Intel Labs + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..090ddf5 --- /dev/null +++ b/README.md @@ -0,0 +1,127 @@ +# EquiTriton + +
+ +[![pytorch](https://img.shields.io/badge/PyTorch-v2.1.0-red?logo=pytorch)](https://pytorch.org/get-started/locally/) +[![License: Apache2.0](https://img.shields.io/badge/License-Apache-yellow.svg)](https://opensource.org/licenses/apache-2-0) +![python-support](https://img.shields.io/badge/Python-3.10%7C3.11%7C3.12-3?logo=python) +![triton](https://img.shields.io/badge/Triton-2.10-2?link=https%3A%2F%2Fgithub.com%2Fintel%2Fintel-xpu-backend-for-triton%2Freleases%2Ftag%2Fv2.1.0) + + +
+ +_Performant kernels for equivariant neural networks in Triton-lang_ + +## Introduction + +_EquiTriton_ is a project that seeks to implement high-performance kernels +for commonly used building blocks in equivariant neural networks, enabling +compute efficient training and inference. The advantage of Triton-lang is +portability across GPU architectures: kernels here have been tested against +GPUs from multiple vendors, including A100/H100 from Nvidia, and the Intel®️ +Data Center GPU Max Series 1550. + +Our current scope includes components such as spherical harmonics (including +derivatives, up to $l=4$), and we intend to expand this set quickly. If you +feel that a particular set of kernels would be valuable, please feel free +to submit an issue or pull request! + + +## Getting Started + +For users, run `pip install git+https://github.com/IntelLabs/EquiTriton`. For those who +are using Intel XPUs, we recommend you reading the section on Intel XPU usage first, +and setting up an environment with PyTorch, IPEX, and Triton for XPU before installing +_EquiTriton_. + +For developers/contributors, please clone this repository and install it in editable mode: + +```console +git clone https://github.com/IntelLabs/EquiTriton +cd EquiTriton +pip install -e './[dev]' +``` + +...which will include development dependencies such as `pre-commit` (used for linting +and formatting), and `jupyter` used for symbolic differentiation for kernel development. + +Finally, we provide `Dockerfile`s for users who prefer containers. + +## Usage + +As a drop-in replacement for `e3nn` spherical harmonics, simply include the +following in your code: + +```python +from equitriton import patch +``` + +This will dynamically replace the `e3nn` spherical harmonics implementation +with the _EquiTriton_ kernels. + +There are two important things to consider before replacing: + +1. Numerically, there are small differences between implementations, primarily +in the backward pass. Because terms in the gradients are implemented as literals, +they can be more susceptible to rounding errors at lower precision. In most +(not all!) instances, they are numerically equivalent for `torch.float32`, and +basically _always_ different for `torch.float16`. At double precision (`torch.float64`) +this does not seem to be an issue, which makes it ideal for use in simulation loops but +please be aware that if it is used for training, the optimization trajectory may not +be exactly the same; we have not tested for divergence and encourage experimentation. +2. Triton kernels are compiled just-in-time and a cached every time it encounters +a new input tensor shape. In `equitriton.sph_harm.SphericalHarmonics`, the `pad_tensor` +argument (default is `True`) is used to try and maximize cache re-use by padding +nodes and masking in the forward pass. The script `scripts/dynamic_shapes.py` will +let you test the performance over a range of shapes; we encourage you to test it +before performing full-scale training/inference. + +### Development and usage on Intel XPU + +Development on Intel XPUs such as the Data Center GPU Max Series 1550 requires +a number of manual components for bare metal. The core dependency to consider +is the [Intel XPU backend for Triton][triton-git], which will dictate the version +of oneAPI, PyTorch, and Intel Extension for PyTorch to install. At the time +of release, _EquiTriton_ has been developed on the following: + +- oneAPI 2024.0 +- PyTorch 2.1.0 +- IPEX 2.1.10+xpu +- Intel XPU backend for Triton [2.1.0](https://github.com/intel/intel-xpu-backend-for-triton/releases/tag/v2.1.0) + +Due to the way that wheels are distributed, please install PyTorch +and IPEX per `intel-requirements.txt`. Alternatively, use the provided +Docker image for development. + +```python +>>> import intel_extension_for_pytorch +>>> import torch +>>> torch.xpu.device_count() +# should be greater than zero +``` +[triton-git]: https://github.com/intel/intel-xpu-backend-for-triton/releases/tag/v2.1.0 + +## Useful commands for Intel GPUs + +- `xpu-smi` (might not be installed) as the name suggests is the equivalent to `nvidia-smi`, +but with a bit more functionality based on our architecture +- `sycl-ls` is provided by the `dpcpp` runtime, and lists out all devices that are OpenCL +and SYCL capable. Notably this can be used to quickly check how many GPUs are available. +- [pti-gpu](https://github.com/intel/pti-gpu) provides a set of tools that you can compile for profiling. Notably, +`unitrace` and `oneprof` allows you do to low-level profiling for the device. + + +Contributing +------------ + +We welcome contributions from the open-source community! If you have any +questions or suggestions, feel free to create an issue in our +repository. We will be happy to work with you to make this project even +better. + +License +------- + +The code and documentation in this repository are licensed under the Apache 2.0 +license. By contributing to this project, you agree that your +contributions will be licensed under this license. diff --git a/docker/.gitkeep b/docker/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/docker/Dockerfile.intel b/docker/Dockerfile.intel new file mode 100644 index 0000000..b7827aa --- /dev/null +++ b/docker/Dockerfile.intel @@ -0,0 +1,22 @@ +# pulls a docker image with tested PyTorch+IPEX+Triton stack +FROM intel/intel-extension-for-pytorch:2.1.10-xpu + +LABEL org.opencontainers.image.title="equitriton" +LABEL org.opencontainers.image.description="Docker image with Intel XPU support for EquiTriton." +LABEL org.opencontainers.image.licenses="MIT" +LABEL org.opencontainers.image.source="https://github.com/IntelLabs/EquiTriton/tree/main/docker/Dockerfile.intel" +LABEL org.opencontainers.image.url="https://github.com/IntelLabs/EquiTriton" +LABEL org.opencontainers.image.documentation="https://github.com/IntelLabs/EquiTriton/tree/main/README.md" +LABEL org.opencontainers.image.version="0.1.0" +LABEL org.opencontainers.image.created="2024-07-09" + +LABEL software.python.version="3.10.12" +LABEL software.pytorch.version="2.1.0" +LABEL software.ipex.version="2.1.10+xpu" +LABEL software.triton.version="2.1.0" + +RUN pip install -U setuptools==69.5 pip +RUN pip install https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v2.1.0/triton-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl +RUN pip install git+https://github.com/IntelLabs/EquiTriton + +HEALTHCHECK NONE diff --git a/intel-requirements.txt b/intel-requirements.txt new file mode 100644 index 0000000..3601a12 --- /dev/null +++ b/intel-requirements.txt @@ -0,0 +1,5 @@ +--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ +torch==2.1.0a0 +intel-extension-for-pytorch==2.1.10+xpu +oneccl_bind_pt==2.1.200 +https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v2.1.0/triton-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl diff --git a/notebooks/Spherical harmonics symbolic differentiation.ipynb b/notebooks/Spherical harmonics symbolic differentiation.ipynb new file mode 100644 index 0000000..96a66f9 --- /dev/null +++ b/notebooks/Spherical harmonics symbolic differentiation.ipynb @@ -0,0 +1,948 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "fe44c9ad-908d-4c5b-8f1b-3899b3edde67", + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "from sympy import symbols, sqrt, diff, Symbol, latex" + ] + }, + { + "cell_type": "markdown", + "id": "0713ba07-311b-4032-b9c8-e152b4259811", + "metadata": {}, + "source": [ + "This notebook uses `sympy` to perform symbolic differentiation to help with writing the manual backward pass for each order. To maintain consistency, the forward pass equations are direct copies/transcribed from the `e3nn` spherical harmonics functions, and for that reason we break up the sections in this notebook that \"match\" the `e3nn` implementation." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "fbf064e8-6cc1-44e6-8a4d-391d64463f30", + "metadata": {}, + "outputs": [], + "source": [ + "x, y, z = symbols(\"x y z\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "6a45fc5c-ad63-4423-a8ec-933cadfc7878", + "metadata": {}, + "outputs": [], + "source": [ + "def take_derivative(expr, symbols: list[Symbol], simplify: bool = True):\n", + " \"\"\"\n", + " Function to take the derivative of a symbolic equation with respect\n", + " to a list of symbols.\n", + "\n", + " We loop through each symbol, and if it is used in the equation,\n", + " we take the first derivative with respect to that function.\n", + " \"\"\"\n", + " return_dict = {}\n", + " for symbol in symbols:\n", + " if symbol in expr.free_symbols:\n", + " deriv = diff(expr, symbol)\n", + " if simplify:\n", + " deriv = deriv.simplify()\n", + " return_dict[str(symbol)] = deriv\n", + " if len(return_dict) == 0:\n", + " raise RuntimeError(\"None of the requested symbols were used in the expression!\")\n", + " return return_dict" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "16ff41cd-d31e-48dc-a153-99a11f2457eb", + "metadata": {}, + "outputs": [], + "source": [ + "def collect_derivatives(derivs: dict, l: int, simplify: bool = True) -> dict:\n", + " \"\"\"Collect up derivatives from each component in terms of a cartesian axis\"\"\"\n", + " joint = {}\n", + " for axis in [\"x\", \"y\", \"z\"]:\n", + " for m, components in derivs.items():\n", + " # we use a dYl^m symbol to denote that you need to multiply\n", + " # by this particular component's gradient\n", + " m_symbol = symbols(f\"dY{l}^{m}\")\n", + " # not every component contributes gradients to an axis\n", + " if axis in components:\n", + " if axis not in joint:\n", + " joint[axis] = components[axis] * m_symbol\n", + " else:\n", + " joint[axis] += components[axis] * m_symbol\n", + " if simplify:\n", + " for axis, expr in joint.items():\n", + " joint[axis] = expr.simplify()\n", + " return joint" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "664e4abd-6af7-4d68-83e3-be1286082e50", + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"\n", + "This cell implements helper functions that will reformat the sympy/Python string result\n", + "into something that can be more readily copy-pasted for the Triton implementation.\n", + "\"\"\"\n", + "\n", + "\n", + "def replace_terms_for_implementation(e: str, mapping: dict[str, str]) -> str:\n", + " \"\"\"\n", + " Implements a function that remaps a sympy expression to match\n", + " the syntax used in the actual Triton implementation\n", + " \"\"\"\n", + " for key, value in mapping.items():\n", + " e = e.replace(key, value)\n", + " return e\n", + "\n", + "\n", + "mapping = {\n", + " \"sqrt(15)\": \"sqrt_15\",\n", + " \"x**4.0\": \"sq_x * sq_x\",\n", + " \"z**4.0\": \"sq_z * sq_z\",\n", + " \"x**4\": \"sq_x * sq_x\",\n", + " \"z**4\": \"sq_z * sq_z\",\n", + "}\n", + "\n", + "for power, prefix in zip([1, 2, 3], [\"\", \"sq_\", \"cu_\"]):\n", + " for axis in [\"x\", \"y\", \"z\"]:\n", + " # this one comes the decimal one comes first, otherwise\n", + " # it leaves dangling .0\n", + " mapping[f\"{axis}**{power}.0\"] = f\"{prefix}{axis}\"\n", + " mapping[f\"{axis}**{power}\"] = f\"{prefix}{axis}\"\n", + "\n", + "\"\"\"\n", + "This adds the gradient terms to the re-mapping dictionary; ideally\n", + "this makes it so that everything is literally just copy paste!\n", + "\"\"\"\n", + "\n", + "\n", + "def num_projections(l: int) -> int:\n", + " return 2 * l + 1\n", + "\n", + "\n", + "for l_max in range(1, 4 + 1):\n", + " for m in range(num_projections(l_max)):\n", + " mapping[f\"dY{l_max}^{m}\"] = f\"g_{l_max}_{m}\"\n", + "\n", + "\n", + "def export_to_latex(expr) -> None:\n", + " \"\"\"Function to format a sympy expression for copy/pasting into LaTeX\"\"\"\n", + " expr_str = latex(expr)\n", + " expr_mapping = {\n", + " \"dY\": \"\\\\nabla Y\",\n", + " \"^{1.0}\": \"\",\n", + " \"^{2.0}\": \"^2\",\n", + " \"^{3.0}\": \"^3\",\n", + " \"^{4.0}\": \"^4\",\n", + " }\n", + " for key, value in expr_mapping.items():\n", + " expr_str = expr_str.replace(key, value)\n", + " print(expr_str)" + ] + }, + { + "cell_type": "markdown", + "id": "0149cca3-3dc5-4781-9883-e208e4c161e9", + "metadata": {}, + "source": [ + "## First order spherical harmonics\n", + "\n", + "...is just constant because it's just $\\sqrt{3}(x,y,z)$" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "6eaa0261-6faf-4994-ba6e-918106acde75", + "metadata": {}, + "outputs": [], + "source": [ + "zeroth_x = x * sqrt(3.0)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "7a32af2a-6dc2-424b-ae5d-f83ff244d905", + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$\\displaystyle 1.73205080756888$" + ], + "text/plain": [ + "1.73205080756888" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "diff(zeroth_x, x)" + ] + }, + { + "cell_type": "markdown", + "id": "a865cb38-cf12-4e96-b716-bc689de0d493", + "metadata": {}, + "source": [ + "## Second order spherical harmonics\n", + "\n", + "A little bit more involved; note I'm butchering the syntax where it's normally meant to be $y_l^m$, but I'm using $m$ indexing as if it was written in `e3nn` code.\n", + "\n", + "$$ y_2^0 = \\sqrt{15}xz $$\n", + "\n", + "$$ y_2^1 = \\sqrt{15}xy $$\n", + "\n", + "$$ y_2^3 = \\sqrt{15}yz $$\n", + "\n", + "$$ y_2^4 = \\sqrt{5} (y^2 - \\frac{1}{2}(x^2 + z^2) $$\n", + "\n", + "$$ y_2^5 = \\frac{\\sqrt{15}}{2} (z^2 - x^2) $$" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "95710a47-8292-4ec0-83f4-50679af6092f", + "metadata": {}, + "outputs": [], + "source": [ + "y_2_0 = sqrt(15) * x * z\n", + "y_2_1 = sqrt(15) * x * y\n", + "y_2_2 = sqrt(15) * y * z\n", + "y_2_3 = sqrt(5) * (y**2 - 0.5 * (x**2 + z**2))\n", + "y_2_4 = (sqrt(15) / 2) * (z**2.0 - x**2.0)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "164f7da2-43ac-48d8-86bd-3ffe2f2e5710", + "metadata": {}, + "outputs": [], + "source": [ + "second_order = {}\n", + "for index, expr in enumerate([y_2_0, y_2_1, y_2_2, y_2_3, y_2_4]):\n", + " second_order[str(index)] = take_derivative(expr, [x, y, z])" + ] + }, + { + "cell_type": "markdown", + "id": "fe92b282-70f8-4909-b917-54b971963d41", + "metadata": {}, + "source": [ + "The cell below shows how each projection of $l$ contributes to $x,y,z$ axes. It's a convenient way to inspect the appearance of terms, but not so much for the actual implementation." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "5cef3f35-cd17-42a8-8f20-1dced30533e9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'0': {'x': sqrt(15)*z, 'z': sqrt(15)*x},\n", + " '1': {'x': sqrt(15)*y, 'y': sqrt(15)*x},\n", + " '2': {'y': sqrt(15)*z, 'z': sqrt(15)*y},\n", + " '3': {'x': -1.0*sqrt(5)*x, 'y': 2*sqrt(5)*y, 'z': -1.0*sqrt(5)*z},\n", + " '4': {'x': -1.0*sqrt(15)*x**1.0, 'z': 1.0*sqrt(15)*z**1.0}}" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "second_order" + ] + }, + { + "cell_type": "markdown", + "id": "f10d5a36-72d8-4806-8adf-d03cf46027cb", + "metadata": {}, + "source": [ + "Instead, we call `collect_derivatives` to aggregate each expression, and reorder it so that it's in terms of the derivatives of _each projection_ with respect to $x,y,z$. This makes it so that the Triton implementation is straightforward, since it maps onto how backprop actually produces gradients." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "a31b03d1-70c6-473d-9061-8a526f5ac2b2", + "metadata": {}, + "outputs": [], + "source": [ + "second_order_cartesian = collect_derivatives(second_order, 2)" + ] + }, + { + "cell_type": "markdown", + "id": "bec25cc1-6c91-49af-827b-636cdd2a7811", + "metadata": {}, + "source": [ + "So the cell below derives $\\frac{\\partial Y_2}{\\partial x}$; i.e. how each projection of 2nd order spherical harmonic contributes to the total gradient of $x$." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "9f93b6e3-f8e4-40e7-b01f-782e99b95348", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'1.0*sqrt_15*g_2_0*z + 1.0*sqrt_15*g_2_1*y - 1.0*sqrt(5)*g_2_3*x - 1.0*sqrt_15*g_2_4*x'" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "replace_terms_for_implementation(str(second_order_cartesian[\"x\"]), mapping)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "012152ab-9227-447c-8b0a-7b39004935e7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'sqrt_15*g_2_1*x + sqrt_15*g_2_2*z + 2*sqrt(5)*g_2_3*y'" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "replace_terms_for_implementation(str(second_order_cartesian[\"y\"]), mapping)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "598c64b5-8182-4c73-92ec-2b2c9d263372", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'1.0*sqrt_15*g_2_0*x + 1.0*sqrt_15*g_2_2*y - 1.0*sqrt(5)*g_2_3*z + 1.0*sqrt_15*g_2_4*z'" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "replace_terms_for_implementation(str(second_order_cartesian[\"z\"]), mapping)" + ] + }, + { + "cell_type": "markdown", + "id": "b534b165-a9e2-44ed-8125-d512185e70d6", + "metadata": {}, + "source": [ + "The remainder of this notebook follows the exact same recipe, and will not be documented as thoroughly as it is mainly for the implementation/reference. Some cells include a `print(latex(...))` line, which is used to pretty-print/format for $\\LaTeX$ outputs." + ] + }, + { + "cell_type": "markdown", + "id": "3afd34dd-1e89-4adb-877d-b5d53b2f6bb1", + "metadata": {}, + "source": [ + "## Third order spherical harmonics\n", + "\n", + "$$ y_3^0 = \\frac{1}{6} \\sqrt{42} ({y_2^0} z + {y_2^4} x) $$\n", + "\n", + "$$ y_3^1 = \\sqrt{7} y_2^0 y $$\n", + "\n", + "$$ y_3^2 = \\frac{1}{8} \\sqrt{168} (4y^2 - x^2 + z^2) x $$\n", + "\n", + "$$ y_3^3 = \\frac{1}{2} \\sqrt{7} y (2y^2 - 3(x^2 + z^2)) $$\n", + "\n", + "$$ y_3^4 = \\frac{1}{8} \\sqrt{168} z(4y^2 - (x^2 + z^2)) $$\n", + "\n", + "$$ y_3^5 = \\sqrt{7} y_2^4 y $$\n", + "\n", + "$$ y_3^6 = \\frac{1}{6} \\sqrt{42} (y_2^4 z - y_2^0 x) $$" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "4715f559-3660-4718-9834-0f14010d65b4", + "metadata": {}, + "outputs": [], + "source": [ + "y2 = y**2.0\n", + "x2z2 = x**2.0 + z**2.0\n", + "\n", + "y_3_0 = (1 / 6) * math.sqrt(42) * (y_2_0 * z + y_2_4 * x)\n", + "y_3_1 = math.sqrt(7) * y_2_0 * y\n", + "y_3_2 = (1 / 8) * math.sqrt(168) * (4.0 * y2 - x2z2) * x\n", + "y_3_3 = (1 / 2) * math.sqrt(7) * y * (2.0 * y2 - 3.0 * x2z2)\n", + "y_3_4 = (1 / 8) * math.sqrt(168) * z * (4.0 * y2 - x2z2)\n", + "y_3_5 = math.sqrt(7) * y_2_4 * y\n", + "y_3_6 = (1 / 6) * math.sqrt(42) * (y_2_4 * z - y_2_0 * x)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "ab0b268c-6be1-4ee0-9be3-a31fcee50f8a", + "metadata": {}, + "outputs": [], + "source": [ + "third_order = {}\n", + "for index, expr in enumerate([y_3_0, y_3_1, y_3_2, y_3_3, y_3_4, y_3_5, y_3_6]):\n", + " third_order[str(index)] = take_derivative(expr, [x, y, z])" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "8946edc6-b32f-4e02-8ff4-a5ba4beec44d", + "metadata": {}, + "outputs": [], + "source": [ + "third_order_cartesian = collect_derivatives(third_order, 3)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "fb7dda98-ec95-49a5-bc43-a1436299fa8a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\\sqrt{15} \\nabla Y^{0}_{3} \\left(- 1.62018517460196 x^2 + 1.08012344973464 z^{2} + 0.540061724867322 z^2\\right) + 2.64575131106459 \\sqrt{15} \\nabla Y^{1}_{3} y z - \\nabla Y^{2}_{3} \\cdot \\left(4.8605555238059 x^2 - 6.48074069840786 y^2 + 1.62018517460197 z^2\\right) - 7.93725393319377 \\nabla Y^{3}_{3} x y - 3.24037034920393 \\nabla Y^{4}_{3} x z - 2.64575131106459 \\sqrt{15} \\nabla Y^{5}_{3} x y - \\sqrt{15} \\nabla Y^{6}_{3} z \\left(1.08012344973464 x + 2.16024689946929 x\\right)\n" + ] + } + ], + "source": [ + "export_to_latex(third_order_cartesian[\"x\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "1876415f-8a2f-4643-965d-b3428bd3c3cd", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'sqrt_15*g_3_0*(-1.62018517460196*sq_x + 1.08012344973464*sq_z + 0.540061724867322*sq_z) + 2.64575131106459*sqrt_15*g_3_1*y*z - g_3_2*(4.8605555238059*sq_x - 6.48074069840786*sq_y + 1.62018517460197*sq_z) - 7.93725393319377*g_3_3*x*y - 3.24037034920393*g_3_4*x*z - 2.64575131106459*sqrt_15*g_3_5*x*y - sqrt_15*g_3_6*z*(1.08012344973464*x + 2.16024689946929*x)'" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "replace_terms_for_implementation(str(third_order_cartesian[\"x\"]), mapping)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "2c2f2bc9-70bf-4dc5-b564-0a8c74cb240d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2.64575131106459 \\sqrt{15} \\nabla Y^{1}_{3} x z + 12.9614813968157 \\nabla Y^{2}_{3} x y - \\nabla Y^{3}_{3} \\cdot \\left(3.96862696659689 x^2 - 7.93725393319377 y^2 + 3.96862696659689 z^2\\right) + 12.9614813968157 \\nabla Y^{4}_{3} y z - 1.3228756555323 \\sqrt{15} \\nabla Y^{5}_{3} \\left(x^2 - z^2\\right)\n" + ] + } + ], + "source": [ + "export_to_latex(third_order_cartesian[\"y\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "e4764b7f-9a3b-4060-8035-4be6047bbf2e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'2.64575131106459*sqrt_15*g_3_1*x*z + 12.9614813968157*g_3_2*x*y - g_3_3*(3.96862696659689*sq_x - 7.93725393319377*sq_y + 3.96862696659689*sq_z) + 12.9614813968157*g_3_4*y*z - 1.3228756555323*sqrt_15*g_3_5*(sq_x - sq_z)'" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "replace_terms_for_implementation(str(third_order_cartesian[\"y\"]), mapping)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "605858c9-55c9-419b-a099-0dc0a42f66a7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\\sqrt{15} \\nabla Y^{0}_{3} x \\left(1.08012344973464 z + 2.16024689946929 z\\right) + 2.64575131106459 \\sqrt{15} \\nabla Y^{1}_{3} x y - 3.24037034920393 \\nabla Y^{2}_{3} x z - 7.93725393319377 \\nabla Y^{3}_{3} y z - \\nabla Y^{4}_{3} \\cdot \\left(1.62018517460197 x^2 - 6.48074069840786 y^2 + 4.8605555238059 z^2\\right) + 2.64575131106459 \\sqrt{15} \\nabla Y^{5}_{3} y z - \\sqrt{15} \\nabla Y^{6}_{3} \\cdot \\left(1.08012344973464 x^{2} + 0.540061724867322 x^2 - 1.62018517460196 z^2\\right)\n" + ] + } + ], + "source": [ + "export_to_latex(third_order_cartesian[\"z\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "269fbe07-0d33-4a0a-a629-df94d655e8b3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'sqrt_15*g_3_0*x*(1.08012344973464*z + 2.16024689946929*z) + 2.64575131106459*sqrt_15*g_3_1*x*y - 3.24037034920393*g_3_2*x*z - 7.93725393319377*g_3_3*y*z - g_3_4*(1.62018517460197*sq_x - 6.48074069840786*sq_y + 4.8605555238059*sq_z) + 2.64575131106459*sqrt_15*g_3_5*y*z - sqrt_15*g_3_6*(1.08012344973464*sq_x + 0.540061724867322*sq_x - 1.62018517460196*sq_z)'" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "replace_terms_for_implementation(str(third_order_cartesian[\"z\"]), mapping)" + ] + }, + { + "cell_type": "markdown", + "id": "499bb11f-f7bd-46d1-abb2-d099425f9cf1", + "metadata": {}, + "source": [ + "## Fourth order spherical harmonics" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "43345b97-4533-47a5-a2cf-cece105c8df5", + "metadata": {}, + "outputs": [], + "source": [ + "y_4_0 = (3 / 4) * math.sqrt(2) * (y_3_0 * z + y_3_6 * x)\n", + "y_4_1 = (\n", + " (3 / 4) * y_3_0 * y\n", + " + (3 / 8) * math.sqrt(6) * y_3_1 * z\n", + " + (3 / 8) * math.sqrt(6) * y_3_5 * x\n", + ")\n", + "y_4_2 = (\n", + " -3 / 56 * math.sqrt(14) * y_3_0 * z\n", + " + (3 / 14) * math.sqrt(21) * y_3_1 * y\n", + " + (3 / 56) * math.sqrt(210) * y_3_2 * z\n", + " + (3 / 56) * math.sqrt(210) * y_3_4 * x\n", + " + (3 / 56) * math.sqrt(14) * y_3_6 * x\n", + ")\n", + "y_4_3 = (\n", + " -3 / 56 * math.sqrt(42) * y_3_1 * z\n", + " + (3 / 28) * math.sqrt(105) * y_3_2 * y\n", + " + (3 / 28) * math.sqrt(70) * y_3_3 * x\n", + " + (3 / 56) * math.sqrt(42) * y_3_5 * x\n", + ")\n", + "y_4_4 = (\n", + " -3 / 28 * math.sqrt(42) * y_3_2 * x\n", + " + (3 / 7) * math.sqrt(7) * y_3_3 * y\n", + " - 3 / 28 * math.sqrt(42) * y_3_4 * z\n", + ")\n", + "y_4_5 = (\n", + " -3 / 56 * math.sqrt(42) * y_3_1 * x\n", + " + (3 / 28) * math.sqrt(70) * y_3_3 * z\n", + " + (3 / 28) * math.sqrt(105) * y_3_4 * y\n", + " - 3 / 56 * math.sqrt(42) * y_3_5 * z\n", + ")\n", + "y_4_6 = (\n", + " -3 / 56 * math.sqrt(14) * y_3_0 * x\n", + " - 3 / 56 * math.sqrt(210) * y_3_2 * x\n", + " + (3 / 56) * math.sqrt(210) * y_3_4 * z\n", + " + (3 / 14) * math.sqrt(21) * y_3_5 * y\n", + " - 3 / 56 * math.sqrt(14) * y_3_6 * z\n", + ")\n", + "y_4_7 = (\n", + " -3 / 8 * math.sqrt(6) * y_3_1 * x\n", + " + (3 / 8) * math.sqrt(6) * y_3_5 * z\n", + " + (3 / 4) * y_3_6 * y\n", + ")\n", + "y_4_8 = (3 / 4) * math.sqrt(2) * (-y_3_0 * x + y_3_6 * z)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "f2304c2c-1416-475d-b4a6-105ebfaadd60", + "metadata": {}, + "outputs": [], + "source": [ + "fourth_order = {}\n", + "for index, expr in enumerate(\n", + " [y_4_0, y_4_1, y_4_2, y_4_3, y_4_4, y_4_5, y_4_6, y_4_7, y_4_8]\n", + "):\n", + " fourth_order[str(index)] = take_derivative(expr, [x, y, z])" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "baa112cb-04ca-40ac-b422-9ea72044c9bf", + "metadata": {}, + "outputs": [], + "source": [ + "fourth_order_cartesian = collect_derivatives(fourth_order, 4)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "12b5a3b5-0388-4690-a4a8-8d0aa5235028", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'-sqrt_15*g_4_0*(3.43693177121688*sq_x*z + 3.43693177121688*sq_x*z - 1.14564392373896*cu_z - 1.14564392373896*cu_z) + sqrt_15*g_4_1*y*(-4.8605555238059*sq_x + 3.24037034920393*sq_z + 1.62018517460197*sq_z) - g_4_2*(0.649519052838329*sqrt_15*sq_x*z - 2.77555756156289e-17*sqrt_15*sq_x*z + 7.54672942406179*sq_x*z - 2.59807621135332*sqrt_15*sq_y*z - 10.0623058987491*sq_y*z + 0.21650635094611*sqrt_15*cu_z + 2.51557647468726*cu_z) - g_4_3*y*(0.918558653543692*sqrt_15*sq_x + 16.0090306546024*sq_x - 9.48683298050514*sq_y + 0.918558653543692*sqrt_15*sq_z + 5.33634355153414*sq_z + 0.459279326771846*sqrt_15*(sq_x - sq_z)) + g_4_4*(-9.0*x*sq_y + 2.25*x*sq_z - 9.0*x*sq_y + 2.25*x*sq_z + 4.5*cu_x) - g_4_5*y*z*(-0.918558653543692*sqrt_15*x + 10.6726871030683*x + 1.83711730708738*sqrt_15*x) - g_4_6*(2.59807621135332*sqrt_15*x*sq_y - 0.21650635094611*sqrt_15*x*sq_z + 2.51557647468726*x*sq_z + 10.0623058987491*x*sq_y - 2.51557647468726*x*sq_z + 0.21650635094611*sqrt_15*x*sq_z - 5.03115294937453*cu_x - 0.433012701892219*sqrt_15*cu_x) - sqrt_15*g_4_7*y*z*(3.24037034920393*x + 6.48074069840786*x) - sqrt_15*g_4_8*(1.14564392373896*x*sq_z + 4.58257569495584*x*sq_z + 1.14564392373896*x*sq_z - 2.29128784747792*cu_x)'" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "replace_terms_for_implementation(str(fourth_order_cartesian[\"x\"]), mapping)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "de352112-a997-4fd6-ad24-91c98ef0c2d6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "- \\sqrt{15} \\nabla Y^{0}_{4} \\cdot \\left(3.43693177121688 x^{2} z + 3.43693177121688 x^2 z - 1.14564392373896 z^{3} - 1.14564392373896 z^3\\right) + \\sqrt{15} \\nabla Y^{1}_{4} y \\left(- 4.8605555238059 x^2 + 3.24037034920393 z^{2} + 1.62018517460197 z^2\\right) - \\nabla Y^{2}_{4} \\cdot \\left(0.649519052838329 \\sqrt{15} x^{2} z - 2.77555756156289 \\cdot 10^{-17} \\sqrt{15} x^2 z + 7.54672942406179 x^2 z - 2.59807621135332 \\sqrt{15} y^{2} z - 10.0623058987491 y^2 z + 0.21650635094611 \\sqrt{15} z^{3} + 2.51557647468726 z^3\\right) - \\nabla Y^{3}_{4} y \\left(0.918558653543692 \\sqrt{15} x^2 + 16.0090306546024 x^2 - 9.48683298050514 y^2 + 0.918558653543692 \\sqrt{15} z^{2} + 5.33634355153414 z^2 + 0.459279326771846 \\sqrt{15} \\left(x^2 - z^2\\right)\\right) + \\nabla Y^{4}_{4} \\left(- 9.0 x y^{2} + 2.25 x z^{2} - 9.0 x y^2 + 2.25 x z^2 + 4.5 x^3\\right) - \\nabla Y^{5}_{4} y z \\left(- 0.918558653543692 \\sqrt{15} x + 10.6726871030683 x + 1.83711730708738 \\sqrt{15} x\\right) - \\nabla Y^{6}_{4} \\cdot \\left(2.59807621135332 \\sqrt{15} x y^{2} - 0.21650635094611 \\sqrt{15} x z^{2} + 2.51557647468726 x z^{2} + 10.0623058987491 x y^2 - 2.51557647468726 x z^2 + 0.21650635094611 \\sqrt{15} x z^2 - 5.03115294937453 x^3 - 0.433012701892219 \\sqrt{15} x^3\\right) - \\sqrt{15} \\nabla Y^{7}_{4} y z \\left(3.24037034920393 x + 6.48074069840786 x\\right) - \\sqrt{15} \\nabla Y^{8}_{4} \\cdot \\left(1.14564392373896 x z^{2} + 4.58257569495584 x z^{2} + 1.14564392373896 x z^2 - 2.29128784747792 x^3\\right)\n" + ] + } + ], + "source": [ + "export_to_latex(fourth_order_cartesian[\"x\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "64ec431e-712a-4c77-a4d6-987de5d9de47", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'sqrt_15*g_4_1*x*(-1.62018517460197*sq_x + 3.24037034920393*sq_z + 1.62018517460197*sq_z) + g_4_2*x*z*(5.19615242270663*sqrt_15*y + 20.1246117974981*y) - g_4_3*x*(5.33634355153414*sq_x - 28.4604989415154*sq_y + 0.918558653543692*sqrt_15*sq_z + 5.33634355153414*sq_z + 0.459279326771846*sqrt_15*(sq_x - sq_z)) - g_4_4*(9.0*sq_x*y + 9.0*sq_x*y + 9.0*y*sq_z + 9.0*y*sq_z - 12.0*cu_y) - g_4_5*z*(0.918558653543692*sqrt_15*sq_x + 5.33634355153414*sq_x - 28.4604989415154*sq_y + 5.33634355153414*sq_z - 0.459279326771846*sqrt_15*(sq_x - sq_z)) - g_4_6*(10.0623058987491*sq_x*y - 10.0623058987491*y*sq_z + 2.59807621135332*sqrt_15*y*(sq_x - sq_z)) - sqrt_15*g_4_7*z*(3.24037034920393*sq_x + 1.62018517460197*sq_x - 1.62018517460197*sq_z)'" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "replace_terms_for_implementation(str(fourth_order_cartesian[\"y\"]), mapping)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "766d8625-a5a5-4763-b3e9-7f4534f00a19", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\\sqrt{15} \\nabla Y^{1}_{4} x \\left(- 1.62018517460197 x^2 + 3.24037034920393 z^{2} + 1.62018517460197 z^2\\right) + \\nabla Y^{2}_{4} x z \\left(5.19615242270663 \\sqrt{15} y + 20.1246117974981 y\\right) - \\nabla Y^{3}_{4} x \\left(5.33634355153414 x^2 - 28.4604989415154 y^2 + 0.918558653543692 \\sqrt{15} z^{2} + 5.33634355153414 z^2 + 0.459279326771846 \\sqrt{15} \\left(x^2 - z^2\\right)\\right) - \\nabla Y^{4}_{4} \\cdot \\left(9.0 x^{2} y + 9.0 x^2 y + 9.0 y z^{2} + 9.0 y z^2 - 12.0 y^3\\right) - \\nabla Y^{5}_{4} z \\left(0.918558653543692 \\sqrt{15} x^{2} + 5.33634355153414 x^2 - 28.4604989415154 y^2 + 5.33634355153414 z^2 - 0.459279326771846 \\sqrt{15} \\left(x^2 - z^2\\right)\\right) - \\nabla Y^{6}_{4} \\cdot \\left(10.0623058987491 x^{2} y - 10.0623058987491 y z^{2} + 2.59807621135332 \\sqrt{15} y \\left(x^2 - z^2\\right)\\right) - \\sqrt{15} \\nabla Y^{7}_{4} z \\left(3.24037034920393 x^{2} + 1.62018517460197 x^2 - 1.62018517460197 z^2\\right)\n" + ] + } + ], + "source": [ + "export_to_latex(fourth_order_cartesian[\"y\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "08441fb6-d16d-456c-887a-c8a9a1566cbb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'-sqrt_15*g_4_0*(1.14564392373896*cu_x - 3.43693177121688*x*sq_z - 3.43693177121688*x*sq_z + 1.14564392373896*cu_x) + sqrt_15*g_4_1*x*y*(3.24037034920393*z + 6.48074069840786*z) - g_4_2*(0.21650635094611*sqrt_15*cu_x - 2.59807621135332*sqrt_15*x*sq_y - 10.0623058987491*x*sq_y + 0.649519052838329*sqrt_15*x*sq_z - 2.77555756156289e-17*sqrt_15*x*sq_z + 7.54672942406179*x*sq_z + 2.51557647468726*cu_x) - g_4_3*x*y*(-0.918558653543692*sqrt_15*z + 10.6726871030683*z + 1.83711730708738*sqrt_15*z) + g_4_4*(2.25*sq_x*z + 2.25*sq_x*z - 9.0*sq_y*z - 9.0*sq_y*z + 4.5*cu_z) - g_4_5*y*(0.918558653543692*sqrt_15*sq_x + 5.33634355153414*sq_x - 9.48683298050514*sq_y + 0.918558653543692*sqrt_15*sq_z + 16.0090306546024*sq_z - 0.459279326771846*sqrt_15*(sq_x - sq_z)) + g_4_6*(-0.21650635094611*sqrt_15*sq_x*z + 2.51557647468726*sq_x*z - 2.51557647468726*sq_x*z + 0.21650635094611*sqrt_15*sq_x*z + 2.59807621135332*sqrt_15*sq_y*z + 10.0623058987491*sq_y*z - 5.03115294937453*cu_z - 0.433012701892219*sqrt_15*cu_z) - sqrt_15*g_4_7*y*(3.24037034920393*sq_x + 1.62018517460197*sq_x - 4.8605555238059*sq_z) - sqrt_15*g_4_8*(1.14564392373896*sq_x*z + 4.58257569495584*sq_x*z + 1.14564392373896*sq_x*z - 2.29128784747792*cu_z)'" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "replace_terms_for_implementation(str(fourth_order_cartesian[\"z\"]), mapping)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "09594841-4dd0-4f32-92cc-2234c5fdc0f9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "- \\sqrt{15} \\nabla Y^{0}_{4} \\cdot \\left(1.14564392373896 x^{3} - 3.43693177121688 x z^{2} - 3.43693177121688 x z^2 + 1.14564392373896 x^3\\right) + \\sqrt{15} \\nabla Y^{1}_{4} x y \\left(3.24037034920393 z + 6.48074069840786 z\\right) - \\nabla Y^{2}_{4} \\cdot \\left(0.21650635094611 \\sqrt{15} x^{3} - 2.59807621135332 \\sqrt{15} x y^{2} - 10.0623058987491 x y^2 + 0.649519052838329 \\sqrt{15} x z^{2} - 2.77555756156289 \\cdot 10^{-17} \\sqrt{15} x z^2 + 7.54672942406179 x z^2 + 2.51557647468726 x^3\\right) - \\nabla Y^{3}_{4} x y \\left(- 0.918558653543692 \\sqrt{15} z + 10.6726871030683 z + 1.83711730708738 \\sqrt{15} z\\right) + \\nabla Y^{4}_{4} \\cdot \\left(2.25 x^{2} z + 2.25 x^2 z - 9.0 y^{2} z - 9.0 y^2 z + 4.5 z^3\\right) - \\nabla Y^{5}_{4} y \\left(0.918558653543692 \\sqrt{15} x^{2} + 5.33634355153414 x^2 - 9.48683298050514 y^2 + 0.918558653543692 \\sqrt{15} z^2 + 16.0090306546024 z^2 - 0.459279326771846 \\sqrt{15} \\left(x^2 - z^2\\right)\\right) + \\nabla Y^{6}_{4} \\left(- 0.21650635094611 \\sqrt{15} x^{2} z + 2.51557647468726 x^{2} z - 2.51557647468726 x^2 z + 0.21650635094611 \\sqrt{15} x^2 z + 2.59807621135332 \\sqrt{15} y^{2} z + 10.0623058987491 y^2 z - 5.03115294937453 z^3 - 0.433012701892219 \\sqrt{15} z^3\\right) - \\sqrt{15} \\nabla Y^{7}_{4} y \\left(3.24037034920393 x^{2} + 1.62018517460197 x^2 - 4.8605555238059 z^2\\right) - \\sqrt{15} \\nabla Y^{8}_{4} \\cdot \\left(1.14564392373896 x^{2} z + 4.58257569495584 x^{2} z + 1.14564392373896 x^2 z - 2.29128784747792 z^3\\right)\n" + ] + } + ], + "source": [ + "export_to_latex(fourth_order_cartesian[\"z\"])" + ] + }, + { + "cell_type": "markdown", + "id": "757b0874-b3ae-4c3c-b3e9-aa5bb9a5277f", + "metadata": {}, + "source": [ + "## Fifth order spherical harmonics\n", + "\n", + "These have yet to be implemented in EquiTriton, but goes to show that higher order derivatives can be easily obtained." + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "20d7bfb9-1613-4014-9cec-e6ca24d6be3a", + "metadata": {}, + "outputs": [], + "source": [ + "y_5_0 = (1 / 10) * math.sqrt(110) * (y_4_0 * z + y_4_8 * x)\n", + "y_5_1 = (\n", + " (1 / 5) * math.sqrt(11) * y_4_0 * y\n", + " + (1 / 5) * math.sqrt(22) * y_4_1 * z\n", + " + (1 / 5) * math.sqrt(22) * y_4_7 * x\n", + ")\n", + "y_5_2 = (\n", + " -1 / 30 * math.sqrt(22) * y_4_0 * z\n", + " + (4 / 15) * math.sqrt(11) * y_4_1 * y\n", + " + (1 / 15) * math.sqrt(154) * y_4_2 * z\n", + " + (1 / 15) * math.sqrt(154) * y_4_6 * x\n", + " + (1 / 30) * math.sqrt(22) * y_4_8 * x\n", + ")\n", + "y_5_3 = (\n", + " -1 / 30 * math.sqrt(66) * y_4_1 * z\n", + " + (1 / 15) * math.sqrt(231) * y_4_2 * y\n", + " + (1 / 30) * math.sqrt(462) * y_4_3 * z\n", + " + (1 / 30) * math.sqrt(462) * y_4_5 * x\n", + " + (1 / 30) * math.sqrt(66) * y_4_7 * x\n", + ")\n", + "y_5_4 = (\n", + " -1 / 15 * math.sqrt(33) * y_4_2 * z\n", + " + (2 / 15) * math.sqrt(66) * y_4_3 * y\n", + " + (1 / 15) * math.sqrt(165) * y_4_4 * x\n", + " + (1 / 15) * math.sqrt(33) * y_4_6 * x\n", + ")\n", + "y_5_5 = (\n", + " -1 / 15 * math.sqrt(110) * y_4_3 * x\n", + " + (1 / 3) * math.sqrt(11) * y_4_4 * y\n", + " - 1 / 15 * math.sqrt(110) * y_4_5 * z\n", + ")\n", + "y_5_6 = (\n", + " -1 / 15 * math.sqrt(33) * y_4_2 * x\n", + " + (1 / 15) * math.sqrt(165) * y_4_4 * z\n", + " + (2 / 15) * math.sqrt(66) * y_4_5 * y\n", + " - 1 / 15 * math.sqrt(33) * y_4_6 * z\n", + ")\n", + "y_5_7 = (\n", + " -1 / 30 * math.sqrt(66) * y_4_1 * x\n", + " - 1 / 30 * math.sqrt(462) * y_4_3 * x\n", + " + (1 / 30) * math.sqrt(462) * y_4_5 * z\n", + " + (1 / 15) * math.sqrt(231) * y_4_6 * y\n", + " - 1 / 30 * math.sqrt(66) * y_4_7 * z\n", + ")\n", + "y_5_8 = (\n", + " -1 / 30 * math.sqrt(22) * y_4_0 * x\n", + " - 1 / 15 * math.sqrt(154) * y_4_2 * x\n", + " + (1 / 15) * math.sqrt(154) * y_4_6 * z\n", + " + (4 / 15) * math.sqrt(11) * y_4_7 * y\n", + " - 1 / 30 * math.sqrt(22) * y_4_8 * z\n", + ")\n", + "y_5_9 = (\n", + " -1 / 5 * math.sqrt(22) * y_4_1 * x\n", + " + (1 / 5) * math.sqrt(22) * y_4_7 * z\n", + " + (1 / 5) * math.sqrt(11) * y_4_8 * y\n", + ")\n", + "y_5_10 = (1 / 10) * math.sqrt(110) * (-y_4_0 * x + y_4_8 * z)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "02741bb4-80c6-425d-8396-09025b5ded62", + "metadata": {}, + "outputs": [], + "source": [ + "fifth_order = {}\n", + "for index, expr in enumerate(\n", + " [y_5_0, y_5_1, y_5_2, y_5_3, y_5_4, y_5_5, y_5_6, y_5_7, y_5_8, y_5_9, y_5_10]\n", + "):\n", + " fifth_order[str(index)] = take_derivative(expr, [x, y, z])" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "4da426fe-816f-4594-80f4-44ef06d3347f", + "metadata": {}, + "outputs": [], + "source": [ + "fifth_order_cartesian = collect_derivatives(fifth_order, 5)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "ac8ebc65-1000-443b-a543-b7d7690f4fa4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'-sqrt_15*dY5^0*(10.8140533566281*sq_x*sq_z + 1.80234222610469*sq_x*sq_z + 5.40702667831406*sq_x*sq_z - 3.00390371017448*sq_x * sq_x - 1.20156148406979*sq_z * sq_z - 1.80234222610469*sq_z * sq_z) - sqrt_15*dY5^1*y*(11.399013115178*sq_x*z + 11.399013115178*sq_x*z - 3.79967103839267*cu_z - 3.79967103839267*cu_z) - sqrt_15*dY5^10*(1.20156148406979*x*cu_z - 4.80624593627917*cu_x*z + 7.20936890441875*x*cu_z + 3.60468445220938*x*cu_z - 7.20936890441875*cu_x*z) - dY5^2*(12.4869932329604*sq_x*sq_y + 1.07470926301023*sqrt_15*sq_x*sq_z - 3.12174830824011*sq_x*sq_z + 0.537354631505117*sqrt_15*sq_x*sq_z + 7.52296484107164*sqrt_15*sq_x*sq_y - 0.537354631505117*sqrt_15*sq_x*sq_z + 9.36524492472033*sq_x*sq_z - 5.20291384706685*sq_x * sq_x - 0.895591052508528*sqrt_15*sq_x * sq_x - 5.01530989404776*sqrt_15*sq_y*sq_z - 2.50765494702388*sqrt_15*sq_y*sq_z - 12.4869932329604*sq_y*sq_z + 0.358236421003411*sqrt_15*sq_z * sq_z + 0.179118210501706*sqrt_15*sq_z * sq_z + 3.12174830824011*sq_z * sq_z) - dY5^3*(5.26497863243527*sqrt_15*sq_x*y*z + 2.77555756156289e-17*sqrt_15*sq_x*y*z + 30.5867618423396*sq_x*y*z - 2.63248931621764*sqrt_15*cu_y*z + 1.75499287747842*sqrt_15*y*cu_z + 5.55111512312578e-17*sqrt_15*y*cu_z + 10.1955872807799*y*cu_z - 23.789703655153*cu_y*z) + dY5^4*(-17.3410639811979*sq_x*sq_y + 0.248746859276655*sqrt_15*sq_x*sq_z - 0.124373429638327*sqrt_15*sq_x*sq_z + 4.33526599529948*sq_x*sq_z - 28.9017733019965*sq_x*sq_y - 2.98496231131986*sqrt_15*sq_x*sq_y + 0.124373429638327*sqrt_15*sq_x*sq_z + 4.33526599529948*sq_x*sq_z + 0.207289049397212*sqrt_15*sq_x * sq_x + 7.22544332549914*sq_x * sq_x - 1.98997487421324*sqrt_15*sq_y*sq_z - 9.63392443399885*sq_y*sq_z + 0.99498743710662*sqrt_15*sq_y*sq_z - 5.78035466039931*sq_y*sq_z + 12.8452325786651*y**4.0 + 0.082915619758885*sqrt_15*sq_z * sq_z - 0.0414578098794425*sqrt_15*sq_z * sq_z + 1.44508866509983*sq_z * sq_z) + dY5^5*(-9.9498743710662*x*cu_y - 0.642261628933256*sqrt_15*x*y*sq_z + 9.9498743710662*x*y*sq_z + 2.56904651573303*sqrt_15*x*y*sq_z - 0.642261628933256*sqrt_15*x*y*sq_z + 9.9498743710662*x*y*sq_z - 23.2163735324878*x*cu_y + 1.28452325786651*sqrt_15*cu_x*y + 19.8997487421324*cu_x*y) + dY5^6*(-19.2678488679977*x*sq_y*z + 1.98997487421324*sqrt_15*x*sq_y*z - 0.082915619758885*sqrt_15*x*cu_z + 2.89017733019965*x*cu_z + 0.33166247903554*sqrt_15*cu_x*z - 3.97994974842648*sqrt_15*x*sq_y*z - 11.5607093207986*x*sq_y*z + 0.16583123951777*sqrt_15*x*cu_z + 0.082915619758885*sqrt_15*x*cu_z + 2.89017733019965*x*cu_z - 0.16583123951777*sqrt_15*cu_x*z + 5.78035466039931*cu_x*z) - dY5^7*(2.63248931621764*sqrt_15*x*cu_y - 1.75499287747842*sqrt_15*x*y*sq_z + 10.1955872807799*x*y*sq_z + 2.77555756156289e-17*sqrt_15*x*y*sq_z - 10.1955872807799*x*y*sq_z + 1.75499287747842*sqrt_15*x*y*sq_z + 23.789703655153*x*cu_y - 20.3911745615597*cu_x*y - 3.50998575495685*sqrt_15*cu_x*y) + dY5^8*(-5.01530989404776*sqrt_15*x*sq_y*z - 2.08116553882674*x*cu_z + 0.358236421003411*sqrt_15*x*cu_z + 1.43294568401365*sqrt_15*cu_x*z - 10.0306197880955*sqrt_15*x*sq_y*z - 24.9739864659209*x*sq_y*z + 0.716472842006823*sqrt_15*x*cu_z - 0.358236421003411*sqrt_15*x*cu_z + 6.24349661648022*x*cu_z + 0.716472842006822*sqrt_15*cu_x*z + 12.4869932329604*cu_x*z) - sqrt_15*dY5^9*y*(3.79967103839267*x*sq_z + 15.1986841535707*x*sq_z + 3.79967103839267*x*sq_z - 7.59934207678533*cu_x)'" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "replace_terms_for_implementation(str(fifth_order_cartesian[\"x\"]), mapping)" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "d1eddfd6-2632-473c-98f1-144cbb8bbc45", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'-3.79967103839267*sqrt_15*dY5^1*(cu_x*z - x*cu_z - x*cu_z + cu_x*z) + dY5^2*(-8.32466215530697*cu_x*y + 24.9739864659209*x*y*sq_z + 10.0306197880955*sqrt_15*x*y*sq_z + 5.01530989404776*sqrt_15*x*y*sq_z - 5.01530989404776*sqrt_15*cu_x*y) - dY5^3*(1.75499287747842*sqrt_15*cu_x*z - 7.89746794865291*sqrt_15*x*sq_y*z - 71.369110965459*x*sq_y*z + 1.75499287747842*sqrt_15*x*cu_z + 10.1955872807799*x*cu_z + 10.1955872807799*cu_x*z) - dY5^4*(11.5607093207986*cu_x*y + 11.5607093207986*x*y*sq_z + 3.97994974842648*sqrt_15*x*y*sq_z - 1.98997487421324*sqrt_15*x*y*sq_z + 19.2678488679977*x*y*sq_z - 51.3809303146605*x*cu_y + 1.98997487421324*sqrt_15*cu_x*y + 19.2678488679977*cu_x*y) + dY5^5*(-34.8245602987317*sq_x*sq_y + 1.28452325786651*sqrt_15*sq_x*sq_z - 0.321130814466628*sqrt_15*sq_x*sq_z + 4.9749371855331*sq_x*sq_z - 14.9248115565993*sq_x*sq_y - 0.321130814466628*sqrt_15*sq_x*sq_z + 4.9749371855331*sq_x*sq_z + 0.321130814466628*sqrt_15*sq_x * sq_x + 4.9749371855331*sq_x * sq_x - 14.9248115565993*sq_y*sq_z - 34.8245602987317*sq_y*sq_z + 16.583123951777*y**4.0 + 0.321130814466628*sqrt_15*sq_z * sq_z + 4.9749371855331*sq_z * sq_z) - dY5^6*(11.5607093207986*sq_x*y*z + 3.97994974842648*sqrt_15*sq_x*y*z - 1.98997487421324*sqrt_15*sq_x*y*z + 19.2678488679977*sq_x*y*z + 11.5607093207986*y*cu_z + 1.98997487421324*sqrt_15*y*cu_z + 19.2678488679977*y*cu_z - 51.3809303146605*cu_y*z) - dY5^7*(35.6845554827295*sq_x*sq_y - 5.09779364038993*sq_x*sq_z + 0.877496438739212*sqrt_15*sq_x*sq_z + 3.94873397432645*sqrt_15*sq_x*sq_y - 0.877496438739212*sqrt_15*sq_x*sq_z + 5.09779364038993*sq_x*sq_z - 5.09779364038993*sq_x * sq_x - 0.877496438739212*sqrt_15*sq_x * sq_x - 3.94873397432645*sqrt_15*sq_y*sq_z - 35.6845554827295*sq_y*sq_z + 0.877496438739212*sqrt_15*sq_z * sq_z + 5.09779364038993*sq_z * sq_z) - dY5^8*(24.9739864659209*sq_x*y*z + 10.0306197880955*sqrt_15*sq_x*y*z + 5.01530989404776*sqrt_15*sq_x*y*z - 8.32466215530696*y*cu_z - 5.01530989404776*sqrt_15*y*cu_z) - sqrt_15*dY5^9*(7.59934207678533*sq_x*sq_z + 1.89983551919633*sq_x*sq_z + 1.89983551919633*sq_x*sq_z - 1.89983551919633*sq_x * sq_x - 1.89983551919633*sq_z * sq_z)'" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "replace_terms_for_implementation(str(fifth_order_cartesian[\"y\"]), mapping)" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "088e5fce-408d-4b93-a566-c7421ad7966e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'-sqrt_15*dY5^0*(1.20156148406979*cu_x*z + 7.20936890441875*cu_x*z - 4.80624593627917*x*cu_z - 7.20936890441875*x*cu_z + 3.60468445220938*cu_x*z) - sqrt_15*dY5^1*y*(3.79967103839267*cu_x - 11.399013115178*x*sq_z - 11.399013115178*x*sq_z + 3.79967103839267*cu_x) + sqrt_15*dY5^10*(1.20156148406979*sq_x * sq_x - 10.8140533566281*sq_x*sq_z - 5.40702667831406*sq_x*sq_z - 1.80234222610469*sq_x*sq_z + 1.80234222610469*sq_x * sq_x + 3.00390371017448*sq_z * sq_z) - dY5^2*(-2.08116553882674*cu_x*z + 0.358236421003411*sqrt_15*cu_x*z + 0.716472842006822*sqrt_15*cu_x*z - 5.01530989404776*sqrt_15*x*sq_y*z - 10.0306197880955*sqrt_15*x*sq_y*z - 24.9739864659209*x*sq_y*z + 1.43294568401365*sqrt_15*x*cu_z + 0.716472842006823*sqrt_15*x*cu_z + 12.4869932329604*x*cu_z - 0.358236421003411*sqrt_15*cu_x*z + 6.24349661648022*cu_x*z) - dY5^3*(1.75499287747842*sqrt_15*cu_x*y - 2.63248931621764*sqrt_15*x*cu_y + 5.26497863243527*sqrt_15*x*y*sq_z + 2.77555756156289e-17*sqrt_15*x*y*sq_z + 30.5867618423396*x*y*sq_z - 23.789703655153*x*cu_y + 4.16333634234434e-17*sqrt_15*cu_x*y + 10.1955872807799*cu_x*y) + dY5^4*(-0.082915619758885*sqrt_15*cu_x*z + 2.89017733019965*cu_x*z + 0.16583123951777*sqrt_15*cu_x*z - 19.2678488679977*x*sq_y*z + 1.98997487421324*sqrt_15*x*sq_y*z - 3.97994974842648*sqrt_15*x*sq_y*z - 11.5607093207986*x*sq_y*z + 0.33166247903554*sqrt_15*x*cu_z - 0.16583123951777*sqrt_15*x*cu_z + 5.78035466039931*x*cu_z + 0.082915619758885*sqrt_15*cu_x*z + 2.89017733019965*cu_x*z) + dY5^5*(-0.642261628933256*sqrt_15*sq_x*y*z + 9.9498743710662*sq_x*y*z + 2.56904651573303*sqrt_15*sq_x*y*z - 0.642261628933256*sqrt_15*sq_x*y*z + 9.9498743710662*sq_x*y*z - 9.9498743710662*cu_y*z + 1.28452325786651*sqrt_15*y*cu_z + 19.8997487421324*y*cu_z - 23.2163735324878*cu_y*z) + dY5^6*(0.082915619758885*sqrt_15*sq_x * sq_x - 1.98997487421324*sqrt_15*sq_x*sq_y - 5.78035466039931*sq_x*sq_y + 0.248746859276655*sqrt_15*sq_x*sq_z + 0.124373429638327*sqrt_15*sq_x*sq_z + 4.33526599529948*sq_x*sq_z - 9.63392443399885*sq_x*sq_y + 0.99498743710662*sqrt_15*sq_x*sq_y - 0.124373429638327*sqrt_15*sq_x*sq_z + 4.33526599529948*sq_x*sq_z - 0.0414578098794425*sqrt_15*sq_x * sq_x + 1.44508866509983*sq_x * sq_x - 28.9017733019965*sq_y*sq_z - 2.98496231131986*sqrt_15*sq_y*sq_z - 17.3410639811979*sq_y*sq_z + 12.8452325786651*y**4.0 + 0.207289049397212*sqrt_15*sq_z * sq_z + 7.22544332549914*sq_z * sq_z) + dY5^7*(-1.75499287747842*sqrt_15*sq_x*y*z + 10.1955872807799*sq_x*y*z + 5.55111512312578e-17*sqrt_15*sq_x*y*z - 10.1955872807799*sq_x*y*z + 1.75499287747842*sqrt_15*sq_x*y*z + 2.63248931621764*sqrt_15*cu_y*z - 20.3911745615597*y*cu_z - 3.50998575495685*sqrt_15*y*cu_z + 23.789703655153*cu_y*z) + dY5^8*(0.358236421003411*sqrt_15*sq_x * sq_x - 5.01530989404776*sqrt_15*sq_x*sq_y - 12.4869932329604*sq_x*sq_y + 1.07470926301023*sqrt_15*sq_x*sq_z - 0.537354631505117*sqrt_15*sq_x*sq_z + 9.36524492472034*sq_x*sq_z - 2.50765494702388*sqrt_15*sq_x*sq_y - 3.12174830824011*sq_x*sq_z + 0.537354631505117*sqrt_15*sq_x*sq_z + 0.179118210501706*sqrt_15*sq_x * sq_x + 3.12174830824011*sq_x * sq_x + 7.52296484107164*sqrt_15*sq_y*sq_z + 12.4869932329604*sq_y*sq_z - 5.20291384706685*sq_z * sq_z - 0.895591052508528*sqrt_15*sq_z * sq_z) - sqrt_15*dY5^9*y*(3.79967103839267*sq_x*z + 15.1986841535707*sq_x*z + 3.79967103839267*sq_x*z - 7.59934207678533*cu_z)'" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "replace_terms_for_implementation(str(fifth_order_cartesian[\"z\"]), mapping)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c9177cd1-675a-4845-b0f7-266f37d9ec06", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..e2e5acf --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,88 @@ +[build-system] +build-backend = "setuptools.build_meta" +requires = ["setuptools <= 69.5.1"] + +[project] +authors = [ + {"name" = "Intel Corporation", "email" = "none@xyz.com"}, +] +classifiers = [ + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] +dependencies = [ + "psutil", + "triton", + "torch", + "e3nn", + "tqdm" +] +description = "Triton-lang implementations of kernels for equivariant neural networks." +dynamic = ["version", "readme"] +keywords = ["performance", "portability", "triton", "equivariance", "graph", "neural", "networks"] +license = {file = "LICENSE.md"} +name = "equitriton" +requires-python = ">=3.10" + +[project.optional-dependencies] +dev = [ + "pre-commit", + "pytest", + "pytest-pretty", + "jupyter" +] + +[tool.setuptools.dynamic] +readme = {file = ["README.md"]} +version = {attr = "equitriton.__version__"} + +[tool.ruff.lint] +# ignore E741 because simple symbols like l are appropriate +ignore = ["F403", "F405", "E741"] + +[tool.semantic_release] +assets = [] +build_command_env = [] +commit_message = "{version}\n\nAutomatically generated by python-semantic-release" +commit_parser = "angular" +logging_use_named_masks = false +major_on_zero = true +allow_zero_version = true +no_git_verify = false +tag_format = "v{version}" + +[tool.semantic_release.branches.main] +match = "main" +prerelease_token = "rc" +prerelease = false + +[tool.semantic_release.changelog] +template_dir = "templates" +changelog_file = "CHANGELOG.md" +exclude_commit_patterns = [] + +[tool.semantic_release.changelog.environment] +block_start_string = "{%" +block_end_string = "%}" +variable_start_string = "{{" +variable_end_string = "}}" +comment_start_string = "{#" +comment_end_string = "#}" +trim_blocks = false +lstrip_blocks = false +newline_sequence = "\n" +keep_trailing_newline = false +extensions = [] +autoescape = true + +[tool.semantic_release.commit_author] +env = "GIT_COMMIT_AUTHOR" +default = "semantic-release " + +[tool.semantic_release.commit_parser_options] +allowed_tags = ["build", "chore", "ci", "docs", "feat", "fix", "perf", "style", "refactor", "test"] +minor_tags = ["feat"] +patch_tags = ["fix", "perf"] +default_bump_level = 0 diff --git a/scripts/benchmark.py b/scripts/benchmark.py new file mode 100644 index 0000000..d3568f4 --- /dev/null +++ b/scripts/benchmark.py @@ -0,0 +1,134 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: MIT License +from __future__ import annotations + +from argparse import ArgumentParser +from logging import getLogger + +import torch +import numpy as np +import pandas as pd +from e3nn.o3._spherical_harmonics import _spherical_harmonics + +from equitriton.sph_harm.bindings import * +from equitriton.benchmark import benchmark + +""" +This script is used to benchmark the performance of the Triton spherical +harmonics against the original e3nn implementation. + +The script runs kernels a specified number warm up and recorded steps, +and uses them to calculate percentiles for the combined forward and +backward passes. The end result is a CSV file containing these statistics +as a function of the number of nodes. +""" + +logger = getLogger("equitriton.benchmark") +logger.setLevel("INFO") + +triton_bindings = [ + None, + FirstOrderSphericalHarmonics, + SecondOrderSphericalHarmonics, + ThirdOrderSphericalHarmonics, + FourthOrderSphericalHarmonics, +] + +parser = ArgumentParser() +parser.add_argument( + "device", type=str, choices=["xpu", "cuda"], help="Device to profile on." +) +parser.add_argument("l_max", type=int, help="Maximum angular momentum to consider.") +parser.add_argument( + "-n", + "--num_steps", + type=int, + default=100, + help="Total number of steps to profile over.", +) +parser.add_argument( + "-w", + "--warmup_fraction", + type=float, + default=0.1, + help="Fraction of `num_steps` to use as warmup.", +) +parser.add_argument( + "-j", + "--min_log_size", + type=float, + default=2.0, + help="Minimum (log10) number of nodes.", +) +parser.add_argument( + "-k", + "--max_log_size", + type=float, + default=9.0, + help="Maximum (log10) number of nodes.", +) +parser.add_argument( + "-i", "--matrix_samples", type=int, default=20, help="Number of experiments to run." +) + +args = parser.parse_args() + + +@benchmark(num_steps=args.num_steps, warmup_fraction=args.warmup_fraction) +def e3nn_benchmark(tensor_shape: list[int], device: str | torch.device, l_max: int): + joint_tensor = torch.rand(tensor_shape, device=device, requires_grad=True) + x, y, z = ( + joint_tensor[..., 0].contiguous(), + joint_tensor[..., 1].contiguous(), + joint_tensor[..., 2].contiguous(), + ) + output = _spherical_harmonics(l_max, x, y, z) + output.backward(gradient=torch.ones_like(output)) + # delete references to ensure memory gets cleared + del output + del joint_tensor + + +@benchmark(num_steps=args.num_steps, warmup_fraction=args.warmup_fraction) +def triton_benchmark(tensor_shape: list[int], device: str | torch.device, l_max: int): + joint_tensor = torch.rand(tensor_shape, device=device, requires_grad=True) + x, y, z = ( + joint_tensor[..., 0].contiguous(), + joint_tensor[..., 1].contiguous(), + joint_tensor[..., 2].contiguous(), + ) + kernel = triton_bindings[l_max] + output = kernel.apply(x, y, z) + output.backward(gradient=torch.ones_like(output)) + # delete references to ensure memory gets cleared + del output + del joint_tensor + + +n_values = np.linspace(args.min_log_size, args.max_log_size, args.matrix_samples) + +all_data = [] +for N in n_values: + joint_results = {"N": N} + try: + e3nn_prof = e3nn_benchmark( + (int(10**N), 3), device=args.device, l_max=args.l_max + ) + e3nn_stats = np.percentile(np.array(e3nn_prof), [0.05, 0.5, 0.95]) + for key, value in zip(["e3nn 5%", "e3nn 50%", "e3nn 95%"], e3nn_stats): + joint_results[key] = value + except Exception as e: + logger.warn(f"e3nn benchmark failed for 10**{N} shape due to {e}") + try: + triton_prof = triton_benchmark( + (int(10**N), 3), device=args.device, l_max=args.l_max + ) + triton_stats = np.percentile(np.array(triton_prof), [0.05, 0.5, 0.95]) + for key, value in zip(["triton 5%", "triton 50%", "triton 95%"], triton_stats): + joint_results[key] = value + except Exception as e: + logger.warn(f"Triton benchmark failed for 10**{N} shape due to {e}") + all_data.append(joint_results) + +df = pd.DataFrame(all_data) +df.to_csv(f"{args.device}_lmax{args.l_max}_results.csv", index=False) diff --git a/scripts/dynamic_shapes.py b/scripts/dynamic_shapes.py new file mode 100644 index 0000000..61b5d91 --- /dev/null +++ b/scripts/dynamic_shapes.py @@ -0,0 +1,98 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: MIT License +from __future__ import annotations + +from argparse import ArgumentParser +from logging import getLogger +from time import time_ns + +import torch +import numpy as np +import pandas as pd +import triton + +from equitriton.sph_harm import SphericalHarmonics + +SEED = 215616 +torch.manual_seed(SEED) +rng = np.random.default_rng(SEED) + +""" +This script is used to benchmark the performance of the Triton spherical +harmonics over a uniform random number of nodes. The idea behind this script +is to look at the overhead associated with kernel compilation, which can +impact training/inference performance if input shapes change wildly. + +The minimum and maximum number of nodes may need to be tweaked to match +what you might expect based on the data you work with. +""" + +logger = getLogger("equitriton.benchmark") +logger.setLevel("INFO") + + +parser = ArgumentParser() +parser.add_argument( + "device", type=str, choices=["xpu", "cuda"], help="Device to profile on." +) +parser.add_argument("l_max", type=int, help="Maximum angular momentum to consider.") +parser.add_argument( + "-n", + "--num_steps", + type=int, + default=100, + help="Total number of steps to profile over.", +) +parser.add_argument( + "-j", + "--min_log_size", + type=float, + default=2.0, + help="Minimum (log10) number of nodes.", +) +parser.add_argument( + "-k", + "--max_log_size", + type=float, + default=9.0, + help="Maximum (log10) number of nodes.", +) +parser.add_argument("-p", "--pad", action="store_true", help="Enable tensor padding.") + +args = parser.parse_args() + + +def triton_benchmark( + tensor_shape: list[int], device: str | torch.device, l_max: int, pad_tensors: bool +): + joint_tensor = torch.rand(tensor_shape, device=device, requires_grad=True) + sph_harm = SphericalHarmonics(l_max, pad_tensors) + output = sph_harm(joint_tensor) + output.backward(gradient=torch.ones_like(output)) + # delete references to ensure memory gets cleared + del output + del joint_tensor + + +all_data = [] +start_time = time_ns() +last_time = start_time +for _ in range(args.num_steps): + num_nodes = int(10 ** rng.uniform(args.min_log_size, args.max_log_size)) + expect_pad = triton.next_power_of_2(num_nodes) + joint_results = {"N": num_nodes, "pad_size": expect_pad} + try: + triton_benchmark( + (num_nodes, 3), device=args.device, l_max=args.l_max, pad_tensors=args.pad + ) + except Exception as e: + logger.warning(f"Triton benchmark failed for {num_nodes} nodes due to {e}") + end_time = time_ns() + timedelta = (end_time - last_time) * 1e-9 + last_time = end_time + joint_results["timedelta"] = timedelta + all_data.append(joint_results) +logger.info(f"All tests finished in {(last_time - start_time) * 1e-9} seconds.") + +df = pd.DataFrame(all_data) +df.to_csv(f"{args.device}_lmax{args.l_max}_jit_results.csv", index=False) diff --git a/scripts/measure_numerical_error.py b/scripts/measure_numerical_error.py new file mode 100644 index 0000000..91d8929 --- /dev/null +++ b/scripts/measure_numerical_error.py @@ -0,0 +1,178 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: MIT License +from __future__ import annotations + +from argparse import ArgumentParser + +import torch +import numpy as np +from e3nn.o3._spherical_harmonics import _spherical_harmonics + +from equitriton.sph_harm.bindings import * + +""" +This script is used to measure the numerical error between e3nn +and Triton implementations. +""" + +logger = getLogger("equitriton.benchmark") +logger.setLevel("INFO") + +triton_bindings = [ + None, + FirstOrderSphericalHarmonics, + SecondOrderSphericalHarmonics, + ThirdOrderSphericalHarmonics, + FourthOrderSphericalHarmonics, +] + +parser = ArgumentParser() +parser.add_argument( + "l", type=int, choices=[1, 2, 3, 4], help="Maximum number of terms to test." +) +parser.add_argument( + "device", type=str, choices=["xpu", "cuda"], help="Device to profile on." +) +parser.add_argument("l_max", type=int, help="Maximum angular momentum to consider.") +parser.add_argument( + "-n", + "--num_iter", + type=int, + default=1000, + help="Total number of iterations to sample over.", +) +parser.add_argument( + "-i", + "--num_feats", + type=int, + default=5000, + help="Number of nodes/features to compute over.", +) +parser.add_argument( + "--relative", + action="store_true", + help="Flag to calculate relative percentage errors instead of absolute errors.", +) +parser.add_argument( + "-d", + "--dtype", + choices=["float", "float32", "float64"], + help="Precision to perform the tests with.", +) + +args = parser.parse_args() + + +def compare_e3nn_triton( + joint_tensor: torch.Tensor, l_max: int, relative: bool = True +) -> tuple[torch.Tensor, torch.Tensor]: + # clear gradients just in case + joint_tensor.grad = None + x, y, z = ( + joint_tensor[..., 0].contiguous(), + joint_tensor[..., 1].contiguous(), + joint_tensor[..., 2].contiguous(), + ) + e3nn_output = _spherical_harmonics(l_max, x, y, z) + e3nn_output.backward(gradient=torch.ones_like(e3nn_output)) + e3nn_grad = joint_tensor.grad.detach().clone() + joint_tensor.grad = None + # now do the same with the Triton version + kernel = triton_bindings[l_max] + triton_output = kernel.apply(x, y, z) + triton_output.backward(gradient=torch.ones_like(triton_output)) + triton_grad = joint_tensor.grad.detach().clone() + # overzealous with the detachs honestly :P + signed_fwd_error = (e3nn_output - triton_output).detach().cpu().numpy() + if relative: + # compute relative percentage error + signed_fwd_error /= e3nn_output.detach().cpu().numpy() + signed_fwd_error *= 100.0 + signed_bwd_error = (e3nn_grad - triton_grad).detach().cpu().numpy() + if relative: + signed_bwd_error /= e3nn_grad.detach().cpu().numpy() + signed_bwd_error *= 100.0 + # delete intermediate tensors to make sure we don't leak + del e3nn_output + del triton_output + return (signed_fwd_error, signed_bwd_error) + + +def run_test( + num_iter: int, + num_feats: int, + device: str | torch.device, + l_max: int, + percentiles: list[float] | np.ndarray = [0.02, 0.5, 0.98], + relative: bool = True, + dtype: torch.dtype = torch.float, +): + """ + Run a set of numerical error tests comparing the e3nn and Triton forward + and backward results. This is used to quantify, for a given precision, + how far off the Triton result might be from e3nn. + + It is recommended that this is run and understood before replacing + the e3nn kernels with the _EquiTriton_ ones. + + Parameters + ---------- + num_iter : int + Number of iterations to test. This is basically how many + random tensors are going to be initialized. + num_feats + Number of nodes/features per iterations. + device : str | torch.deviec + Device to execute on. + l_max : int + Maximum number of terms to consider. + percentiles : list[float] | np.ndarray + Percentiles to compute statistics with. The default values should + be reasonably descriptive. Keep in mind that, if ``relative`` is + True, then the values reported are in percentage error, as opposed + to absolute error. + relative : bool, default True + If True, computes the relative percentage error by dividing + by the e3nn result. + dtype : torch.dtype, default torch.float + Data type to compute with. + """ + fwd_errors = [] + bwd_errors = [] + for _ in range(num_iter): + joint_tensor = torch.rand( + [num_feats, 3], device=device, requires_grad=True, dtype=dtype + ) + fwd_error, bwd_error = compare_e3nn_triton(joint_tensor, l_max, relative) + fwd_errors.append(fwd_error) + bwd_errors.append(bwd_error) + # get back shape of [num_feats, 3] for binning + fwd_errors = np.vstack(fwd_errors) + bwd_errors = np.vstack(bwd_errors) + # calculate error percentiles along samples dimension; + # output array is [percentiles, xyz] + fwd_percentiles = np.percentile(fwd_errors, percentiles, axis=0) + bwd_percentiles = np.percentile(bwd_errors, percentiles, axis=0) + logger.info( + f"Numerical error analysis for l_max=${l_max} on {device}. {num_iter} iterations using {num_feats} random nodes." + ) + for index, axis in enumerate(["x", "y", "z"]): + logger.info(f"---------- Result for axis: {axis} ----------") + logger.info( + f"Forward signed percentile ({percentiles}) errors: {fwd_percentiles[:,index]}" + ) + logger.info( + f"Backward signed percentile ({percentiles}) errors: {bwd_percentiles[:,index]}" + ) + + +dtype = getattr(torch, args.dtype) + +run_test( + args.num_iter, + args.num_feats, + args.device, + args.l, + relative=args.relative, + dtype=dtype, +) diff --git a/scripts/profile_script.py b/scripts/profile_script.py new file mode 100644 index 0000000..9404177 --- /dev/null +++ b/scripts/profile_script.py @@ -0,0 +1,109 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: MIT License +from __future__ import annotations + +from argparse import ArgumentParser +from logging import getLogger + +import torch +from torch.profiler import record_function +from e3nn.o3._spherical_harmonics import _spherical_harmonics + +from equitriton.sph_harm.bindings import * +from equitriton.benchmark import profile + +""" +Runs the PyTorch profiler on either the Triton or ``e3nn`` kernels. + +This provides a more in-depth analysis into the relative performance +of the kernels, as it produces a timeline for operations performed. +""" + +logger = getLogger("equitriton.benchmark").setLevel("INFO") + +triton_bindings = [ + None, + FirstOrderSphericalHarmonics, + SecondOrderSphericalHarmonics, + ThirdOrderSphericalHarmonics, + FourthOrderSphericalHarmonics, +] + +parser = ArgumentParser() +parser.add_argument( + "device", type=str, choices=["xpu", "cuda"], help="Device to profile on." +) +parser.add_argument("l_max", type=int, help="Maximum angular momentum to consider.") +parser.add_argument( + "-n", + "--num_steps", + type=int, + default=100, + help="Total number of steps to profile over.", +) +parser.add_argument( + "-w", + "--warmup_fraction", + type=float, + default=0.1, + help="Fraction of `num_steps` to use as warmup.", +) +parser.add_argument( + "-p", + "--prefix", + type=str, + default="", + help="Prefix to use for naming this experiment.", +) +parser.add_argument( + "-s", "--size", type=int, default=10_000_000, help="Number of nodes to use." +) + +args = parser.parse_args() + + +@profile( + experiment_name=f"{args.prefix}e3nn_{args.device}", + num_steps=args.num_steps, + warmup_fraction=args.warmup_fraction, +) +def e3nn_benchmark(tensor_shape: list[int], device: str | torch.device, l_max: int): + joint_tensor = torch.rand(tensor_shape, device=device, requires_grad=True) + x, y, z = ( + joint_tensor[..., 0].contiguous(), + joint_tensor[..., 1].contiguous(), + joint_tensor[..., 2].contiguous(), + ) + with record_function("forward"): + output = _spherical_harmonics(l_max, x, y, z) + with record_function("backward"): + output.backward(gradient=torch.ones_like(output)) + # delete references to ensure memory gets cleared + del output + del joint_tensor + + +@profile( + experiment_name=f"{args.prefix}triton_{args.device}", + num_steps=args.num_steps, + warmup_fraction=args.warmup_fraction, +) +def triton_benchmark(tensor_shape: list[int], device: str | torch.device, l_max: int): + joint_tensor = torch.rand(tensor_shape, device=device, requires_grad=True) + x, y, z = ( + joint_tensor[..., 0].contiguous(), + joint_tensor[..., 1].contiguous(), + joint_tensor[..., 2].contiguous(), + ) + kernel = triton_bindings[l_max] + with record_function("forward"): + output = kernel.apply(x, y, z) + with record_function("backward"): + output.backward(gradient=torch.ones_like(output)) + # delete references to ensure memory gets cleared + del output + del joint_tensor + + +e3nn_benchmark(tensor_shape=(args.size, 3), device=args.device, l_max=args.l_max) +triton_benchmark(tensor_shape=(args.size, 3), device=args.device, l_max=args.l_max) diff --git a/src/equitriton/__init__.py b/src/equitriton/__init__.py new file mode 100644 index 0000000..de4cdee --- /dev/null +++ b/src/equitriton/__init__.py @@ -0,0 +1,29 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: MIT License +from __future__ import annotations + +from os import environ +from warnings import warn +from importlib.util import find_spec + +import torch + +__HAS_IPEX__ = True if find_spec("intel_extension_for_pytorch") else False +__HAS_CUDA__ = torch.cuda.is_available() +__HAS_XPU__ = False + +if __HAS_IPEX__: + try: + import intel_extension_for_pytorch # noqa: F401 + + __HAS_XPU__ = torch.xpu.device_count() != 0 + except ImportError as e: + warn(f"Unable to load IPEX due to {e}; XPU may not function.") + +if "PATCH_E3NN" in environ: + _will_patch = bool(environ.get("PATCH_E3NN", False)) + + if _will_patch: + from equitriton import patch # noqa: F401 + +__version__ = "0.1.0" diff --git a/src/equitriton/benchmark.py b/src/equitriton/benchmark.py new file mode 100644 index 0000000..40669a1 --- /dev/null +++ b/src/equitriton/benchmark.py @@ -0,0 +1,144 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: MIT License +from __future__ import annotations + +from importlib import import_module + +import torch +from torch.profiler import ProfilerActivity, schedule +from torch.profiler import profile as torch_profile +from typing import Callable +from functools import wraps +from time import perf_counter_ns +from logging import getLogger, INFO, basicConfig + +from tqdm import tqdm +import numpy as np + +basicConfig() + +__all__ = ["benchmark"] + + +def benchmark( + num_steps: int = 100, + warmup_fraction: float = 0.05, + percentiles: list[float] = [0.05, 0.1, 0.5, 0.9, 0.95], +): + def decorator(func: Callable): + @wraps(func) + def benchmark_func(*args, **kwargs): + logger = getLogger("equitriton.benchmark") + logger.setLevel(INFO) + times = [] + assert ( + warmup_fraction < 1.0 + ), f"Invalid warm up fraction: got {warmup_fraction}" + warmup_steps = int(warmup_fraction * num_steps) + # try and determine the device from kwargs + if "device" in kwargs: + device = kwargs["device"] + if isinstance(device, str) and "xpu" in device: + sync_func = torch.xpu.synchronize + elif isinstance(device, str) and "cuda" in device: + sync_func = torch.cuda.synchronize + elif isinstance(device, torch.device): + device_type = device.type + submodule = import_module(f"torch.{device_type}") + sync_func = getattr(submodule, "synchronize", None) + if not sync_func: + raise NotImplementedError( + f"Device {device} does not have a synchronize function in torch." + ) + else: + device = "unknown device" + sync_func = None + logger.info( + f"Benchmarking {func} on {device} with {num_steps} steps ({warmup_steps} warm up)." + ) + # clear cache + if sync_func: + cache = torch.empty(256_000_000, dtype=torch.int8, device=device) + sync_func() + total_start = perf_counter_ns() + for i in tqdm(range(num_steps), desc=f"{func} on {device}"): + if sync_func: + cache.zero_() + sync_func() + if i > warmup_steps: + start_time = perf_counter_ns() + _ = func(*args, **kwargs) + if sync_func: + sync_func() + if i > warmup_steps: + end_time = perf_counter_ns() + times.append(end_time - start_time) + total_end = perf_counter_ns() + times = np.array(times) / 1e6 # convert to milliseconds + benchmark_percentiles = np.percentile(times, q=percentiles) + end_to_end = (total_end - total_start) / 1e6 + logger.info( + f"{num_steps} took {end_to_end} milliseconds to complete (including warm up!)." + ) + logger.info("Reporting percentiles.") + for per, value in zip(percentiles, benchmark_percentiles): + logger.info(f"{per * 100} percentile - {value} milliseconds") + return times + + return benchmark_func + + return decorator + + +def profile( + experiment_name: str, + num_steps: int = 100, + warmup_fraction: float = 0.05, + repeat: int = 1, + **profile_kwargs, +): + profile_kwargs.setdefault("profile_memory", True) + profile_kwargs.setdefault("with_stack", True) + profile_kwargs.setdefault("record_shapes", True) + + def decorator(func: Callable): + @wraps(func) + def benchmark_func(*args, **kwargs): + logger = getLogger("equitriton.benchmark") + logger.setLevel(INFO) + activities = [ProfilerActivity.CPU] + if "device" in kwargs: + device = kwargs["device"] + if "cuda" in device: + activities.append(ProfilerActivity.CUDA) + profile_kwargs.setdefault("use_cuda", True) + if "xpu" in device: + activities.append(ProfilerActivity.XPU) + sch = schedule( + active=num_steps, + warmup=int(num_steps * warmup_fraction), + wait=0, + repeat=repeat, + ) + logger.info( + f"Profiling {activities} for {num_steps} steps ({int(num_steps * warmup_fraction)} warmup)." + ) + with torch_profile( + activities=activities, schedule=sch, **profile_kwargs + ) as prof_obj: + for _ in tqdm(range(num_steps)): + _ = func(*args, **kwargs) + print(prof_obj.key_averages().table(row_limit=10)) + try: + prof_obj.export_chrome_trace(f"{experiment_name}_trace.json") + except Exception as e: + logger.warn(f"Unable to export trace due to {e}.") + try: + prof_obj.export_memory_timeline(f"{experiment_name}_memory.json") + except Exception as e: + logger.warn(f"Unable to export memory profile due to {e}.") + return prof_obj + + return benchmark_func + + return decorator diff --git a/src/equitriton/patch.py b/src/equitriton/patch.py new file mode 100644 index 0000000..e1b93ec --- /dev/null +++ b/src/equitriton/patch.py @@ -0,0 +1,65 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: MIT License +from __future__ import annotations + +from logging import getLogger +import math + +import torch +from e3nn.o3 import SphericalHarmonics +from equitriton.sph_harm.main import SphericalHarmonics as TritonHarmonics + +""" +This module will monkey patch ``e3nn``: in other words, when loaded, +it will dynamically replace the ``forward`` call for the ``e3nn`` +Spherical Harmonic class that is commonly used in equivariant models with +the _Equitriton_ version (i.e. replace ``torchscript`` kernels with the +``triton`` ones. + +If this behavior is _not_ desired, do _not_ import this module at runtime. +""" + +logger = getLogger("equitriton") + + +def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Patched version of the forward call, which instead relies on Triton + kernels for each value of l_max. + """ + if self.normalize: + x = torch.nn.functional.normalize( + x, dim=-1 + ) # forward 0's instead of nan for zero-radius + + # initialize the spherical harmonics wrapper + if not hasattr(self, "_triton"): + self._triton = TritonHarmonics(self._lmax) + # do the spherical harmonic evaluation with triton kernels instead + sh = self._triton(x) + + if not self._is_range_lmax: + sh = torch.cat( + [sh[..., l * l : (l + 1) * (l + 1)] for l in self._ls_list], # noqa: E741 + dim=-1, + ) + + if self.normalization == "integral": + sh.div_(math.sqrt(4 * math.pi)) + elif self.normalization == "norm": + sh.div_( + torch.cat( + [ + math.sqrt(2 * l + 1) + * torch.ones(2 * l + 1, dtype=sh.dtype, device=sh.device) + for l in self._ls_list # noqa: E741 + ] + ) + ) + + return sh + + +# apply the monkey patch +logger.info("Patching e3nn `SphericalHarmonics.forward` with Triton kernels.") +SphericalHarmonics.forward = forward diff --git a/src/equitriton/sph_harm/__init__.py b/src/equitriton/sph_harm/__init__.py new file mode 100644 index 0000000..f9720dd --- /dev/null +++ b/src/equitriton/sph_harm/__init__.py @@ -0,0 +1,9 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: MIT License +from __future__ import annotations + +from equitriton.sph_harm.main import SphericalHarmonics + +__all__ = [ + "SphericalHarmonics", +] diff --git a/src/equitriton/sph_harm/bindings.py b/src/equitriton/sph_harm/bindings.py new file mode 100644 index 0000000..555f4e0 --- /dev/null +++ b/src/equitriton/sph_harm/bindings.py @@ -0,0 +1,340 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: MIT License +from __future__ import annotations + +import torch +import triton +import numpy as np + +from equitriton.sph_harm import triton_kernels as tk + + +__all__ = [ + "FirstOrderSphericalHarmonics", + "SecondOrderSphericalHarmonics", + "ThirdOrderSphericalHarmonics", + "FourthOrderSphericalHarmonics", +] + + +def _num_projections(l: int) -> int: # noqa: E741 + """Calculate the number of projections of m based on l""" + return 2 * l + 1 + + +def total_projections(l_max: int) -> int: + """Calculate the total number of projects for a given l_max""" + return sum([_num_projections(m) for m in range(l_max + 1)]) + + +def make_output_tensor(x: torch.Tensor, l_max: int) -> list[torch.Tensor]: + """Create a list of tensors with the correct size and mapping to be concatenated afterwards""" + total_num_projections = total_projections(l_max) + last_dim = x.size(-1) + remainder = x.shape[:-1] + # add an extra 1 dimension to the end to facilitate concatenation + output = [ + torch.empty((*remainder, last_dim, 1), dtype=x.dtype, device=x.device) + for _ in range(total_num_projections) + ] + return output + + +def split_tensor_by_l( + joint_tensor: torch.Tensor, l_max: int, dim: int = -1 +) -> list[torch.Tensor]: + """The reverse operation of the concatenate step""" + num_projections = [total_projections(l_value) for l_value in range(l_max + 1)] + proj_indices = list(np.cumsum(num_projections) - 1) + # the first output is empty, so we exclude it + return torch.tensor_split(joint_tensor, proj_indices, dim=dim)[1:] + + +def slice_and_dice_tensor(joint_tensor: torch.Tensor) -> list[torch.Tensor]: + """Completely slices up a tensor along the last dimension, returning N views of an N length dimension.""" + num_slices = joint_tensor.size(-1) + slice_indices = np.arange(num_slices).tolist() + # the first output is empty, so we exclude it + result = torch.tensor_split(joint_tensor, slice_indices, dim=-1)[1:] + return result + + +class FirstOrderSphericalHarmonics(torch.autograd.Function): + """ + First order spherical harmonics. This doesn't actually + even use Triton, but is implement for consistency in the + interface. + """ + + @staticmethod + def forward( + ctx, + x: torch.Tensor, + y: torch.Tensor, + z: torch.Tensor, + mask: torch.Tensor | None = None, + ) -> torch.Tensor: + # for the current parallel model to work, the pointers must be contiguous! + # otherwise the result will be completely scrambled, as the output tensor + # indexing will be mismatched from xyz + # TODO: move this to the high level wrapper + x = x.contiguous() + y = y.contiguous() + z = z.contiguous() + output_tensors = make_output_tensor(x, 1) + output_tensors[0][:] = 1.0 + block_size = 256 + vector_length = x.numel() + # ceiling divide makes sure it works for block sizes larger than + # the total number of samples + num_blocks = triton.next_power_of_2(triton.cdiv(vector_length, block_size)) + tk._triton_first_order_fwd[num_blocks,]( + x, + y, + z, + *output_tensors[1:], + BLOCK_SIZE=block_size, + vector_length=vector_length, + ) + ctx.save_for_backward(x, y, z, mask) + # the expected shape is [..., num_projections] + output = torch.cat(output_tensors, dim=-1) + # remove contributions from padded nodes + if isinstance(mask, torch.Tensor): + output = output[mask] + return output + + @staticmethod + def backward(ctx, grad_output): + # derivative of projections of each spherical harmonic order + # zeroth order is constant and doesn't contribute derivatives + d_sph_0, d_sph_1_x, d_sph_1_y, d_sph_1_z = slice_and_dice_tensor(grad_output) + saved_tensors = ctx.saved_tensors + if len(saved_tensors) == 3: + x, y, z = saved_tensors + mask = None + else: + x, y, z, mask = saved_tensors + # factor of sqrt3 for all values + sqrt3 = 3**0.5 + # we expect three tensors back for xyz + x_grad = d_sph_1_x * sqrt3 + y_grad = d_sph_1_y * sqrt3 + z_grad = d_sph_1_z * sqrt3 + # intended gradients should be shape [num_nodes] per coordinate + return x_grad.squeeze(), y_grad.squeeze(), z_grad.squeeze(), mask + + +class SecondOrderSphericalHarmonics(torch.autograd.Function): + """ + Second order spherical harmonics. A little more involved than + the first order case, and actually gives something interesting + to look at. + """ + + @staticmethod + def forward( + ctx, + x: torch.Tensor, + y: torch.Tensor, + z: torch.Tensor, + mask: torch.Tensor | None = None, + ) -> torch.Tensor: + x = x.contiguous() + y = y.contiguous() + z = z.contiguous() + output_tensors = make_output_tensor(x, 2) + output_tensors[0][:] = 1.0 + block_size = 256 + vector_length = x.numel() + # ceiling divide makes sure it works for block sizes larger than + # the total number of samples + num_blocks = triton.next_power_of_2(triton.cdiv(vector_length, block_size)) + tk._triton_second_order_fwd[num_blocks,]( + x, + y, + z, + *output_tensors[1:], # unpack pointers without verbosity + BLOCK_SIZE=block_size, + vector_length=vector_length, + ) + ctx.save_for_backward(x, y, z) + output = torch.cat(output_tensors, dim=-1) + # remove contributions from padded nodes + if isinstance(mask, torch.Tensor): + output = output[mask] + return output + + @staticmethod + def backward(ctx, grad_output): + # derivative of projections of each spherical harmonic order + # zeroth order is constant and doesn't contribute derivatives + gradient_collection = slice_and_dice_tensor(grad_output) + saved_tensors = ctx.saved_tensors + if len(saved_tensors) == 3: + x, y, z = saved_tensors + mask = None + else: + x, y, z, mask = saved_tensors + x_grad = torch.zeros_like(x) + y_grad = torch.zeros_like(y) + z_grad = torch.zeros_like(z) + block_size = 256 + vector_length = x.numel() + # ceiling divide makes sure it works for block sizes larger than + # the total number of samples + num_blocks = triton.next_power_of_2(triton.cdiv(vector_length, block_size)) + tk._triton_second_order_bwd[num_blocks,]( + x, + y, + z, + x_grad, + y_grad, + z_grad, + *gradient_collection[1:], + BLOCK_SIZE=block_size, + vector_length=vector_length, + ) + return x_grad.squeeze(), y_grad.squeeze(), z_grad.squeeze(), mask + + +class ThirdOrderSphericalHarmonics(torch.autograd.Function): + """ + Third order spherical harmonics. Starting to get more cookiecutter. + """ + + @staticmethod + def forward( + ctx, + x: torch.Tensor, + y: torch.Tensor, + z: torch.Tensor, + mask: torch.Tensor | None = None, + ) -> torch.Tensor: + x = x.contiguous() + y = y.contiguous() + z = z.contiguous() + output_tensors = make_output_tensor(x, 3) + output_tensors[0][:] = 1.0 + block_size = 256 + vector_length = x.numel() + # ceiling divide makes sure it works for block sizes larger than + # the total number of samples + num_blocks = triton.next_power_of_2(triton.cdiv(vector_length, block_size)) + tk._triton_third_order_fwd[num_blocks,]( + x, + y, + z, + *output_tensors[1:], # unpack pointers without verbosity + BLOCK_SIZE=block_size, + vector_length=vector_length, + ) + ctx.save_for_backward(x, y, z, mask) + output = torch.cat(output_tensors, dim=-1) + # remove contributions from padded nodes + if isinstance(mask, torch.Tensor): + output = output[mask] + return output + + @staticmethod + def backward(ctx, grad_output): + # derivative of projections of each spherical harmonic order + # zeroth order is constant and doesn't contribute derivatives + gradient_collection = slice_and_dice_tensor(grad_output) + saved_tensors = ctx.saved_tensors + if len(saved_tensors) == 3: + x, y, z = saved_tensors + mask = None + else: + x, y, z, mask = saved_tensors + x_grad = torch.zeros_like(x) + y_grad = torch.zeros_like(y) + z_grad = torch.zeros_like(z) + block_size = 256 + vector_length = x.numel() + # ceiling divide makes sure it works for block sizes larger than + # the total number of samples + num_blocks = triton.next_power_of_2(triton.cdiv(vector_length, block_size)) + tk._triton_third_order_bwd[num_blocks,]( + x, + y, + z, + x_grad, + y_grad, + z_grad, + *gradient_collection[1:], + BLOCK_SIZE=block_size, + vector_length=vector_length, + ) + return x_grad.squeeze(), y_grad.squeeze(), z_grad.squeeze(), mask + + +class FourthOrderSphericalHarmonics(torch.autograd.Function): + """ + Fourth order spherical harmonics. Starting to get tediuous... + """ + + @staticmethod + def forward( + ctx, + x: torch.Tensor, + y: torch.Tensor, + z: torch.Tensor, + mask: torch.Tensor | None = None, + ) -> torch.Tensor: + x = x.contiguous() + y = y.contiguous() + z = z.contiguous() + output_tensors = make_output_tensor(x, 4) + output_tensors[0][:] = 1.0 + block_size = 256 + vector_length = x.numel() + # ceiling divide makes sure it works for block sizes larger than + # the total number of samples + num_blocks = triton.next_power_of_2(triton.cdiv(vector_length, block_size)) + tk._triton_fourth_order_fwd[num_blocks,]( + x, + y, + z, + *output_tensors[1:], # unpack pointers without verbosity + BLOCK_SIZE=block_size, + vector_length=vector_length, + ) + ctx.save_for_backward(x, y, z, mask) + output = torch.cat(output_tensors, dim=-1) + # remove contributions from padded nodes + if isinstance(mask, torch.Tensor): + output = output[mask] + return output + + @staticmethod + def backward(ctx, grad_output): + # derivative of projections of each spherical harmonic order + # zeroth order is constant and doesn't contribute derivatives + gradient_collection = slice_and_dice_tensor(grad_output) + saved_tensors = ctx.saved_tensors + if len(saved_tensors) == 3: + x, y, z = saved_tensors + mask = None + else: + x, y, z, mask = saved_tensors + x_grad = torch.zeros_like(x) + y_grad = torch.zeros_like(y) + z_grad = torch.zeros_like(z) + block_size = 256 + vector_length = x.numel() + # ceiling divide makes sure it works for block sizes larger than + # the total number of samples + num_blocks = triton.next_power_of_2(triton.cdiv(vector_length, block_size)) + tk._triton_fourth_order_bwd[num_blocks,]( + x, + y, + z, + x_grad, + y_grad, + z_grad, + *gradient_collection[1:], + BLOCK_SIZE=block_size, + vector_length=vector_length, + ) + return x_grad.squeeze(), y_grad.squeeze(), z_grad.squeeze(), mask diff --git a/src/equitriton/sph_harm/main.py b/src/equitriton/sph_harm/main.py new file mode 100644 index 0000000..60edb05 --- /dev/null +++ b/src/equitriton/sph_harm/main.py @@ -0,0 +1,94 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: MIT License +from __future__ import annotations + +from typing import Callable + +import torch +from torch import nn + +from equitriton.sph_harm.bindings import * +from equitriton.utils import pad_tensor_to_power + + +class SphericalHarmonics(nn.Module): + # None is prepended to keep the indexing consistent with l_max + __fwd_kernel_mapping__ = [ + None, + FirstOrderSphericalHarmonics, + SecondOrderSphericalHarmonics, + ThirdOrderSphericalHarmonics, + FourthOrderSphericalHarmonics, + ] + + def __init__(self, lmax: int, pad_tensors: bool = True) -> None: + """ + Initialize a ``SphericalHarmonics`` object that computes + up to some maximum value of ``l``. + + Optionally, to minimize kernel JIT overhead, the option to + pad tensors under the hood is provided: by rounding the + number of nodes to the nearest power of two, we are able + to improve re-use of kernels compiled for specific shapes. + + Parameters + ---------- + lmax : int + Maximum value of ``l`` to use for embedding. + pad_tensors : bool, default True + If set to True, this will pad the number of nodes + up to the nearest power of two. This results in + higher memory usage during the forward pass, but + the tradeoff is minimizing overhead from needing + to recompile kernels for every single batch shape. + + In cases where shapes are expected to be static + (e.g. in MD simulations), this can be safely disabled. + """ + super().__init__() + self.lmax = lmax + self.pad_tensors = pad_tensors + + def _preprocess_tensors( + self, input_tensor: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # last dimension should be xyz + assert ( + input_tensor.size(-1) == 3 + ), f"Expected last input dimension to be 3 (x,y,z). Got {input_tensor.size(-1)}" + # pad tensor if requested + if self.pad_tensors: + input_tensor, mask = pad_tensor_to_power(input_tensor) + self.mask = mask + else: + self.mask = None + # make tensors contiguous for better memory access + x, y, z = ( + input_tensor[..., 0].contiguous(), + input_tensor[..., 1].contiguous(), + input_tensor[..., 2].contiguous(), + ) + return (x, y, z) + + def _determine_kernel(self) -> Callable: + try: + kernel = self.__fwd_kernel_mapping__[self.lmax] + except IndexError as e: + raise NotImplementedError( + f"Kernels only implemented up to lmax = {len(self.__fwd_kernel_mapping__)}." + ) from e + if kernel is None: + raise NotImplementedError( + "Zeroth order kernel is not implemented; it's too trivial 😏" + ) + return kernel + + def _forward( + self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor + ) -> torch.Tensor: + kernel = self._determine_kernel() + return kernel.apply(x, y, z, self.mask) + + def forward(self, input_tensor: torch.Tensor): + x, y, z = self._preprocess_tensors(input_tensor) + return self._forward(x, y, z) diff --git a/src/equitriton/sph_harm/tests/test_correctness.py b/src/equitriton/sph_harm/tests/test_correctness.py new file mode 100644 index 0000000..9ac1209 --- /dev/null +++ b/src/equitriton/sph_harm/tests/test_correctness.py @@ -0,0 +1,111 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: MIT License +from __future__ import annotations +import pytest + +import torch + +from equitriton import __HAS_XPU__, __HAS_CUDA__ +from equitriton.sph_harm import bindings +from e3nn.o3._spherical_harmonics import _spherical_harmonics + +# make sure values are the same every time +torch.manual_seed(3125161) + +""" +This test suite parametrizes over l, device, and tensor shapes to +test for functionality and correctness. + +TODO: expand to parametrize data types as well +""" + +RTOL = 1e-4 +ATOL = 1e-6 + + +@pytest.mark.parametrize("l_func_name", bindings.__all__) +@pytest.mark.parametrize( + "device", + [ + pytest.param( + "xpu", + marks=pytest.mark.skipif(not __HAS_XPU__, reason="No XPUs available."), + ), + pytest.param( + "cuda", + marks=pytest.mark.skipif( + not __HAS_CUDA__, reason="No CUDA GPUs available." + ), + ), + ], +) +@pytest.mark.parametrize("tensor_shape", [(512, 3), (128, 16, 3), (256, 8, 8, 3)]) +def test_bound_kernel(l_func_name, device, tensor_shape): + """ + Iterate through exported autograd bindings, and make sure that + the forward application passes. + """ + l_func = getattr(bindings, l_func_name) + joint_tensor = torch.rand(tensor_shape, device=device, requires_grad=True) + x, y, z = joint_tensor[..., 0], joint_tensor[..., 1], joint_tensor[..., 2] + outputs = l_func.apply(x, y, z) + assert torch.isfinite(outputs).all() + + +@pytest.mark.parametrize("l", [1, 2, 3, 4]) +@pytest.mark.parametrize( + "device", + [ + pytest.param( + "xpu", + marks=pytest.mark.skipif(not __HAS_XPU__, reason="No XPUs available."), + ), + pytest.param( + "cuda", + marks=pytest.mark.skipif( + not __HAS_CUDA__, reason="No CUDA GPUs available." + ), + ), + ], +) +@pytest.mark.parametrize("tensor_shape", [(512, 3), (128, 16, 3), (256, 8, 8, 3)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_correctness_fwd_bwd(l, device, tensor_shape, dtype): + """Compare e3nn and triton results for the forward and backward passes.""" + joint = torch.rand(tensor_shape, device=device, dtype=dtype, requires_grad=True) + x, y, z = joint[..., 0], joint[..., 1], joint[..., 2] + # run the test with e3nn forward then backward + e3nn_result = _spherical_harmonics(l, x, y, z) + e3nn_result.backward(gradient=torch.ones_like(e3nn_result)) + e3nn_grad = joint.grad.clone().detach() + # reset grads for the next round + joint.grad = None + # index the exported bindings, which starts from 1 + l_func_name = bindings.__all__[l - 1] + l_func = getattr(bindings, l_func_name) + triton_result = l_func.apply(x, y, z) + assert triton_result.shape == e3nn_result.shape + # loop over spherical harmonics terms so we can get informative results + dim_mismatchs = [ + torch.allclose(triton_result[..., i], e3nn_result[..., i], atol=ATOL, rtol=RTOL) + for i in range(e3nn_result.size(-1)) + ] + if not all(dim_mismatchs): + bad_dims = [i for i, test in enumerate(dim_mismatchs) if not test] + raise AssertionError( + f"Forward call mismatch on l={l} for dimensions {bad_dims}" + ) + triton_result.backward(gradient=torch.ones_like(triton_result)) + triton_grad = joint.grad.clone().detach() + joint.grad = None + # check the tensor outputs + assert triton_grad.shape == e3nn_grad.shape + dim_mismatchs = [ + torch.allclose(triton_grad[..., i], e3nn_grad[..., i], atol=ATOL, rtol=RTOL) + for i in range(e3nn_grad.size(-1)) + ] + if not all(dim_mismatchs): + bad_dims = [i for i, test in enumerate(dim_mismatchs) if not test] + raise AssertionError( + f"Backward call mismatch on l={l} for dimensions {bad_dims}" + ) diff --git a/src/equitriton/sph_harm/tests/test_main.py b/src/equitriton/sph_harm/tests/test_main.py new file mode 100644 index 0000000..5a64d45 --- /dev/null +++ b/src/equitriton/sph_harm/tests/test_main.py @@ -0,0 +1,58 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: MIT License +from __future__ import annotations + +import pytest +import torch + +from equitriton.sph_harm import SphericalHarmonics +from equitriton import __HAS_XPU__, __HAS_CUDA__ + + +@pytest.mark.parametrize( + "l_max", + [ + pytest.param( + 0, marks=pytest.mark.xfail(reason="Zeroth order not implemented.") + ), + 1, + 2, + 3, + 4, + ], +) +@pytest.mark.parametrize( + "tensor_shape", + [ + (64, 3), + (256, 64, 3), + (512, 128, 8, 3), + pytest.param( + (10, 8, 40, 1), marks=pytest.mark.xfail(reason="Bad last dimension.") + ), + ], +) +@pytest.mark.parametrize( + "device", + [ + pytest.param( + "cuda:0", + marks=pytest.mark.skipif( + not __HAS_CUDA__, reason="No CUDA device available." + ), + ), + pytest.param( + "xpu:0", + marks=pytest.mark.skipif( + not __HAS_XPU__, reason="No XPU device available." + ), + ), + ], +) +@pytest.mark.parametrize("node_padding", [True, False]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.float64]) +def test_main_interface(l_max, tensor_shape, device, node_padding, dtype): + joint_tensor = torch.rand(tensor_shape, device=device, dtype=dtype) + sph_harm = SphericalHarmonics(l_max, pad_tensors=node_padding) + output = sph_harm(joint_tensor) + assert torch.isfinite(output).all() diff --git a/src/equitriton/sph_harm/tests/test_utils.py b/src/equitriton/sph_harm/tests/test_utils.py new file mode 100644 index 0000000..c61fe99 --- /dev/null +++ b/src/equitriton/sph_harm/tests/test_utils.py @@ -0,0 +1,16 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: MIT License +import pytest +import torch + +from equitriton.sph_harm.bindings import split_tensor_by_l, total_projections + + +@pytest.mark.parametrize("l_max", [1, 2, 3]) +def test_split_tensor(l_max): + feat_dim = total_projections(l_max) + # this is equivalent to 128 nodes + expected_output = torch.rand(128, feat_dim) + split_tensors = split_tensor_by_l(expected_output, l_max) + # should be a tensor for each component + assert len(split_tensors) == l_max + 1 diff --git a/src/equitriton/sph_harm/triton_kernels.py b/src/equitriton/sph_harm/triton_kernels.py new file mode 100644 index 0000000..1271127 --- /dev/null +++ b/src/equitriton/sph_harm/triton_kernels.py @@ -0,0 +1,962 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: MIT License +import triton +from triton import language as tl + +__all__ = [ + "_triton_first_order_fwd", + "_triton_second_order_fwd", + "_triton_second_order_bwd", + "_triton_third_order_fwd", + "_triton_third_order_bwd", + "_triton_fourth_order_fwd", + "_triton_fourth_order_bwd", +] + + +@triton.jit +def _triton_first_order_fwd( + x_ptr: tl.tensor, + y_ptr: tl.tensor, + z_ptr: tl.tensor, + sph_1_0_ptr: tl.tensor, + sph_1_1_ptr: tl.tensor, + sph_1_2_ptr: tl.tensor, + BLOCK_SIZE: tl.constexpr, + vector_length: tl.constexpr, +): + """ + First order spherical harmonics in Triton. + + Computationally not that intensive, as we're just applying + a sqrt 3 to the coordinates, but also good for validating + the kernel performs as intended. + + Parameters + ---------- + x_ptr, y_ptr, z_ptr : tl.tensor + Pointers to the coordinate tensors. + sph_1_0_ptr, sph_1_1_ptr, sph_1_2_ptr : tl.tensor + Points to tensors to write outputs to. Assumed to + be the same length as the input tensors. + block_size : tl.constexpr + Vector length of contiguous elements to load into memory + within a given block. + vector_length : tl.constexpr + The maximum/total length of the vectors, assumed to + be the same for every one. This is used to calculate + the mask to keep operations within bounds. + """ + sqrt_3 = 3**0.5 + block_id = tl.program_id(0) + # calculate the offset for this particular thread + offset = tl.arange(0, BLOCK_SIZE) + (BLOCK_SIZE * block_id) + x_row_start = x_ptr + offset + y_row_start = y_ptr + offset + z_row_start = z_ptr + offset + # load in x,y,z to operate on + x = tl.load(x_row_start, mask=offset < vector_length) + y = tl.load(y_row_start, mask=offset < vector_length) + z = tl.load(z_row_start, mask=offset < vector_length) + # now multiply + sph_1_0 = sqrt_3 * x + sph_1_1 = sqrt_3 * y + sph_1_2 = sqrt_3 * z + # work out the pointers for the outputs + sph_1_0_start = sph_1_0_ptr + offset + sph_1_1_start = sph_1_1_ptr + offset + sph_1_2_start = sph_1_2_ptr + offset + tl.store(sph_1_0_start, sph_1_0, mask=offset < vector_length) + tl.store(sph_1_1_start, sph_1_1, mask=offset < vector_length) + tl.store(sph_1_2_start, sph_1_2, mask=offset < vector_length) + + +@triton.jit +def _triton_second_order_fwd( + x_ptr: tl.tensor, + y_ptr: tl.tensor, + z_ptr: tl.tensor, + sh_1_0_ptr: tl.tensor, + sh_1_1_ptr: tl.tensor, + sh_1_2_ptr: tl.tensor, + sh_2_0_ptr: tl.tensor, + sh_2_1_ptr: tl.tensor, + sh_2_2_ptr: tl.tensor, + sh_2_3_ptr: tl.tensor, + sh_2_4_ptr: tl.tensor, + BLOCK_SIZE: tl.constexpr, + vector_length: tl.constexpr, +): + sqrt_3 = 3**0.5 + block_id = tl.program_id(0) + # calculate the offset for this particular thread + offset = tl.arange(0, BLOCK_SIZE) + (BLOCK_SIZE * block_id) + x_row_start = x_ptr + offset + y_row_start = y_ptr + offset + z_row_start = z_ptr + offset + # load in x,y,z to operate on + x = tl.load(x_row_start, mask=offset < vector_length) + y = tl.load(y_row_start, mask=offset < vector_length) + z = tl.load(z_row_start, mask=offset < vector_length) + # compute first order terms + sh_1_0 = x * sqrt_3 + sh_1_1 = y * sqrt_3 + sh_1_2 = z * sqrt_3 + # now work on second order + sqrt_15 = 15**0.5 + sqrt_5 = 5**0.5 + sq_x = x * x + sq_y = y * y + sq_z = z * z + # compute each component + sh_2_0 = sqrt_15 * x * z + sh_2_1 = sqrt_15 * x * y + # these two appear swapped, but they are consistent with e3nn + sh_2_2 = sqrt_5 * (sq_y - 0.5 * (sq_x + sq_z)) + sh_2_3 = sqrt_15 * y * z + sh_2_4 = 0.5 * sqrt_15 * (sq_z - sq_x) + # write the results to memory + sh_1_0_start = sh_1_0_ptr + offset + sh_1_1_start = sh_1_1_ptr + offset + sh_1_2_start = sh_1_2_ptr + offset + sh_2_0_start = sh_2_0_ptr + offset + sh_2_1_start = sh_2_1_ptr + offset + sh_2_2_start = sh_2_2_ptr + offset + sh_2_3_start = sh_2_3_ptr + offset + sh_2_4_start = sh_2_4_ptr + offset + tl.store(sh_1_0_start, sh_1_0, mask=offset < vector_length) + tl.store(sh_1_1_start, sh_1_1, mask=offset < vector_length) + tl.store(sh_1_2_start, sh_1_2, mask=offset < vector_length) + tl.store(sh_2_0_start, sh_2_0, mask=offset < vector_length) + tl.store(sh_2_1_start, sh_2_1, mask=offset < vector_length) + tl.store(sh_2_2_start, sh_2_2, mask=offset < vector_length) + tl.store(sh_2_3_start, sh_2_3, mask=offset < vector_length) + tl.store(sh_2_4_start, sh_2_4, mask=offset < vector_length) + + +@triton.jit +def _triton_second_order_bwd( + x_ptr: tl.tensor, + y_ptr: tl.tensor, + z_ptr: tl.tensor, + g_x_ptr: tl.tensor, + g_y_ptr: tl.tensor, + g_z_ptr: tl.tensor, + g_1_0_ptr: tl.tensor, + g_1_1_ptr: tl.tensor, + g_1_2_ptr: tl.tensor, + g_2_0_ptr: tl.tensor, + g_2_1_ptr: tl.tensor, + g_2_2_ptr: tl.tensor, + g_2_3_ptr: tl.tensor, + g_2_4_ptr: tl.tensor, + BLOCK_SIZE: tl.constexpr, + vector_length: tl.constexpr, +): + # expect the xyz are the same as the forward pass, we have expected + # gradient output tensors as well as intermediate gradients + sqrt_3 = 3**0.5 + sqrt_5 = 5**0.5 + sqrt_15 = 15**0.5 + block_id = tl.program_id(0) + # calculate the offset for this particular thread + offset = tl.arange(0, BLOCK_SIZE) + (BLOCK_SIZE * block_id) + x_row_start = x_ptr + offset + y_row_start = y_ptr + offset + z_row_start = z_ptr + offset + # load in x,y,z to operate on + x = tl.load(x_row_start, mask=offset < vector_length) + y = tl.load(y_row_start, mask=offset < vector_length) + z = tl.load(z_row_start, mask=offset < vector_length) + # load the pre-allocated xyz gradients + g_x_start = g_x_ptr + offset + g_y_start = g_y_ptr + offset + g_z_start = g_z_ptr + offset + # NOTE: these are the gradient outputs and are assumed to be initially zeros + g_x = tl.load(g_x_start, mask=offset < vector_length) + g_y = tl.load(g_y_start, mask=offset < vector_length) + g_z = tl.load(g_z_start, mask=offset < vector_length) + # this is the first order derivative, which is just root 3 + g_1_0 = tl.load(g_1_0_ptr + offset, mask=offset < vector_length) + g_1_1 = tl.load(g_1_1_ptr + offset, mask=offset < vector_length) + g_1_2 = tl.load(g_1_2_ptr + offset, mask=offset < vector_length) + g_x += sqrt_3 * g_1_0 + g_y += sqrt_3 * g_1_1 + g_z += sqrt_3 * g_1_2 + # now work on the second order derivatives, grouped by m + g_2_0 = tl.load(g_2_0_ptr + offset, mask=offset < vector_length) + g_2_1 = tl.load(g_2_1_ptr + offset, mask=offset < vector_length) + g_2_2 = tl.load(g_2_2_ptr + offset, mask=offset < vector_length) + g_2_3 = tl.load(g_2_3_ptr + offset, mask=offset < vector_length) + g_2_4 = tl.load(g_2_4_ptr + offset, mask=offset < vector_length) + # Y_2^0 + g_x += sqrt_15 * z * g_2_0 + g_z += sqrt_15 * x * g_2_0 + # Y_2^1 + g_x += sqrt_15 * y * g_2_1 + g_y += sqrt_15 * x * g_2_1 + # Y_2^2 + g_y += sqrt_15 * z * g_2_2 + g_z += sqrt_15 * y * g_2_2 + # Y_2^3 + g_x += -1.0 * sqrt_5 * x * g_2_3 + g_y += 2.0 * sqrt_5 * y * g_2_3 + g_z += -1.0 * sqrt_5 * z * g_2_3 + # Y_2_4 + g_x += -1.0 * sqrt_15 * x * g_2_4 + g_z += sqrt_15 * z * g_2_4 + # after all the operations are done, write back to memory + tl.store(g_x_ptr + offset, g_x, mask=offset < vector_length) + tl.store(g_y_ptr + offset, g_y, mask=offset < vector_length) + tl.store(g_z_ptr + offset, g_z, mask=offset < vector_length) + + +@triton.jit +def _triton_third_order_fwd( + x_ptr: tl.tensor, + y_ptr: tl.tensor, + z_ptr: tl.tensor, + sh_1_0_ptr: tl.tensor, + sh_1_1_ptr: tl.tensor, + sh_1_2_ptr: tl.tensor, + sh_2_0_ptr: tl.tensor, + sh_2_1_ptr: tl.tensor, + sh_2_2_ptr: tl.tensor, + sh_2_3_ptr: tl.tensor, + sh_2_4_ptr: tl.tensor, + sh_3_0_ptr: tl.tensor, + sh_3_1_ptr: tl.tensor, + sh_3_2_ptr: tl.tensor, + sh_3_3_ptr: tl.tensor, + sh_3_4_ptr: tl.tensor, + sh_3_5_ptr: tl.tensor, + sh_3_6_ptr: tl.tensor, + BLOCK_SIZE: tl.constexpr, + vector_length: tl.constexpr, +): + sqrt_3 = 3**0.5 + block_id = tl.program_id(0) + # calculate the offset for this particular thread + offset = tl.arange(0, BLOCK_SIZE) + (BLOCK_SIZE * block_id) + x_row_start = x_ptr + offset + y_row_start = y_ptr + offset + z_row_start = z_ptr + offset + # load in x,y,z to operate on + x = tl.load(x_row_start, mask=offset < vector_length) + y = tl.load(y_row_start, mask=offset < vector_length) + z = tl.load(z_row_start, mask=offset < vector_length) + # compute first order terms + sh_1_0 = x * sqrt_3 + sh_1_1 = y * sqrt_3 + sh_1_2 = z * sqrt_3 + # now work on second order + sqrt_15 = 15**0.5 + sqrt_5 = 5**0.5 + sq_x = x * x + sq_y = y * y + sq_z = z * z + # compute each component + sh_2_0 = sqrt_15 * x * z + sh_2_1 = sqrt_15 * x * y + # these two appear swapped, but they are consistent with e3nn + sh_2_2 = sqrt_5 * (sq_y - 0.5 * (sq_x + sq_z)) + sh_2_3 = sqrt_15 * y * z + sh_2_4 = 0.5 * sqrt_15 * (sq_z - sq_x) + # now work on third order + sqrt_42 = 42**0.5 + sqrt_168 = 168**0.5 + sqrt_7 = 7**0.5 + sh_3_0 = (1 / 6) * sqrt_42 * (sh_2_0 * z + sh_2_4 * x) + sh_3_1 = sqrt_7 * sh_2_0 * y + sh_3_2 = (1 / 8) * sqrt_168 * (4 * sq_y - (sq_x + sq_z)) * x + sh_3_3 = 0.5 * sqrt_7 * y * (2 * sq_y - 3 * (sq_x + sq_z)) + sh_3_4 = (1 / 8) * sqrt_168 * z * (4 * sq_y - (sq_x + sq_z)) + sh_3_5 = sqrt_7 * (sh_2_4 * y) + sh_3_6 = (1 / 6) * sqrt_42 * (sh_2_4 * z - sh_2_0 * x) + # write the results to memory + sh_1_0_start = sh_1_0_ptr + offset + sh_1_1_start = sh_1_1_ptr + offset + sh_1_2_start = sh_1_2_ptr + offset + sh_2_0_start = sh_2_0_ptr + offset + sh_2_1_start = sh_2_1_ptr + offset + sh_2_2_start = sh_2_2_ptr + offset + sh_2_3_start = sh_2_3_ptr + offset + sh_2_4_start = sh_2_4_ptr + offset + sh_3_0_start = sh_3_0_ptr + offset + sh_3_1_start = sh_3_1_ptr + offset + sh_3_2_start = sh_3_2_ptr + offset + sh_3_3_start = sh_3_3_ptr + offset + sh_3_4_start = sh_3_4_ptr + offset + sh_3_5_start = sh_3_5_ptr + offset + sh_3_6_start = sh_3_6_ptr + offset + tl.store(sh_1_0_start, sh_1_0, mask=offset < vector_length) + tl.store(sh_1_1_start, sh_1_1, mask=offset < vector_length) + tl.store(sh_1_2_start, sh_1_2, mask=offset < vector_length) + tl.store(sh_2_0_start, sh_2_0, mask=offset < vector_length) + tl.store(sh_2_1_start, sh_2_1, mask=offset < vector_length) + tl.store(sh_2_2_start, sh_2_2, mask=offset < vector_length) + tl.store(sh_2_3_start, sh_2_3, mask=offset < vector_length) + tl.store(sh_2_4_start, sh_2_4, mask=offset < vector_length) + tl.store(sh_3_0_start, sh_3_0, mask=offset < vector_length) + tl.store(sh_3_1_start, sh_3_1, mask=offset < vector_length) + tl.store(sh_3_2_start, sh_3_2, mask=offset < vector_length) + tl.store(sh_3_3_start, sh_3_3, mask=offset < vector_length) + tl.store(sh_3_4_start, sh_3_4, mask=offset < vector_length) + tl.store(sh_3_5_start, sh_3_5, mask=offset < vector_length) + tl.store(sh_3_6_start, sh_3_6, mask=offset < vector_length) + + +@triton.jit +def _triton_third_order_bwd( + x_ptr: tl.tensor, + y_ptr: tl.tensor, + z_ptr: tl.tensor, + g_x_ptr: tl.tensor, + g_y_ptr: tl.tensor, + g_z_ptr: tl.tensor, + g_1_0_ptr: tl.tensor, + g_1_1_ptr: tl.tensor, + g_1_2_ptr: tl.tensor, + g_2_0_ptr: tl.tensor, + g_2_1_ptr: tl.tensor, + g_2_2_ptr: tl.tensor, + g_2_3_ptr: tl.tensor, + g_2_4_ptr: tl.tensor, + g_3_0_ptr: tl.tensor, + g_3_1_ptr: tl.tensor, + g_3_2_ptr: tl.tensor, + g_3_3_ptr: tl.tensor, + g_3_4_ptr: tl.tensor, + g_3_5_ptr: tl.tensor, + g_3_6_ptr: tl.tensor, + BLOCK_SIZE: tl.constexpr, + vector_length: tl.constexpr, +): + # expect the xyz are the same as the forward pass, we have expected + # gradient output tensors as well as intermediate gradients + sqrt_3 = 3**0.5 + sqrt_5 = 5**0.5 + sqrt_15 = 15**0.5 + block_id = tl.program_id(0) + # calculate the offset for this particular thread + offset = tl.arange(0, BLOCK_SIZE) + (BLOCK_SIZE * block_id) + x_row_start = x_ptr + offset + y_row_start = y_ptr + offset + z_row_start = z_ptr + offset + # load in x,y,z to operate on + x = tl.load(x_row_start, mask=offset < vector_length) + y = tl.load(y_row_start, mask=offset < vector_length) + z = tl.load(z_row_start, mask=offset < vector_length) + # load the pre-allocated xyz gradients + g_x_start = g_x_ptr + offset + g_y_start = g_y_ptr + offset + g_z_start = g_z_ptr + offset + # NOTE: these are the gradient outputs and are assumed to be initially zeros + g_x = tl.load(g_x_start, mask=offset < vector_length) + g_y = tl.load(g_y_start, mask=offset < vector_length) + g_z = tl.load(g_z_start, mask=offset < vector_length) + # this is the first order derivative, which is just root 3 + g_1_0 = tl.load(g_1_0_ptr + offset, mask=offset < vector_length) + g_1_1 = tl.load(g_1_1_ptr + offset, mask=offset < vector_length) + g_1_2 = tl.load(g_1_2_ptr + offset, mask=offset < vector_length) + g_x += sqrt_3 * g_1_0 + g_y += sqrt_3 * g_1_1 + g_z += sqrt_3 * g_1_2 + # now work on the second order derivatives, grouped by m + g_2_0 = tl.load(g_2_0_ptr + offset, mask=offset < vector_length) + g_2_1 = tl.load(g_2_1_ptr + offset, mask=offset < vector_length) + g_2_2 = tl.load(g_2_2_ptr + offset, mask=offset < vector_length) + g_2_3 = tl.load(g_2_3_ptr + offset, mask=offset < vector_length) + g_2_4 = tl.load(g_2_4_ptr + offset, mask=offset < vector_length) + # Y_2^0 + g_x += sqrt_15 * z * g_2_0 + g_z += sqrt_15 * x * g_2_0 + # Y_2^1 + g_x += sqrt_15 * y * g_2_1 + g_y += sqrt_15 * x * g_2_1 + # Y_2^2 + g_y += sqrt_15 * z * g_2_2 + g_z += sqrt_15 * y * g_2_2 + # Y_2^3 + g_x += -1.0 * sqrt_5 * x * g_2_3 + g_y += 2.0 * sqrt_5 * y * g_2_3 + g_z += -1.0 * sqrt_5 * z * g_2_3 + # Y_2_4 + g_x += -1.0 * sqrt_15 * x * g_2_4 + g_z += sqrt_15 * z * g_2_4 + # now work on third order, but we group by cartesian axis instead + g_3_0 = tl.load(g_3_0_ptr + offset, mask=offset < vector_length) + g_3_1 = tl.load(g_3_1_ptr + offset, mask=offset < vector_length) + g_3_2 = tl.load(g_3_2_ptr + offset, mask=offset < vector_length) + g_3_3 = tl.load(g_3_3_ptr + offset, mask=offset < vector_length) + g_3_4 = tl.load(g_3_4_ptr + offset, mask=offset < vector_length) + g_3_5 = tl.load(g_3_5_ptr + offset, mask=offset < vector_length) + g_3_6 = tl.load(g_3_6_ptr + offset, mask=offset < vector_length) + sq_x = x * x + sq_y = y * y + sq_z = z * z + # IMO this is a more readable grouping, components within an axis + # unfortunately this is the part where "magic constants" start appearing + # since they're simplified expressions + g_x += ( + sqrt_15 + * g_3_0 + * ( + -1.62018517460196 * sq_x + + 1.08012344973464 * sq_z + + 0.540061724867322 * sq_z + ) + ) + g_x += 2.64575131106459 * sqrt_15 * g_3_1 * y * z + g_x -= g_3_2 * ( + 4.8605555238059 * sq_x - 6.48074069840786 * sq_y + 1.62018517460197 * sq_z + ) + g_x -= 7.93725393319377 * g_3_3 * x * y + g_x -= 3.24037034920393 * g_3_4 * x * z + g_x -= 2.64575131106459 * sqrt_15 * g_3_5 * x * y + g_x -= sqrt_15 * g_3_6 * z * (1.08012344973464 * x + 2.16024689946929 * x) + # now calculate y contributions + g_y += 2.64575131106459 * sqrt_15 * g_3_1 * x * z + g_y += 12.9614813968157 * g_3_2 * x * y + g_y -= g_3_3 * ( + 3.96862696659689 * sq_x - 7.93725393319377 * sq_y + 3.96862696659689 * sq_z + ) + g_y += 12.9614813968157 * g_3_4 * y * z + g_y -= 1.3228756555323 * sqrt_15 * g_3_5 * (sq_x - sq_z) + # now calculate z contributions + g_z += sqrt_15 * g_3_0 * x * (1.08012344973464 * z + 2.16024689946929 * z) + g_z += 2.64575131106459 * sqrt_15 * g_3_1 * x * y + g_z -= 3.24037034920393 * g_3_2 * x * z + g_z -= 7.93725393319377 * g_3_3 * y * z + g_z -= g_3_4 * ( + 1.62018517460197 * sq_x - 6.48074069840786 * sq_y + 4.8605555238059 * sq_z + ) + g_z += 2.64575131106459 * sqrt_15 * g_3_5 * y * z + g_z -= ( + sqrt_15 + * g_3_6 + * (1.08012344973464 * sq_x + 0.540061724867322 * sq_x - 1.62018517460196 * sq_z) + ) + # after all the operations are done, write back to memory + tl.store(g_x_ptr + offset, g_x, mask=offset < vector_length) + tl.store(g_y_ptr + offset, g_y, mask=offset < vector_length) + tl.store(g_z_ptr + offset, g_z, mask=offset < vector_length) + + +@triton.jit +def _triton_fourth_order_fwd( + x_ptr: tl.tensor, + y_ptr: tl.tensor, + z_ptr: tl.tensor, + sh_1_0_ptr: tl.tensor, + sh_1_1_ptr: tl.tensor, + sh_1_2_ptr: tl.tensor, + sh_2_0_ptr: tl.tensor, + sh_2_1_ptr: tl.tensor, + sh_2_2_ptr: tl.tensor, + sh_2_3_ptr: tl.tensor, + sh_2_4_ptr: tl.tensor, + sh_3_0_ptr: tl.tensor, + sh_3_1_ptr: tl.tensor, + sh_3_2_ptr: tl.tensor, + sh_3_3_ptr: tl.tensor, + sh_3_4_ptr: tl.tensor, + sh_3_5_ptr: tl.tensor, + sh_3_6_ptr: tl.tensor, + sh_4_0_ptr: tl.tensor, + sh_4_1_ptr: tl.tensor, + sh_4_2_ptr: tl.tensor, + sh_4_3_ptr: tl.tensor, + sh_4_4_ptr: tl.tensor, + sh_4_5_ptr: tl.tensor, + sh_4_6_ptr: tl.tensor, + sh_4_7_ptr: tl.tensor, + sh_4_8_ptr: tl.tensor, + BLOCK_SIZE: tl.constexpr, + vector_length: tl.constexpr, +): + sqrt_3 = 3**0.5 + block_id = tl.program_id(0) + # calculate the offset for this particular thread + offset = tl.arange(0, BLOCK_SIZE) + (BLOCK_SIZE * block_id) + x_row_start = x_ptr + offset + y_row_start = y_ptr + offset + z_row_start = z_ptr + offset + # load in x,y,z to operate on + x = tl.load(x_row_start, mask=offset < vector_length) + y = tl.load(y_row_start, mask=offset < vector_length) + z = tl.load(z_row_start, mask=offset < vector_length) + # compute first order terms + sh_1_0 = x * sqrt_3 + sh_1_1 = y * sqrt_3 + sh_1_2 = z * sqrt_3 + # now work on second order + sqrt_15 = 15**0.5 + sqrt_5 = 5**0.5 + sq_x = x * x + sq_y = y * y + sq_z = z * z + # compute each component + sh_2_0 = sqrt_15 * x * z + sh_2_1 = sqrt_15 * x * y + # these two appear swapped, but they are consistent with e3nn + sh_2_2 = sqrt_5 * (sq_y - 0.5 * (sq_x + sq_z)) + sh_2_3 = sqrt_15 * y * z + sh_2_4 = 0.5 * sqrt_15 * (sq_z - sq_x) + # now work on third order + sqrt_42 = 42**0.5 + sqrt_168 = 168**0.5 + sqrt_7 = 7**0.5 + sh_3_0 = (1 / 6) * sqrt_42 * (sh_2_0 * z + sh_2_4 * x) + sh_3_1 = sqrt_7 * sh_2_0 * y + sh_3_2 = (1 / 8) * sqrt_168 * (4 * sq_y - (sq_x + sq_z)) * x + sh_3_3 = 0.5 * sqrt_7 * y * (2 * sq_y - 3 * (sq_x + sq_z)) + sh_3_4 = (1 / 8) * sqrt_168 * z * (4 * sq_y - (sq_x + sq_z)) + sh_3_5 = sqrt_7 * (sh_2_4 * y) + sh_3_6 = (1 / 6) * sqrt_42 * (sh_2_4 * z - sh_2_0 * x) + # now work on fourth order + sqrt_2 = 2**0.5 + sqrt_210 = 210**0.5 + sqrt_14 = 14**0.5 + sqrt_21 = 21**0.5 + sqrt_70 = 70**0.5 + sqrt_105 = 105**0.5 + sqrt_6 = 6**0.5 + sh_4_0 = (3 / 4) * sqrt_2 * (sh_3_0 * z + sh_3_6 * x) + sh_4_1 = ( + (3 / 4) * sh_3_0 * y + + (3 / 8) * sqrt_6 * sh_3_1 * z + + (3 / 8) * sqrt_6 * sh_3_5 * x + ) + sh_4_2 = ( + -3 / 56 * sqrt_14 * sh_3_0 * z + + (3 / 14) * sqrt_21 * sh_3_1 * y + + (3 / 56) * sqrt_210 * sh_3_2 * z + + (3 / 56) * sqrt_210 * sh_3_4 * x + + (3 / 56) * sqrt_14 * sh_3_6 * x + ) + sh_4_3 = ( + -3 / 56 * sqrt_42 * sh_3_1 * z + + (3 / 28) * sqrt_105 * sh_3_2 * y + + (3 / 28) * sqrt_70 * sh_3_3 * x + + (3 / 56) * sqrt_42 * sh_3_5 * x + ) + sh_4_4 = ( + (-3 / 28 * sqrt_42 * sh_3_2 * x) + + (3 / 7) * sqrt_7 * sh_3_3 * y + - (3 / 28 * sqrt_42 * sh_3_4 * z) + ) + sh_4_5 = ( + -3 / 56 * sqrt_42 * sh_3_1 * x + + (3 / 28) * sqrt_70 * sh_3_3 * z + + (3 / 28) * sqrt_105 * sh_3_4 * y + - 3 / 56 * sqrt_42 * sh_3_5 * z + ) + sh_4_6 = ( + -3 / 56 * sqrt_14 * sh_3_0 * x + - 3 / 56 * sqrt_210 * sh_3_2 * x + + (3 / 56) * sqrt_210 * sh_3_4 * z + + (3 / 14) * sqrt_21 * sh_3_5 * y + - 3 / 56 * sqrt_14 * sh_3_6 * z + ) + sh_4_7 = ( + -3 / 8 * sqrt_6 * sh_3_1 * x + + (3 / 8) * sqrt_6 * sh_3_5 * z + + (3 / 4) * sh_3_6 * y + ) + sh_4_8 = (3 / 4) * sqrt_2 * (-sh_3_0 * x + sh_3_6 * z) + # write the results to memory + sh_1_0_start = sh_1_0_ptr + offset + sh_1_1_start = sh_1_1_ptr + offset + sh_1_2_start = sh_1_2_ptr + offset + sh_2_0_start = sh_2_0_ptr + offset + sh_2_1_start = sh_2_1_ptr + offset + sh_2_2_start = sh_2_2_ptr + offset + sh_2_3_start = sh_2_3_ptr + offset + sh_2_4_start = sh_2_4_ptr + offset + sh_3_0_start = sh_3_0_ptr + offset + sh_3_1_start = sh_3_1_ptr + offset + sh_3_2_start = sh_3_2_ptr + offset + sh_3_3_start = sh_3_3_ptr + offset + sh_3_4_start = sh_3_4_ptr + offset + sh_3_5_start = sh_3_5_ptr + offset + sh_3_6_start = sh_3_6_ptr + offset + sh_4_0_start = sh_4_0_ptr + offset + sh_4_1_start = sh_4_1_ptr + offset + sh_4_2_start = sh_4_2_ptr + offset + sh_4_3_start = sh_4_3_ptr + offset + sh_4_4_start = sh_4_4_ptr + offset + sh_4_5_start = sh_4_5_ptr + offset + sh_4_6_start = sh_4_6_ptr + offset + sh_4_7_start = sh_4_7_ptr + offset + sh_4_8_start = sh_4_8_ptr + offset + tl.store(sh_1_0_start, sh_1_0, mask=offset < vector_length) + tl.store(sh_1_1_start, sh_1_1, mask=offset < vector_length) + tl.store(sh_1_2_start, sh_1_2, mask=offset < vector_length) + tl.store(sh_2_0_start, sh_2_0, mask=offset < vector_length) + tl.store(sh_2_1_start, sh_2_1, mask=offset < vector_length) + tl.store(sh_2_2_start, sh_2_2, mask=offset < vector_length) + tl.store(sh_2_3_start, sh_2_3, mask=offset < vector_length) + tl.store(sh_2_4_start, sh_2_4, mask=offset < vector_length) + tl.store(sh_3_0_start, sh_3_0, mask=offset < vector_length) + tl.store(sh_3_1_start, sh_3_1, mask=offset < vector_length) + tl.store(sh_3_2_start, sh_3_2, mask=offset < vector_length) + tl.store(sh_3_3_start, sh_3_3, mask=offset < vector_length) + tl.store(sh_3_4_start, sh_3_4, mask=offset < vector_length) + tl.store(sh_3_5_start, sh_3_5, mask=offset < vector_length) + tl.store(sh_3_6_start, sh_3_6, mask=offset < vector_length) + tl.store(sh_4_0_start, sh_4_0, mask=offset < vector_length) + tl.store(sh_4_1_start, sh_4_1, mask=offset < vector_length) + tl.store(sh_4_2_start, sh_4_2, mask=offset < vector_length) + tl.store(sh_4_3_start, sh_4_3, mask=offset < vector_length) + tl.store(sh_4_4_start, sh_4_4, mask=offset < vector_length) + tl.store(sh_4_5_start, sh_4_5, mask=offset < vector_length) + tl.store(sh_4_6_start, sh_4_6, mask=offset < vector_length) + tl.store(sh_4_7_start, sh_4_7, mask=offset < vector_length) + tl.store(sh_4_8_start, sh_4_8, mask=offset < vector_length) + + +@triton.jit +def _triton_fourth_order_bwd( + x_ptr: tl.tensor, + y_ptr: tl.tensor, + z_ptr: tl.tensor, + g_x_ptr: tl.tensor, + g_y_ptr: tl.tensor, + g_z_ptr: tl.tensor, + g_1_0_ptr: tl.tensor, + g_1_1_ptr: tl.tensor, + g_1_2_ptr: tl.tensor, + g_2_0_ptr: tl.tensor, + g_2_1_ptr: tl.tensor, + g_2_2_ptr: tl.tensor, + g_2_3_ptr: tl.tensor, + g_2_4_ptr: tl.tensor, + g_3_0_ptr: tl.tensor, + g_3_1_ptr: tl.tensor, + g_3_2_ptr: tl.tensor, + g_3_3_ptr: tl.tensor, + g_3_4_ptr: tl.tensor, + g_3_5_ptr: tl.tensor, + g_3_6_ptr: tl.tensor, + g_4_0_ptr: tl.tensor, + g_4_1_ptr: tl.tensor, + g_4_2_ptr: tl.tensor, + g_4_3_ptr: tl.tensor, + g_4_4_ptr: tl.tensor, + g_4_5_ptr: tl.tensor, + g_4_6_ptr: tl.tensor, + g_4_7_ptr: tl.tensor, + g_4_8_ptr: tl.tensor, + BLOCK_SIZE: tl.constexpr, + vector_length: tl.constexpr, +): + # expect the xyz are the same as the forward pass, we have expected + # gradient output tensors as well as intermediate gradients + sqrt_3 = 3**0.5 + sqrt_5 = 5**0.5 + sqrt_15 = 15**0.5 + block_id = tl.program_id(0) + # calculate the offset for this particular thread + offset = tl.arange(0, BLOCK_SIZE) + (BLOCK_SIZE * block_id) + x_row_start = x_ptr + offset + y_row_start = y_ptr + offset + z_row_start = z_ptr + offset + # load in x,y,z to operate on + x = tl.load(x_row_start, mask=offset < vector_length) + y = tl.load(y_row_start, mask=offset < vector_length) + z = tl.load(z_row_start, mask=offset < vector_length) + # load the pre-allocated xyz gradients + g_x_start = g_x_ptr + offset + g_y_start = g_y_ptr + offset + g_z_start = g_z_ptr + offset + # NOTE: these are the gradient outputs and are assumed to be initially zeros + g_x = tl.load(g_x_start, mask=offset < vector_length) + g_y = tl.load(g_y_start, mask=offset < vector_length) + g_z = tl.load(g_z_start, mask=offset < vector_length) + # this is the first order derivative, which is just root 3 + g_1_0 = tl.load(g_1_0_ptr + offset, mask=offset < vector_length) + g_1_1 = tl.load(g_1_1_ptr + offset, mask=offset < vector_length) + g_1_2 = tl.load(g_1_2_ptr + offset, mask=offset < vector_length) + g_x += sqrt_3 * g_1_0 + g_y += sqrt_3 * g_1_1 + g_z += sqrt_3 * g_1_2 + # now work on the second order derivatives, grouped by m + g_2_0 = tl.load(g_2_0_ptr + offset, mask=offset < vector_length) + g_2_1 = tl.load(g_2_1_ptr + offset, mask=offset < vector_length) + g_2_2 = tl.load(g_2_2_ptr + offset, mask=offset < vector_length) + g_2_3 = tl.load(g_2_3_ptr + offset, mask=offset < vector_length) + g_2_4 = tl.load(g_2_4_ptr + offset, mask=offset < vector_length) + # Y_2^0 + g_x += sqrt_15 * z * g_2_0 + g_z += sqrt_15 * x * g_2_0 + # Y_2^1 + g_x += sqrt_15 * y * g_2_1 + g_y += sqrt_15 * x * g_2_1 + # Y_2^2 + g_y += sqrt_15 * z * g_2_2 + g_z += sqrt_15 * y * g_2_2 + # Y_2^3 + g_x += -1.0 * sqrt_5 * x * g_2_3 + g_y += 2.0 * sqrt_5 * y * g_2_3 + g_z += -1.0 * sqrt_5 * z * g_2_3 + # Y_2_4 + g_x += -1.0 * sqrt_15 * x * g_2_4 + g_z += sqrt_15 * z * g_2_4 + # now work on third order, but we group by cartesian axis instead + g_3_0 = tl.load(g_3_0_ptr + offset, mask=offset < vector_length) + g_3_1 = tl.load(g_3_1_ptr + offset, mask=offset < vector_length) + g_3_2 = tl.load(g_3_2_ptr + offset, mask=offset < vector_length) + g_3_3 = tl.load(g_3_3_ptr + offset, mask=offset < vector_length) + g_3_4 = tl.load(g_3_4_ptr + offset, mask=offset < vector_length) + g_3_5 = tl.load(g_3_5_ptr + offset, mask=offset < vector_length) + g_3_6 = tl.load(g_3_6_ptr + offset, mask=offset < vector_length) + sq_x = x * x + sq_y = y * y + sq_z = z * z + cu_z = sq_z * z + cu_x = sq_x * x + cu_y = sq_y * y + # IMO this is a more readable grouping, components within an axis + # unfortunately this is the part where "magic constants" start appearing + # since they're simplified expressions + g_x += ( + sqrt_15 + * g_3_0 + * ( + -1.62018517460196 * sq_x + + 1.08012344973464 * sq_z + + 0.540061724867322 * sq_z + ) + ) + g_x += 2.64575131106459 * sqrt_15 * g_3_1 * y * z + g_x -= g_3_2 * ( + 4.8605555238059 * sq_x - 6.48074069840786 * sq_y + 1.62018517460197 * sq_z + ) + g_x -= 7.93725393319377 * g_3_3 * x * y + g_x -= 3.24037034920393 * g_3_4 * x * z + g_x -= 2.64575131106459 * sqrt_15 * g_3_5 * x * y + g_x -= sqrt_15 * g_3_6 * z * (1.08012344973464 * x + 2.16024689946929 * x) + # now calculate y contributions + g_y += 2.64575131106459 * sqrt_15 * g_3_1 * x * z + g_y += 12.9614813968157 * g_3_2 * x * y + g_y -= g_3_3 * ( + 3.96862696659689 * sq_x - 7.93725393319377 * sq_y + 3.96862696659689 * sq_z + ) + g_y += 12.9614813968157 * g_3_4 * y * z + g_y -= 1.3228756555323 * sqrt_15 * g_3_5 * (sq_x - sq_z) + # now calculate z contributions + g_z += sqrt_15 * g_3_0 * x * (1.08012344973464 * z + 2.16024689946929 * z) + g_z += 2.64575131106459 * sqrt_15 * g_3_1 * x * y + g_z -= 3.24037034920393 * g_3_2 * x * z + g_z -= 7.93725393319377 * g_3_3 * y * z + g_z -= g_3_4 * ( + 1.62018517460197 * sq_x - 6.48074069840786 * sq_y + 4.8605555238059 * sq_z + ) + g_z += 2.64575131106459 * sqrt_15 * g_3_5 * y * z + g_z -= ( + sqrt_15 + * g_3_6 + * (1.08012344973464 * sq_x + 0.540061724867322 * sq_x - 1.62018517460196 * sq_z) + ) + # now work on fourth order, grouping by cartesian axis + g_4_0 = tl.load(g_4_0_ptr + offset, mask=offset < vector_length) + g_4_1 = tl.load(g_4_1_ptr + offset, mask=offset < vector_length) + g_4_2 = tl.load(g_4_2_ptr + offset, mask=offset < vector_length) + g_4_3 = tl.load(g_4_3_ptr + offset, mask=offset < vector_length) + g_4_4 = tl.load(g_4_4_ptr + offset, mask=offset < vector_length) + g_4_5 = tl.load(g_4_5_ptr + offset, mask=offset < vector_length) + g_4_6 = tl.load(g_4_6_ptr + offset, mask=offset < vector_length) + g_4_7 = tl.load(g_4_7_ptr + offset, mask=offset < vector_length) + g_4_8 = tl.load(g_4_8_ptr + offset, mask=offset < vector_length) + g_x -= ( + sqrt_15 + * g_4_0 + * ( + 3.43693177121688 * sq_x * z + + 3.43693177121688 * sq_x * z + - 1.14564392373896 * cu_z + - 1.14564392373896 * cu_z + ) + ) + g_x += ( + sqrt_15 + * g_4_1 + * y + * (-4.8605555238059 * sq_x + 3.24037034920393 * sq_z + 1.62018517460197 * sq_z) + ) + g_x -= g_4_2 * ( + 0.649519052838329 * sqrt_15 * sq_x * z + + 7.54672942406179 * sq_x * z + - 2.59807621135332 * sqrt_15 * sq_y * z + - 10.0623058987491 * sq_y * z + + 0.21650635094611 * sqrt_15 * cu_z + + 2.51557647468726 * cu_z + ) + g_x -= ( + g_4_3 + * y + * ( + 0.918558653543692 * sqrt_15 * sq_x + + 16.0090306546024 * sq_x + - 9.48683298050514 * sq_y + + 0.918558653543692 * sqrt_15 * sq_z + + 5.33634355153414 * sq_z + + 0.459279326771846 * sqrt_15 * (sq_x - sq_z) + ) + ) + g_x += g_4_4 * ( + -9.0 * x * sq_y + + 2.25 * x * sq_z + - 9.0 * x * sq_y + + 2.25 * x * sq_z + + 4.5 * cu_x + ) + g_x -= ( + g_4_5 + * y + * z + * ( + -0.918558653543692 * sqrt_15 * x + + 10.6726871030683 * x + + 1.83711730708738 * sqrt_15 * x + ) + ) + g_x -= g_4_6 * ( + 2.59807621135332 * sqrt_15 * x * sq_y + - 0.21650635094611 * sqrt_15 * x * sq_z + + 2.51557647468726 * x * sq_z + + 10.0623058987491 * x * sq_y + - 2.51557647468726 * x * sq_z + + 0.21650635094611 * sqrt_15 * x * sq_z + - 5.03115294937453 * cu_x + - 0.433012701892219 * sqrt_15 * cu_x + ) + g_x -= sqrt_15 * g_4_7 * y * z * (3.24037034920393 * x + 6.48074069840786 * x) + g_x -= ( + sqrt_15 + * g_4_8 + * ( + 1.14564392373896 * x * sq_z + + 4.58257569495584 * x * sq_z + + 1.14564392373896 * x * sq_z + - 2.29128784747792 * cu_x + ) + ) + g_y += ( + sqrt_15 + * g_4_1 + * x + * (-1.62018517460197 * sq_x + 3.24037034920393 * sq_z + 1.62018517460197 * sq_z) + ) + g_y += g_4_2 * x * z * (5.19615242270663 * sqrt_15 * y + 20.1246117974981 * y) + g_y -= ( + g_4_3 + * x + * ( + 5.33634355153414 * sq_x + - 28.4604989415154 * sq_y + + 0.918558653543692 * sqrt_15 * sq_z + + 5.33634355153414 * sq_z + + 0.459279326771846 * sqrt_15 * (sq_x - sq_z) + ) + ) + g_y -= g_4_4 * ( + 9.0 * sq_x * y + 9.0 * sq_x * y + 9.0 * y * sq_z + 9.0 * y * sq_z - 12.0 * cu_y + ) + g_y -= ( + g_4_5 + * z + * ( + 0.918558653543692 * sqrt_15 * sq_x + + 5.33634355153414 * sq_x + - 28.4604989415154 * sq_y + + 5.33634355153414 * sq_z + - 0.459279326771846 * sqrt_15 * (sq_x - sq_z) + ) + ) + g_y -= g_4_6 * ( + 10.0623058987491 * sq_x * y + + 2.59807621135332 * sqrt_15 * y * (sq_x - sq_z) + - 10.0623058987491 * y * sq_z + ) + g_y -= ( + sqrt_15 + * g_4_7 + * z + * (3.24037034920393 * sq_x + 1.62018517460197 * sq_x - 1.62018517460197 * sq_z) + ) + g_z -= ( + sqrt_15 + * g_4_0 + * ( + 1.14564392373896 * cu_x + - 3.43693177121688 * x * sq_z + - 3.43693177121688 * x * sq_z + + 1.14564392373896 * cu_x + ) + ) + g_z += sqrt_15 * g_4_1 * x * y * (3.24037034920393 * z + 6.48074069840786 * z) + g_z -= g_4_2 * ( + 0.21650635094611 * sqrt_15 * cu_x + - 2.59807621135332 * sqrt_15 * x * sq_y + - 10.0623058987491 * x * sq_y + + 0.649519052838329 * sqrt_15 * x * sq_z + + 7.54672942406179 * x * sq_z + + 2.51557647468726 * cu_x + ) + g_z -= ( + g_4_3 + * x + * y + * ( + -0.918558653543692 * sqrt_15 * z + + 10.6726871030683 * z + + 1.83711730708738 * sqrt_15 * z + ) + ) + g_z += g_4_4 * ( + 2.25 * sq_x * z + 2.25 * sq_x * z - 9.0 * sq_y * z - 9.0 * sq_y * z + 4.5 * cu_z + ) + g_z -= ( + g_4_5 + * y + * ( + 0.918558653543692 * sqrt_15 * sq_x + + 5.33634355153414 * sq_x + - 9.48683298050514 * sq_y + + 0.918558653543692 * sqrt_15 * sq_z + + 16.0090306546024 * sq_z + - 0.459279326771846 * sqrt_15 * (sq_x - sq_z) + ) + ) + g_z += g_4_6 * ( + -0.21650635094611 * sqrt_15 * sq_x * z + + 2.51557647468726 * sq_x * z + - 2.51557647468726 * sq_x * z + + 0.21650635094611 * sqrt_15 * sq_x * z + + 2.59807621135332 * sqrt_15 * sq_y * z + + 10.0623058987491 * sq_y * z + - 5.03115294937453 * cu_z + - 0.433012701892219 * sqrt_15 * cu_z + ) + g_z -= ( + sqrt_15 + * g_4_7 + * y + * (3.24037034920393 * sq_x + 1.62018517460197 * sq_x - 4.8605555238059 * sq_z) + ) + g_z -= ( + sqrt_15 + * g_4_8 + * ( + 1.14564392373896 * sq_x * z + + 4.58257569495584 * sq_x * z + + 1.14564392373896 * sq_x * z + - 2.29128784747792 * cu_z + ) + ) + # after all the operations are done, write back to memory + tl.store(g_x_ptr + offset, g_x, mask=offset < vector_length) + tl.store(g_y_ptr + offset, g_y, mask=offset < vector_length) + tl.store(g_z_ptr + offset, g_z, mask=offset < vector_length) diff --git a/src/equitriton/tests/test_benchmark.py b/src/equitriton/tests/test_benchmark.py new file mode 100644 index 0000000..e4afa28 --- /dev/null +++ b/src/equitriton/tests/test_benchmark.py @@ -0,0 +1,16 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: MIT License +from __future__ import annotations + +from time import sleep +from random import random + +from equitriton.benchmark import benchmark + + +def test_benchmark_decorator(): + @benchmark(num_steps=50) + def dummy_func(): + sleep(random()) # nosec + + dummy_func() diff --git a/src/equitriton/utils.py b/src/equitriton/utils.py new file mode 100644 index 0000000..7760350 --- /dev/null +++ b/src/equitriton/utils.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import torch +import triton + +__all__ = ["pad_tensor_to_power"] + + +def pad_tensor_to_power( + input_tensor: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Pad a tensor to the nearest power of two. + + The goal of this is to minimize the number of compiled + kernels due to large variations in tensor shapes. By + padding to the nearest power of two, we hopefully only + encounter typical tensor shapes, with the cost of a bit + of memory overhead. + + Parameters + ---------- + input_tensor : torch.Tensor + Tensor to be padded. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + A 2-tuple of tensors: the first is the padded tensor, + and the second is a 1D mask to be applied along the + node dimension of a tensor. + """ + num_nodes = input_tensor.size(0) + pad_size = triton.next_power_of_2(num_nodes) + num_pad = pad_size - num_nodes + # allocate tensor of zeros to pad with + zero_pad = torch.zeros( + (num_pad, *input_tensor.shape[1:]), + dtype=input_tensor.dtype, + device=input_tensor.device, + ) + joint_tensor = torch.cat([input_tensor, zero_pad], dim=0) + mask = torch.ones(pad_size, device=joint_tensor.device, dtype=torch.bool) + mask[num_nodes:] = False + return (joint_tensor, mask)