diff --git a/README.md b/README.md index 1de1534..06e7058 100644 --- a/README.md +++ b/README.md @@ -79,6 +79,46 @@ nodes and masking in the forward pass. The script `scripts/dynamic_shapes.py` wi let you test the performance over a range of shapes; we encourage you to test it before performing full-scale training/inference. +## Decoupled spherical harmonics kernels + +We recently published a paper at the AI4Mat workshop at NeurIPS 2024, which as part +of that work, we went back into ``sympy`` to refactor the spherical harmonics up to $l=10$, +such that computations of a particular order are _independent_ from others. This allows +arbitrary orders to be freely composed without incurring a performance penalty, in +the case that one wishes to calculate $l=8$, but not $l=7$, for example. + +Functionally, these kernels are intended to behave in the same way as their original +implementation, i.e. they still provide equivariant properties when used to map +cartesian point clouds. However, because of the aggressive refactoring and heavy use +of hard-coded literals, they may (or will) differ numerically from even the initial _EquiTriton_ +kernels, particularly at higher orders. + +> [!IMPORTANT] +> For the above reason, while the kernels can be drop-in replacements, we do not recommend +> using them from already trained models, at least without some testing on the user's part, +> as the results may differ. We have also not yet attempted to use these kernels as part of +> simulation-based workflows (i.e. molecular dynamics), however our training experiments do +> show that training indeed does converge. + +To use the new set of decoupled kernels, the main `torch.autograd` binding is through +the `equitriton.sph_harm.direct.TritonSphericalHarmonic`: + +```python +import torch +from equitriton.sph_harm.direct import TritonSphericalHarmonic + +coords = torch.rand(100, 3) +sph_harm = TritonSphericalHarmonic.apply( + l_values=[0, 1, 2, 6, 10], + coords=coords +) +``` + +The improvements to performance are expected to come from (1) decoupling of each spherical +harmonic order, and (2) pre-allocation of an output tensor as to avoid using `torch.cat`, +which calculates each order followed by copying. See the "Direct spherical harmonics evaluation" +notebook in the notebooks folder for derivation. + ### Development and usage on Intel XPU Development on Intel XPUs such as the Data Center GPU Max Series 1550 requires @@ -131,7 +171,9 @@ contributions will be licensed under this license. Citation -------- -If you find this repo useful, please consider citing the corresponding paper: +If you find this repo useful, please consider citing the respective papers. + +For the original EquiTriton implementation, please use/read the following citation: ```bibtex @inproceedings{lee2024scaling, @@ -141,4 +183,16 @@ If you find this repo useful, please consider citing the corresponding paper: year={2024}, url={https://openreview.net/forum?id=ftK00FO5wq} } -``` \ No newline at end of file +``` + +For the refactored spherical harmonics up to $l=10$, and subsequent PHATE embedding analysis, see: + +```bibtex +@inproceedings{lee2024deconstructing, + title={Deconstructing equivariant representations in molecular systems}, + author={Kin Long Kelvin Lee and Mikhail Galkin and Santiago Miret}, + booktitle={AI for Accelerated Materials Design - NeurIPS 2024}, + year={2024}, + url={https://openreview.net/forum?id=pshyLoyzRn} +} +``` diff --git a/notebooks/.gitignore b/notebooks/.gitignore new file mode 100644 index 0000000..b75f584 --- /dev/null +++ b/notebooks/.gitignore @@ -0,0 +1,2 @@ +qm9_data/ +lightning_logs/ diff --git a/notebooks/Baseline model development.ipynb b/notebooks/Baseline model development.ipynb new file mode 100644 index 0000000..3a7c1c4 --- /dev/null +++ b/notebooks/Baseline model development.ipynb @@ -0,0 +1,1318 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "fd832b52-7df1-4d42-a2ea-a7139df2196b", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Literal, Callable, Any\n", + "from math import ceil\n", + "\n", + "import torch\n", + "from torch import nn\n", + "from torch.utils.data import random_split\n", + "from torch.optim import AdamW\n", + "import e3nn\n", + "from e3nn import o3\n", + "from torch_scatter import scatter\n", + "from torch_geometric.data import Data as PyGGraph\n", + "from torch_geometric.datasets import QM9\n", + "from torch_geometric.loader import DataLoader\n", + "from torch_cluster import radius_graph\n", + "import pytorch_lightning as pl\n", + "from matplotlib import pyplot as plt\n", + "\n", + "from equitriton.sph_harm.direct import triton_spherical_harmonic\n", + "from equitriton.utils import spherical_harmonics_irreps" + ] + }, + { + "cell_type": "markdown", + "id": "1e68d7b9-c046-492c-8380-b801d6a0f209", + "metadata": {}, + "source": [ + "# Baseline model development\n", + "\n", + "This notebook was used to develop the simple graph convolution model used in the paper _Deconstructing equivariant molecular representations in molecular systems_ at the AI4Mat workshop at NeurIPS 2024.\n", + "\n", + "If you are looking to use this architecture in your own testing, please look at the `equitriton.model.blocks` module for the \"production ready\" version (i.e. please don't copy-paste this code unless you're looking to modify things!)." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "a0307714-1036-4110-8b53-0c50062ee913", + "metadata": {}, + "outputs": [], + "source": [ + "seed = torch.manual_seed(215162)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "79f96b13-929a-4b2e-a4d9-991194fd98ba", + "metadata": {}, + "outputs": [], + "source": [ + "class AtomEmbedding(nn.Module):\n", + " \"\"\"\n", + " A PyTorch module for embedding atomic numbers into dense vectors.\n", + "\n", + " Parameters\n", + " ----------\n", + " num_atoms : int\n", + " The number of distinct atomic types in the dataset.\n", + " atom_dim : int\n", + " The dimensionality of the embedded atomic vectors.\n", + "\n", + " Example\n", + " --------\n", + " >>> # Create an instance of the AtomEmbedding module\n", + " >>> atom_embedding = AtomEmbedding(num_atoms=10, atom_dim=128)\n", + "\n", + " >>> # Embed a batch of atomic numbers\n", + " >>> embedded_vectors = atom_embedding(torch.tensor([1, 2, 3, 4]))\n", + "\n", + " >>> print(embedded_vectors.shape) # Output: torch.Size([4, 128])\n", + " \"\"\"\n", + "\n", + " def __init__(self, num_atoms: int, atom_dim: int):\n", + " \"\"\"\n", + " Initializes the AtomEmbedding module.\n", + "\n", + " Parameters\n", + " ----------\n", + " num_atoms : int\n", + " The number of distinct atomic types in the dataset.\n", + " atom_dim : int\n", + " The dimensionality of the embedded atomic vectors.\n", + " \"\"\"\n", + " super().__init__()\n", + " self.embedding = nn.Embedding(num_atoms, atom_dim, padding_idx=0)\n", + "\n", + " def forward(self, atomic_numbers: torch.LongTensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Embeds a batch of atomic numbers into a tensor of embedded vectors.\n", + "\n", + " Parameters\n", + " ----------\n", + " atomic_numbers : torch.LongTensor\n", + " A tensor of shape (batch_size,) containing the atomic numbers to embed.\n", + "\n", + " Returns\n", + " -------\n", + " torch.Tensor\n", + " A tensor of shape (batch_size, atom_dim) containing the embedded atomic vectors.\n", + " \"\"\"\n", + " return self.embedding(atomic_numbers)\n", + "\n", + "\n", + "class EdgeEmbedding(nn.Module):\n", + " def __init__(self, num_basis: int, radius_cutoff: float = 6.0, **kwargs):\n", + " \"\"\"\n", + " This module embeds edges in a graph with an EdgeEmbedding object.\n", + "\n", + " Parameters\n", + " ----------\n", + " num_basis : int, optional\n", + " The number of basis functions. Defaults to 1.\n", + " radius_cutoff : float, optional\n", + " The maximum radius up to which basis functions are defined. Defaults to 6.0.\n", + "\n", + " Optional kwargs\n", + " ---------------\n", + " basis : str, optional\n", + " The type of basis function to use. Defaults to 'bessel'.\n", + " start : float, optional\n", + " The starting point in the distance grid used in the radial basis.\n", + " cutoff : bool, optional\n", + " Whether or not to apply a cutoff to the basis functions.\n", + "\n", + " Returns\n", + " -------\n", + " torch.Tensor\n", + " A tensor representing the embedding of edges with shape (num_edges, num_basis).\n", + "\n", + " Examples\n", + " --------\n", + " >>> # Define an instance of EdgeEmbedding with 4 basis functions and a radius cutoff of 10.\n", + " >>> embedder = EdgeEmbedding(num_basis=4, radius_cutoff=10.0)\n", + " \"\"\"\n", + " super().__init__()\n", + " kwargs.setdefault(\"basis\", \"bessel\")\n", + " kwargs.setdefault(\"start\", 0.0)\n", + " kwargs.setdefault(\"cutoff\", True)\n", + " self.num_basis = num_basis\n", + " self.radius_cutoff = radius_cutoff\n", + " self.basis_kwargs = kwargs\n", + "\n", + " def forward(self, distances: torch.Tensor) -> torch.Tensor:\n", + " basis_funcs = e3nn.math.soft_one_hot_linspace(\n", + " distances,\n", + " number=self.num_basis,\n", + " end=self.radius_cutoff,\n", + " **self.basis_kwargs,\n", + " )\n", + " return basis_funcs * self.num_basis**0.5\n", + "\n", + "\n", + "class SphericalHarmonicEmbedding(nn.Module):\n", + " def __init__(\n", + " self,\n", + " l_values: list[int],\n", + " normalize: bool = True,\n", + " normalization: Literal[\"norm\", \"integral\", \"component\"] = \"integral\",\n", + " use_e3nn: bool = False,\n", + " ):\n", + " \"\"\"\n", + " Projects cartesian positions onto spherical harmonic functions.\n", + " \"\"\"\n", + " super().__init__()\n", + " self.l_values = list(sorted(l_values))\n", + " self.irreps = spherical_harmonics_irreps(self.l_values, num_feat=1)\n", + " self.normalize = normalize\n", + " self.normalization = normalization\n", + " self.use_e3nn = use_e3nn\n", + "\n", + " def forward(self, coords: torch.Tensor) -> torch.Tensor:\n", + " if not self.use_e3nn:\n", + " if self.normalize:\n", + " coords = torch.nn.functional.normalize(coords, dim=-1)\n", + " outputs = [triton_spherical_harmonic(l, coords) for l in self.l_values]\n", + " outputs = torch.cat(outputs, dim=-1)\n", + " if self.normalization == \"integral\":\n", + " outputs /= (4.0 * torch.pi) ** 0.5\n", + " return outputs\n", + " else:\n", + " return o3.spherical_harmonics(\n", + " self.irreps, coords, self.normalize, self.normalization\n", + " )\n", + "\n", + "\n", + "class InteractionBlock(nn.Module):\n", + " def __init__(\n", + " self,\n", + " atomic_dim: int | o3.Irreps,\n", + " l_values: int,\n", + " edge_dim: int,\n", + " hidden_dim: int,\n", + " radius_cutoff: float,\n", + " degree_norm: float,\n", + " edge_kwargs: dict[str, Any] = {},\n", + " sph_harm_kwargs: dict[str, Any] = {},\n", + " activation: Callable = nn.functional.silu,\n", + " ):\n", + " \"\"\"\n", + " A module that combines radial basis with spherical harmonics to\n", + " describe molecular interactions.\n", + "\n", + " Parameters\n", + " ----------\n", + " atomic_dim : int | o3.Irreps\n", + " Dimension of the atomic features. If int, it is treated as a\n", + " single irreducible representation.\n", + " l_values : int\n", + " Values of the spherical harmonic order.\n", + " edge_dim : int\n", + " Dimension of the edge features.\n", + " hidden_dim : int\n", + " Hidden dimension for the fully connected network.\n", + " radius_cutoff : float\n", + " Cutoff radius for the radial basis.\n", + " degree_norm : float\n", + " Normalization factor for the degree of the graph.\n", + " edge_kwargs : dict[str, Any], optional\n", + " Keyword arguments for the EdgeEmbedding module. Defaults to {}.\n", + " sph_harm_kwargs : dict[str, Any], optional\n", + " Keyword arguments for the SphericalHarmonicEmbedding module.\n", + " Defaults to {}.\n", + " activation : Callable, optional\n", + " Activation function for the fully connected network. Defaults to\n", + " nn.functional.silu.\n", + "\n", + " Notes\n", + " -----\n", + " The `degree_norm` attribute is set as a property and effectively\n", + " represents the average number of neighbors in other models.\n", + "\n", + " Examples\n", + " --------\n", + " >>> block = InteractionBlock(atomic_dim=8, l_values=[0, 1],\n", + " edge_dim=16, hidden_dim=32)\n", + " >>> block.sph_irreps\n", + " ['1x0e', '2x0e']\n", + " \"\"\"\n", + " sph_harm_kwargs.setdefault(\"use_e3nn\", False)\n", + "\n", + " super().__init__()\n", + " # this is effectively the average number of neighbors in other models\n", + " self.degree_norm = degree_norm\n", + " # treat atom features as invariant\n", + " if isinstance(atomic_dim, int):\n", + " atomic_irreps = f\"{atomic_dim}x0e\"\n", + " else:\n", + " atomic_irreps = atomic_dim\n", + " self.atomic_irreps = atomic_irreps\n", + " self.l_values = list(sorted(l_values))\n", + " # these two attributes are similar but different: the former is used for describing\n", + " # the basis itself, and the latter is for actually specifying the weights\n", + " self.sph_irreps = spherical_harmonics_irreps(self.l_values, num_feat=1)\n", + " self.output_irreps = spherical_harmonics_irreps(\n", + " self.l_values, num_feat=hidden_dim\n", + " )\n", + " # tensor product is the final bit the combines the radial basis with the spherical\n", + " # harmonics\n", + " self.tensor_product = o3.FullyConnectedTensorProduct(\n", + " self.atomic_irreps,\n", + " self.sph_irreps,\n", + " self.output_irreps,\n", + " shared_weights=False,\n", + " )\n", + " self.edge_basis = EdgeEmbedding(edge_dim, radius_cutoff, **edge_kwargs)\n", + " self.spherical_harmonics = SphericalHarmonicEmbedding(\n", + " l_values, **sph_harm_kwargs\n", + " )\n", + " self.fc = e3nn.nn.FullyConnectedNet(\n", + " [edge_dim, hidden_dim, self.tensor_product.weight_numel], activation\n", + " )\n", + "\n", + " @property\n", + " def num_projections(self) -> int:\n", + " \"\"\"Returns the expected number of projections.\"\"\"\n", + " return sum([2 * l + 1 for l in self.l_values])\n", + "\n", + " @property\n", + " def output_dim(self) -> int:\n", + " \"\"\"Returns the dimensionality of the output.\"\"\"\n", + " return self.output_irreps.dim\n", + "\n", + " def forward(\n", + " self,\n", + " atomic_features: torch.Tensor,\n", + " coords: torch.Tensor,\n", + " edge_index: torch.LongTensor,\n", + " ) -> torch.Tensor:\n", + " \"\"\"\n", + " High-level description:\n", + "\n", + " 1. Project cartesian coordinates onto spherical harmonic basis\n", + " 2. Project interatomic distances onto radial (bessel) basis\n", + " 3. Transform radial basis functions with learnable weights\n", + " 4. Compute tensor product between scalar atom features and spherical harmonic basis\n", + " 5. Update node features\n", + " \"\"\"\n", + " edge_dist = coords[edge_index[0]] - coords[edge_index[1]]\n", + " sph_harm = self.spherical_harmonics(edge_dist)\n", + " # calculate atomic distances, embed, and transform them\n", + " edge_basis = self.edge_basis(edge_dist.norm(dim=-1))\n", + " edge_z = self.fc(edge_basis)\n", + " # compute tensor product\n", + " messages = self.tensor_product(atomic_features[edge_index[0]], sph_harm, edge_z)\n", + " # update node features\n", + " hidden_feats = (\n", + " scatter(messages, edge_index[1], dim=0, dim_size=atomic_features.size(0))\n", + " / self.degree_norm\n", + " )\n", + " return hidden_feats\n", + "\n", + "\n", + "class ScalarReadoutLayer(nn.Module):\n", + " def __init__(self, hidden_irreps: o3.Irreps, output_dim: int):\n", + " super().__init__()\n", + " self.hidden_irreps = hidden_irreps\n", + " self.output_irreps = o3.Irreps(f\"{output_dim}x0e\")\n", + " self.output_layer = o3.Linear(\n", + " irreps_in=hidden_irreps, irreps_out=self.output_irreps\n", + " )\n", + "\n", + " def forward(self, node_feats: torch.Tensor) -> torch.Tensor:\n", + " return self.output_layer(node_feats)\n", + "\n", + "\n", + "class EquiTritonModel(nn.Module):\n", + " def __init__(\n", + " self,\n", + " initial_atom_dim: int,\n", + " num_layers: int,\n", + " output_dim: int,\n", + " l_values: int,\n", + " edge_dim: int,\n", + " hidden_dim: int,\n", + " radius_cutoff: float,\n", + " degree_norm: float,\n", + " edge_kwargs: dict[str, Any] = {},\n", + " sph_harm_kwargs: dict[str, Any] = {},\n", + " activation: Callable = nn.functional.silu,\n", + " num_atoms: int = 100,\n", + " skip_connections: bool = True,\n", + " ):\n", + " \"\"\"\n", + " End-to-end simple model that uses the EquiTriton kernels.\n", + "\n", + " Parameters\n", + " ----------\n", + " initial_atom_dim : int\n", + " The dimensionality of the atomic embeddings.\n", + " num_layers : int\n", + " The number of convolutional layers in the model.\n", + " output_dim : int\n", + " The dimensionality of the graph-level scalar features.\n", + " l_values : int\n", + " The maximum value of the L indices for spherical harmonics.\n", + " edge_dim : int\n", + " The dimensionality of the edge features.\n", + " hidden_dim : int\n", + " The hidden dimensionality of the interaction blocks.\n", + " radius_cutoff : float\n", + " The cutoff distance for the radial basis functions.\n", + " degree_norm : float\n", + " The normalization factor for the degree of each node.\n", + " edge_kwargs : dict, optional\n", + " Additional keyword arguments for the edge embedding layer. Defaults to {}.\n", + " sph_harm_kwargs : dict, optional\n", + " Additional keyword arguments for the spherical harmonics layer. Defaults to {}.\n", + " activation : Callable, optional\n", + " The activation function used in the model. Defaults to nn.functional.silu.\n", + " num_atoms : int, optional\n", + " The number of atoms in each graph. Defaults to 100.\n", + " skip_connections : bool, optional\n", + " Whether to use skip connections between layers. Defaults to True.\n", + " \"\"\"\n", + " sph_harm_kwargs.setdefault(\"use_e3nn\", False)\n", + "\n", + " super().__init__()\n", + " self.atomic_embedding = AtomEmbedding(num_atoms, initial_atom_dim)\n", + " self.initial_layer = InteractionBlock(\n", + " initial_atom_dim,\n", + " l_values,\n", + " edge_dim,\n", + " hidden_dim,\n", + " radius_cutoff,\n", + " degree_norm,\n", + " edge_kwargs,\n", + " sph_harm_kwargs,\n", + " activation,\n", + " )\n", + " self.conv_layers = nn.ModuleDict()\n", + " for layer_index in range(num_layers + 1):\n", + " self.conv_layers[f\"conv_{layer_index}\"] = InteractionBlock(\n", + " self.initial_layer.output_irreps,\n", + " l_values,\n", + " edge_dim,\n", + " hidden_dim,\n", + " radius_cutoff,\n", + " degree_norm,\n", + " edge_kwargs,\n", + " sph_harm_kwargs,\n", + " activation,\n", + " )\n", + " self.scalar_readout = ScalarReadoutLayer(\n", + " self.initial_layer.output_irreps, output_dim\n", + " )\n", + " self.skip_connections = skip_connections\n", + " self.output_dim = output_dim\n", + "\n", + " def visualize(self, **kwargs):\n", + " num_plots = len(self.conv_layers) + 1\n", + " fig, axarray = plt.subplots(num_plots, 1, figsize=(3, 12))\n", + " # make indexing easier\n", + " axarray = axarray.flatten()\n", + "\n", + " self.initial_layer.tensor_product.visualize(ax=axarray[0], **kwargs)\n", + " axarray[0].set_title(\"Input layer\", loc=\"right\")\n", + " index = 1\n", + " for layer_name, layer in self.conv_layers.items():\n", + " ax = axarray[index]\n", + " layer.tensor_product.visualize(ax=ax, **kwargs)\n", + " ax.set_title(layer_name, loc=\"right\")\n", + " index += 1\n", + " fig.tight_layout()\n", + " return fig, axarray\n", + "\n", + " def forward(self, graph: PyGGraph) -> tuple[torch.Tensor, torch.Tensor]:\n", + " # determine if the graph is batched or not\n", + " is_batched = hasattr(graph, \"ptr\")\n", + " # get atom embeddings\n", + " atom_z = self.atomic_embedding(graph.z) # [nodes, initial_atom_dim]\n", + " # first message passing step\n", + " z = self.initial_layer(atom_z, graph.pos, graph.edge_index)\n", + " outputs = {}\n", + " for layer_name, layer in self.conv_layers.items():\n", + " new_z = layer(z, graph.pos, graph.edge_index)\n", + " # add residual connections\n", + " if self.skip_connections and new_z.shape == z.shape:\n", + " new_z += z\n", + " z = new_z\n", + " outputs[layer_name] = z\n", + " # map final output as scalars\n", + " z = self.scalar_readout(z)\n", + " # latest node features are in z; we generate graph-level scalar features\n", + " # by doing a scatter add\n", + " if is_batched:\n", + " graph_z = scatter(z, graph.batch, dim=0, dim_size=graph.batch_size)\n", + " else:\n", + " # for a single graph, just sum up the node features\n", + " graph_z = z.sum(dim=0, keepdims=True)\n", + " return graph_z, z" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "4e568b72-3990-4652-8a32-db375c65a81b", + "metadata": {}, + "outputs": [], + "source": [ + "def make_fake_graph(\n", + " num_nodes: int,\n", + " num_edges: int,\n", + " max_radius: float = 1.5,\n", + " coord_scale: float = 1.0,\n", + " max_atomic_number: int = 100,\n", + " device=\"xpu\",\n", + "):\n", + " \"\"\"\n", + " Generate a fake graph with the specified number of nodes and edges.\n", + "\n", + " Parameters\n", + " ----------\n", + " num_nodes : int\n", + " The number of nodes in the graph.\n", + " num_edges : int\n", + " The number of edges in the graph.\n", + " max_radius : float, optional\n", + " The maximum radius for node connections. Defaults to 1.5.\n", + " coord_scale : float, optional\n", + " The scaling factor for node coordinates. Defaults to 1.0.\n", + " max_atomic_number : int, optional\n", + " The maximum atomic number for nodes. Defaults to 100.\n", + " device : str or torch.device, optional\n", + " The device to use for computations.\n", + "\n", + " Returns\n", + " -------\n", + " tuple[torch.Tensor, torch.Tensor, torch.Tensor]\n", + " A tuple containing:\n", + " - coords (torch.Tensor): Node coordinates with shape ``(num_nodes, 3)``.\n", + " - edge_index (torch.Tensor): Edge indices with shape ``(2, num_edges)``.\n", + " - atomic_numbers (torch.Tensor): Atomic numbers for nodes with shape ``(num_nodes,)``.\n", + "\n", + " Examples\n", + " --------\n", + " >>> coords, edge_index, atomic_numbers = make_fake_graph(10, 20)\n", + " >>> print(coords.shape) # (10, 3)\n", + " >>> print(edge_index.shape) # (2, 20)\n", + " >>> print(atomic_numbers.shape) # (10,)\n", + " \"\"\"\n", + " coords = torch.rand(num_nodes, 3, device=device) * coord_scale\n", + " edge_src, edge_dst = radius_graph(\n", + " coords, max_radius, max_num_neighbors=num_nodes - 1\n", + " )\n", + " edge_index = torch.vstack([edge_src, edge_dst]).to(device)\n", + " atomic_numbers = torch.randint(\n", + " 0, max_atomic_number, size=(num_nodes,), device=device\n", + " )\n", + " return coords, edge_index, atomic_numbers" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "8671f0e8-0da8-4250-9d02-fb30ecb977fe", + "metadata": {}, + "outputs": [], + "source": [ + "edge_embedder = EdgeEmbedding(num_basis=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f4971d48-3574-44d0-b2fb-27306bc3aeac", + "metadata": {}, + "outputs": [], + "source": [ + "coords, edge_index, atomic_numbers = make_fake_graph(\n", + " 16,\n", + " 12,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "55f14f94-51ca-4e5c-ad63-ca830404eaf9", + "metadata": {}, + "outputs": [], + "source": [ + "# coords = torch.ones_like(coords, requires_grad=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "67f6e836-f4e0-48cf-981d-4419dda0159b", + "metadata": {}, + "outputs": [], + "source": [ + "atom_embedder = AtomEmbedding(100, 64).to(\"cuda\")\n", + "atom_z = atom_embedder(atomic_numbers)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "46778ea6-92f2-4d7a-9bd8-aa58013cd1c6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 240])" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "edge_index.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "0c82e80f-c4bd-43fe-8846-dab203757992", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/kelvin/miniforge3/envs/equitriton/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "layer = InteractionBlock(\n", + " 64,\n", + " [\n", + " 0,\n", + " 1,\n", + " 2,\n", + " ],\n", + " 10,\n", + " 32,\n", + " radius_cutoff=6.0,\n", + " degree_norm=17**0.5,\n", + " sph_harm_kwargs={\"use_e3nn\": True},\n", + ").to(\"cuda\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "a0913f1e-b000-4a3b-b723-84a92d16425d", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/kelvin/miniforge3/envs/equitriton/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "next_layer = InteractionBlock(\n", + " layer.output_irreps,\n", + " [\n", + " 0,\n", + " 1,\n", + " 2,\n", + " ],\n", + " 10,\n", + " 32,\n", + " radius_cutoff=6.0,\n", + " degree_norm=17**0.5,\n", + ").to(\"cuda\")" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "e893b896-f7b9-4b62-95a1-fe7fb45cf2ed", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "InteractionBlock(\n", + " (tensor_product): FullyConnectedTensorProduct(64x0e x 1x0e+1x1o+1x2e -> 32x0e+32x1o+32x2e | 6144 paths | 6144 weights)\n", + " (edge_basis): EdgeEmbedding()\n", + " (spherical_harmonics): SphericalHarmonicEmbedding()\n", + " (fc): FullyConnectedNet[10, 32, 6144]\n", + ")" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "layer" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "34c3578d-0017-4350-99d2-373ad2e542cb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(
, )" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAgMAAAGuCAYAAAANsQX6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABQeUlEQVR4nO3deViU9f4+8HvYUUABFUFBdsQNFRVxF5Cl0vLUsWOLZXU8uZVlplmpqYlaapiWWp1osbLdMtlFcAVBRGWdAQRRQQERZWfm+f3B1/kdExUU5pnhuV/Xda5zBTPDjR/R2888z+ctEwRBABEREUmWntgBiIiISFwsA0RERBLHMkBERCRxLANEREQSxzJAREQkcSwDREREEscyQEREJHEsA0RERBLHMkBERCRxLANEREQSxzJAREQkcSwDREREEscyQEREJHEsA0RERBLHMkBERCRxLANEREQSxzJAREQkcSwDREREEscyQEREJHEsA0RERBLHMkBERCRxLANEREQSxzJAREQkcSwDREREEscyQEREJHEsAySaxsZGPPTQQ0hISBA7CpGoEhIS8NBDD6GxsVHsKCRRLAMkmu3btyMqKgqWlpZiRyESVffu3REVFYVPPvlE7CgkUTJBEASxQ5D0XLlyBW5ubpg5cyY+/fRTseMQie7ll1/Gnj17IJfL0aNHD7HjkMRwZ4BEsWLFCshkMqxevVrsKERaYc2aNRAEAStWrBA7CkkQywBpXHp6Onbt2oVVq1ahZ8+eYsch0go9e/bEqlWrsHPnTpw+fVrsOCQxfJuANEoQBPj5+aGkpASnT5+GoaGh2JGItEZjYyOGDBkCW1tbxMXFQSaTiR2JJII7A6RRv/76Kw4ePIgtW7awCBD9jaGhITZv3oz4+Hj89ttvYschCeHOAGlMXV0dPD09MXDgQOzbt0/sOERa6+GHH0ZmZiaysrJgYmIidhySAO4MkMZs3rwZxcXF2Lx5s9hRiLTazZ+VLVu2iB2FJII7A6QRFy5cgIeHB15++WV8+OGHYsch0nqLFy/Gzp07kZubCzs7O7HjUCfHMkAaMWvWLERGRkIul6Nbt25ixyHSepWVlXB3d0dISAi++uorseNQJ8e3CajDHT9+HN988w3ef/99FgGiVurevTvWrl2Lr7/+GklJSWLHoU6OOwPUoVQqFXx9fdHQ0ICUlBTo6+uLHYlIZyiVSnh7e8PExARHjx6Fnh7//UYdg7+zqEPt3r0bycnJCAsLYxEgaiN9fX2EhYUhKSkJ3333ndhxqBPjzgB1mBs3bsDd3R3jx4/Hnj17xI5DpLNmzJiBI0eOICcnB2ZmZmLHoU6IOwPUYUJDQ3H16lVs3LhR7ChEOm3jxo2oqKjA+vXrxY5CnRTLAHWI/Px8bNq0CUuWLEG/fv3EjkOk0xwdHfHGG2/gww8/REFBgdhxqBPi2wTUIR5//HEkJSUhJycHXbt2FTsOkc6rrq6Gh4cHRo8ejZ9//lnsONTJcGeA2l18fDx+/fVXbNy4kUWAqJ107doVGzZswC+//IKDBw+KHYc6Ge4MULtqamqCt7c3zMzMcPjwYU5dI2pHgiBg7NixqK6uxsmTJ3mHDrUb7gxQu/r8889x+vRphIWFsQgQtTOZTIawsDCcPn0an3/+udhxqBPhzgC1m6tXr8LNzQ1Tp07Fl19+KXYcok7r+eefx759+yCXy2FpaSl2HOoEuDNA7ea9995DfX091q1bJ3YUok4tNDQU9fX1WL16tdhRqJNgGaB2kZWVhe3bt+Odd96Bra2t2HGIOjVbW1u8/fbb2LZtG7Kzs8WOQ50A3yagByYIAkJCQqBQKJCRkQFjY2OxIxF1enV1dRg4cCA8PDywf/9+seOQjuPOAD2w/fv3IyoqCps2bWIRINIQExMTbNq0CRERESwD9MC4M0APpKGhAYMGDYKDgwNiYmJ4BwGRBgmCgICAABQXF+PMmTMwMjISOxLpKO4M0AP5+OOPkZeXh48++ohFgEjDZDIZPvroIygUCmzbtk3sOKTDuDNA9+3y5ctwc3PDs88+yz+IiEQ0f/587N69G7m5uejVq5fYcUgHsQzQfZszZw5+/vlnyOVyWFtbix2HSLLKy8vh5uaGf/7zn9i5c6fYcUgH8W0Cui9paWn4/PPPsXr1ahYBIpFZW1vjvffew2effYZTp06JHYd0EHcGqM0EQcDEiRNRXl6O9PR0GBgYiB2JSPIaGxvh5eWFnj174uDBg7yGh9qEOwPUZj/99BMOHTqEjz76iEWASEsYGhpiy5YtSExM5IhjajOWAWqT2tpaLFmyBNOmTcOUKVPEjkNE/yMoKAhTp07FkiVLUFtb2+rnbd++HY6OjjAxMYGPjw+Sk5M7MCVpI5YBapMPP/wQly5dwqZNm8SOQkQt2LRpEy5evNjqn9E9e/bg9ddfx8qVK3Hy5El4eXkhKCgIly9f7uCkpE1YBqjVzp8/j9DQULz22mtwdXUVOw4RtcDNzQ2LFi1CaGgoiouL7/n4zZs349///jdmz56NAQMGYMeOHejSpQv++9//qh9TVFSEp556CpaWlrCyssLTTz+Nq1evduS3QRrGMkCttmzZMlhYWODtt98WOwoR3cU777wDMzMzLFu27K6Pa2hoQGpqKgICAtQf09PTQ0BAAI4dOwYAUCgU8Pb2hqurK44fP46YmBgoFAosWbKkQ78H0iyWAWqVo0eP4rvvvsO6detgYWEhdhwiugsLCwusW7cOu3fvVv+l3pKysjIolUrY2Njc8nEbGxuUlJQAAObNm4d58+Zh9erV8PDwgLe3N958800cOHCgQ78H0izeWkj3pFKp4OPjA0EQkJycDD09dkgibadUKjFq1Cjo6+vj+PHjLf7cXrx4EX369MHRo0fh6+ur/vibb76JhIQE/Pjjj3B0dISpqektz1cqlbC3t0dubq5GvhfqeLwvjO7p66+/RkpKCg4fPswiQKQj9PX1ERYWhvHjx+Obb77Bc889d9tjevToAX19fZSWlt7y8dLSUvTu3Rvp6emwsrJCUlLSbc81NTXtsOykedwZoLuqqqqCu7s7/Pz88N1334kdh4jaaObMmTh48CByc3Nhbm5+2+d9fHwwatQofPzxxwCadwIdHBywYMECeHl54dFHH0VlZSW6dOmi6eikQfxnHt3VunXrUFVVhQ0bNogdhYjuw/r161F5rRLvv/9+i59//fXX8dlnn+Grr75CVlYW5s6di+rqasyePRs+Pj6wsLDArFmzkJ6eDoVCgcjISCxatEiz3wR1OL5NQHeUl5eHLVu2YPny5bC3txc7DhG10flr5xFVFgWfGT7YvGUz5syZA2dn51se8+STT+LKlStYsWIFSkpKMHToUERGRqovKty/fz+WLl2KCRMmQBAEuLm5tfiWA+k2vk1AdzR9+nSkpqYiOzubW4REOuRa3TXE5sfizOUzAABZowyfPv8pxviMwa+//ipyOtJG3BmgFsXGxuL333/HDz/8wCJApCMalY04ev4oDhcdRqOqETLIMNx2OPyc/OCxyQMzZ85EXFwc/P39xY5KWoY7A3SbpqYmDB06FJaWlkhMTOT0MyItJwgCMq5kICYvBtfqrwEA+nXrh2DXYNia26ofM378eFy7dg1paWkcMka34O8Gus3OnTuRmZmJEydOsAgQablL1y8hQhGBomtFAIBuxt0Q6BKIAT0H3PLzK5PJEBYWhpEjR2LXrl2YN2+eWJFJC3FngG5RUVEBNzc3TJ8+HZ9//rnYcYjoDqobqhFXEIe0S2kQIMBQzxDjHMZhjP0YGOob3vF5L774Ivbu3Yvc3FxYWVlpMDFpM5YBusUrr7yC8PBwyOXy244oJSLxKVVKJF1IQsK5BNQr6wEAg3sNRoBzALqZdLvn80tKSuDu7o7Zs2cjLCyso+OSjmAZILWMjAx4eXlh/fr1eOONN8SOQ0T/QxAEyCvkiFJEoby2HABgZ26HENcQ2Hdr262/H3zwAd566y2cPn0aAwYM6Ii4pGNYBghA8x80QUFBKCgoQEZGBoyMjMSORET/50r1FUTlRUFRoQAAmBmZIcA5AF42Xvd1XU99fT0GDhwIZ2dnREVF8dog4gWE1OzPP/9ETEwM/vjjDxYBIi1R21iLhMIEJF9IhkpQQV+mD197X4x3GA9jA+P7fl1jY2Ns3rwZjz76KPbt24epU6e2Y2rSRdwZINTX12PQoEFwdnZGZGQk/5VAJDKVoELqxVTEn4tHTWMNAKB/j/4IdAmElWn7XPTH3UD6X9wZIGzduhUFBQXYu3cviwCRyAquFiBSEYnS6uZJgr269kKwazCcLZ3v8cy2kclk2LJlC7y8vLB161ZeJyRx3BmQOF5ZTKQdrtZeRXReNLLKsgAApgammOw0GSPsRkBP1nEz5XgHEQEsA5L34osv4vfff4dcLuc9x0QiqG+qx+GiwzhWfAxNqiboyfQw0m4kJjlOgqmhaYd/fZ4tQgDLgKSlpqZi5MiR2LZtG08jI9IwQRBwuvQ0YvNjcb3hOgDAxdIFQa5B6NW1l0azbN++HQsXLkRKSgqGDx+u0a9N2oFlQKJ4TjmReIqrihEhj8CF6xcAAFamVghyCYK7tbso1+1wHgnxbwCJ2rNnD44cOYLY2FgWASINqaqvQmx+LE6XngYAGOsbY0K/CfDp6wMDPfF+Dg0MDBAWFoaAgAD8+OOPePLJJ0XLQuLgzoAE1dTUwMPDAyNHjuRscyINaFQ24ljxMRwqPKQeLTy091D4O/vDzMhM7Hhq06dPR2pqKrKzszm6XGJYBiRo1apVCA0NRVZWFpyd2/d2JSL6/wRBQFZZFqLzolFZVwkAsLewR4hbCOzM7cQN14K8vDwMGDAAy5cvx8qVK8WOQxrEMiAxRUVF8PDwwKJFixAaGip2HKJOq+RGCSIVkThXeQ4AYGFsgUCXQAzsOVCr35NftmwZtm7dipycHNjbt23mAekulgGJmTlzJg4ePIjc3FyYm5uLHYeo06luqEb8uXikXkyFAAEGegYY5zAOY+3H3nW0sLa4fv063N3dMXnyZHz33XdixyENYRmQkEOHDmHChAkIDw/Hc889J3Ycok5FqVIi+UIyEgoTUNdUBwAY1GsQpjhPadVoYW0SHh6O2bNn49ChQxg3bpzYcUgDWAYkQqlUYuTIkTAwMMDx48ehp9dxJ5oRSY28XI6ovCiU1ZQBAGzNbBHsGox+3fuJnOz+qFQq+Pj4QKVS4cSJE/zzQgJ4T5lEhIeHIy0tDUePHuUPNlE7KaspQ5QiCvIKOQCgq2FX+Dv7Y2jvoR16hHBH09PTQ1hYGMaOHYvw8HC88MILYkeiDsadAQmoqqqCm5sbpkyZgm+//VbsOEQ6r66pDgnnEpB0IUk9Wtinrw8m9JsAEwMTseO1m6effhqxsbGQy+WwsLAQOw51IJYBCXjzzTexfft25OTkoG/fvmLHIdJZKkGFtEtpiCuIU48W9rD2QKBLIKy7WIucrv0VFxfDw8MDCxYswIYNG8SOQx2IZaCTk8vlGDhwIFasWIF33nlH7DhEOutc5TlEKiJRcqMEANCzS08EuQbB1cpV5GQda82aNVizZg0yMjLg5uYmdhzqICwDndy0adNw+vRpZGVlwdS04yegEXU2lXWViMmLQcaVDACAiYEJJjs2jxbW19MXOV3Hq62tRf/+/TF06FDs3btX7DjUQXgBYScWFRWFP//8Ez/++COLAFEbNSgbcLjoMI6eP4omVRNkkGGE3QhMdpqMLobSOarX1NQUH3zwAZ588klER0cjMDBQ7EjUAbgz0Ek1NjbCy8sLPXv2xMGDB7X6xDMibSIIAs5cPoPY/FhU1VcBAJy6OyHYNRg2ZjYipxOHIAiYOHEiysvLkZ6ezuFmnRBXtJPasWMHsrOz8d1337EIELXShaoLiFBEoLiqGABgaWKJINcgeFh7SPrnSCaTISwsDN7e3tixYwcWLFggdiRqZ9wZ6ITKysrg5uaGGTNmYOfOnWLHIdJ61+uvI64gDqdKTgEAjPSNMKHfBIzuO1rU0cLaZs6cOfj5558hl8thbd357p6QMpaBTmj+/PnYvXs35HI5evbsKXYcIq3VpGrC8eLjSCxMRIOyAQCaRws7+cPcmLM7/u7y5ctwc3PDs88+i23btokdh9qR7h6R1UYXLlzAM888A2tra5iammLw4MFISUm57XHr16+HTCbDokWLOiTH9u3b4ejoCBMTE/j4+CA5ObldX//MmTPYsWMHVqxYwSJAdAeCICDrSha2J29HbH4sGpQN6GvRF/8e/m881v8xFoE76NWrF1asWIFPP/0UZ86cafXzEhMTMXXqVNjZ2UEmk+H333/vuJB0XyRRBq5evYqxY8fC0NAQERERyMzMxKZNm2BpaXnL406cOIGdO3diyJAh93zNI0eOoLGx8baPZ2ZmorS0tMXn7NmzB6+//jpWrlyJkydPwsvLC0FBQbh8+fL9fWN/IwgCFi1aBFdXV76nR3QHpTdK8XX619iTsQdX667C3Mgc//D8B14c9iL6WPQRO57WW7hwIVxcXPDaa6+htRvL1dXV8PLywvbt2zs4Hd03QQKWLl0qjBs37q6PuX79uuDm5ibExMQIEydOFF599dU7PlapVApeXl7CE088ITQ1Nak/np2dLdjY2AgbNmxo8XmjRo0S5s+ff8vr2NnZCaGhobc8rrCwUJg5c6bQvXt3wdLSUnjqqaeEioqKe36fv/32mwBA+Ouvv+75WCKpqW6oFvbl7BNWxa8SVsavFNYkrBHi8uOE+qZ6saPpnH379gkAhN9//73NzwUg/Pbbb7d9/MyZM0JISIhgbm4u2NjYCK+//rpQX8+10RRJ7Az88ccfGDFiBP75z3+iV69eGDZsGD777LNbHjN//nw8/PDDCAgIuOfr6enpYf/+/UhLS8OsWbOgUqmQl5cHPz8/PPbYY3jzzTdve05DQwNSU1NveX09PT0EBATg2LFj6o8pFAp4e3vD1dUVx48fR0xMDBQKBZYsWXLXTHV1dVi8eDFCQkLw0EMP3fN7IJIKpUqJpOIkbE3aihMXT0CAgIE9B2LBqAXwc/KDkb6R2BF1zkMPPYTg4GC8/vrrqK+vf+DXS0tLw5gxYzB8+HCcPHkSP/zwA77//nsegaxJYrcRTTA2NhaMjY2Ft956Szh58qSwc+dOwcTERAgPDxcEQRC+//57YdCgQUJtba0gCMI9dwZuKiwsFBwcHIQnn3xScHBwEGbNmiWoVKoWH3vhwgUBgHD06NFbPr5kyRJh1KhR6v+eMmWKsGLFilse8/PPPwtOTk53zRIaGioYGBgIWVlZ98xNJBWKcoWwLWmbsDJ+pbAyfqXw6YlPhYKrBWLH6hQyMzMFAwMDYf369W16HlrYGfD29hbmzZt3y8eWL19+y5+N1LEkcc+MSqXCiBEjsG7dOgDAsGHDcPbsWezYsQN+fn549dVXERMTAxOTtk0bc3BwwDfffIOJEyfC2dkZX3zxxQPdi1xYWIiYmBgcPnwYmzZtUn9cqVTC3t7+js+7ePEi1q5diwULFqB///73/fWJOovymnJE50UjpzwHANDFsAv8nfwxzHaYTo8W1iaenp6YP38+1q5di1mzZsHW1va+Xic7Oxupqam3TVQ1MjJql10Hah1JlAFbW1sMGDDglo95enril19+QWpqKi5fvozhw4erP6dUKpGYmIht27ahvr4e+votnz9eWlqKOXPmYOrUqThx4gRee+01fPzxxy0+tkePHtDX17/t4sLS0lL07t0bAJCeng4rKyskJSXd9vy7HSe8fPlymJqaYsWKFXd8DJEU1DXVIbEwEUnFSVAKSujJ9ODTxwcTHSd2qtHC2mLlypX49ttvsXz5cnz55Zf39RoZGRkwNDSEu7v7LR/PzMzE4MGD2yMmtYIkysDYsWORk5Nzy8dyc3PRr18/+Pv733aLzOzZs9G/f38sXbr0jkWgrKwM/v7+8PT0xE8//YTc3FxMmjQJxsbG+PDDD297vJGREby9vREXF4fHHnsMQPOORVxcnPrKf0NDQ1y/fh12dnbo0qV1Z5+fOHECX331FXbs2HHb3RFEUqESVDhVcgpx+XGobqwGALhZuSHINQg9uvQQOV3nZWlpibVr12Lu3LmYN28eRo4c2ebXMDc3h1KpRGNjI4yNjQEABQUF+O233/DHH3+0d2S6E7Hfp9CE5ORkwcDAQHj//fcFuVwu7N69W+jSpYvw7bfftvj41txNMGLECOGhhx665WrXU6dOCVZWVsLmzZtbfN4PP/wgGBsbC+Hh4UJmZqYwZ84coXv37kJJSYkgCIJQXl4uWFtbC48//rhw6tQpQS6XCxEREXfMolKpBO+R3kI/j3633NVAJCWFlYXCjhM71NcFfJz0sZBblit2LMloamoSPAd6Cj6jfe54zdT169eFtLQ0IS0tTQAgbN68WUhLSxMKCwuFyspKwcrKSli0aJGQl5cnxMXFCZ6ensKzzz6r4e9E2iRRBgRBEP78809h0KBBgrGxsdC/f39h165dd3xsay4gjI6OVl9w+L9OnjwpnD9//o7P+/jjjwUHBwfByMhIGDVqlHD8+PFbPp+UlCRMmjRJsLCwEMzNzYXhw4cLYWFhLb7Wt99+KwAQbObbCI9894hwpvTMXTMTdSaVtZXCTxk/qUtA6KFQ4dj5Y0KTksVYU26uwazNswQAwu7du1t8XHx8vADgtv8999xzgiAIQmJiojB8+HDBxMREcHZ2FkJDQ/kPHA3jccQ66saNG/Dw8ICBgwH0ntSDAAEGegYIdg3G+5PfRzfTbmJHJOoQjcpGHDl/BEeKjqBR1QgZZPC288Zkx8noatRV7HiS0KBswJGiIzhy/oh6vHN0aDQKMwqRk5ODrl25DrqGZUBHvfvuu/jggw+QlZWFAhRg1cFVKLpWBAAwMzLDS8NfwsKRC+94zQORrhEEARlXMhCdF60eLezY3RHBrsHobdZb5HTSIAgCzl4+i5j8mNvWoOZyDQYMGIA333wTq1evFjkptRXLgA46d+4cPD09sXjxYqxduxZA8x0Qn6Z+ip0pO3G94ToAoK9FX7w74V0EuQaJGZfogV28fhGRikh14e1u0h2BLoHw7OEp6dHCmtTSGgS5BKF/j/7qNXj77bexefNmZGdno1+/fmLGpTZiGdBBM2bMwJEjR5CTkwMzM7NbPnet9hpWHFyBv+R/qbfvvO28sSFgA1ysXERKTHR/bjTcQFx+82hhAQIM9Qwxvt94+Pb1haG+odjxJOHva2Ckb4TxDuPha+9723jnGzduwN3dHePHj8eePXtESkz3g2VAxyQkJGDSpEn45ptv8Mwzz9zxcVlXsrA8bjnSStIAAIZ6hnjU41GsmrwKZkZmd3wekTZoUjUhqTgJiYWJqFc2HzzjZeMFf2d/WBhbiJxOGu53Db755hvMmjULCQkJmDBhgqbi0gNiGdAhSqUS3t7eMDExwdGjR6Gnd++T1Pbl7sP7ie/j0o1LAIBuxt2wYNQCvDD0BV5PQFpHEATkluciKi8KFbUVAIA+5n0Q4haCvhZ9RU4nDYIgIKc8B9F50fe1BiqVCr6+vmhoaEBKSgr/nNERLAM6ZNeuXfjPf/6D48ePw8fHp9XPUyqV2HR8E8JPhaOmsQYA4NzdGasmrcIERzZ30g6Xqy8jShGFvKt5AABzI3MEOAdgiM0QXhegIZerLyNSEYn8q/kA7n8Njh8/Dl9fX+zatQv//ve/OyoutSOWAR1RWVkJd3d3hISE4Kuvvrqv16iorcCy2GWIy4+DUlBCJpNhTN8x2BiwEX26cY47iaO2sRbx5+KRcjEFKkEFfZk+xtiPwTiHcTA2MBY7niT8fQ0M9AzUa3C/Ux1nzZqFyMhIyOVydOvGW521HcuAjli8eDF27tyJ3Nxc2NnZPdBrnbp0Cm/FvYWssiwAgLG+MWYMnIF3J77Lca6kMSpBhZSLKYgviEdtUy0AwLOHJwJdAmFpyqO1NaGlNRjQcwCmOE954DW4cOECPDw88PLLL7d4RDtpF5YBHZCTk4NBgwZh9erVeOutt9rtdfec2YMPjn2AspoyAIB1F2ss9l2MpwY/1W5fg6gl+VfzEamIxOXqywAAm642CHYNhpOlk8jJpEMTa7Bu3TqsXLkSZ8+ehYeHR7u9LrU/lgEd8PDDDyMrKwuZmZltHrN8Lw3KBoQeCsX3Z79HXVMdAKB/j/543+99eNt5t+vXIqqorUB0XjSyy7IBNI8Wnuw4Gd523hwtrCEtrYGfkx+G2w5v9zWoq6uDp6cnBg4ciH379rXra1P7YhnQchEREXjooYfwyy+/4B//+EeHfZ0L1y5g+YHlSCxKhCAI0JPpwd/JH+v81qGnWc8O+7okDfVN9ThUdAjHzh9TjxYe1WcUJvabCFPDO4/npvYj1hr88ssveOKJJxAREYHg4OAO+zr0YFgGtFhjYyMGDx4MOzs7xMXFaeSK6iNFR7AifoX6im5TQ1M87/U83vB9g7cIUZsJgoD00nTE5sfiRsMNAICLpQuCXYPRsytLpiYIgtA83rkgTr0GrlauCHIJ0sgaCIIAPz8/lJSU4PTp0zA05GFR2ohlQIt99NFHWLx4MdLS0jBkyBCNfV2lUonw9HB8nPwxKusqAQA2ZjZ4e9zbmNZ/msZykG47f+08IhQRuHj9IgDA2tQaQa5BcLNy462CGqIta5Ceno7hw4dj8+bNePXVVzX2dan1WAa01JUrV+Dm5oannnoKn3zyiSgZahtqsfLgSvyW/RsaVY0Amk8gWx+wHp49PUXJRNqvqr4KMXkxOHP5DIDmu1UmOk6ETx8f6Otxd0kTWlqDSY6TMKrPKNHWYO7cufj+++8hl8vRsyd3hbQNy4CWevnll7Fnzx7I5XL06NFD1CwFVwuwLHYZki8kq0clh7iGYO3ktRyVTGqNykYcPX8Uh4sOq0cLD7MdBj8nPx6BrSEtrcFw2+Hwc/ITfbzzzX/gzJw5E59++qmoWeh2LANaSFu31GLyYrA6YTXOV50H0Hw62RzvOZg3Yh6vJ5AwQRCQeSUT0XnRuFZ/DQDQr1s/BLsGw9bcVuR00qArayDWW590bywDWkbbL7ZRKpX4JOUT7ErdpR6VbG9hjxUTV2CKyxSR05GmXbp+CZGKSBReKwTQPPsi0CUQA3oO4HUBGqJLa9DY2IghQ4bA1tZWYxdFU+uwDGiZm7fhREZGIigoSOw4d3St9hreiX8HEYoI9ajkUX1GYX3Aeh4cIwHVDdU4UHAAJy+dVI8WHucwDmPsx3C0sIZUN1QjriAOaZfSdGoNIiMjERIS0uG3S1PbsAxokdraWgwYMACDBg3Cn3/+KXacVjl7+Szejnsb6aXpAJpHJf/D8x9YNXEVTI14/3hno1QpkXQhCQnnEtRjbQf3GowA5wB0M+H1I5rQ0hoMsRmCAOcAnRnv/MgjjyAjIwNZWVntfpAa3R+WAS3y/vvv47333sPZs2fh7u4udpw2+S3rN6w/sh6lN0oBAN1NumPhqIV43ut5Xk/QCQiCAHmFHFGKKJTXlgMA7MztEOwaDIduDiKnk4Y7rUGIawjsu9mLnK5tbh6x/t5772H58uVixyGwDGiNCxcuwN3dHXPnztXZoR5KpRIbj27E16e/Rm1j89ATF0sXrJ68GmMdxoqcju7XleoriMqLgqJCAQAwMzKDv5M/hvYeyvd8NaSlNQhwDoCXjZfOrkF7Dl+jB8cyoCU607jPKzeuYPmB5YgriINKUEEmk2GCwwSs81vHUck6pLaxFgmFCUi+kKweLexr74vxDuM5WlhDOvMatMdYdmo/LANa4Pjx4/D19cVnn32Gl156Sew47Sb1YirePvC2eiCKiYEJZg6aibfGv8VRyVpMJahw8tJJHCg4gJrGGgDNw6sCXQJhZWolcjppUAkqpF5MRfy5+E69Bp999hnmzJmD48ePw8fHR+w4ksYyIDKVSgVfX180NjbixIkTnfL99e/OfIdNxzahvKb5fc4eXXpgie8SPDn4SZGT0d8VXC1ApCISpdXN13707NITwa7BcLFyETmZdPx9DXp17YVg12A4WzqLnKz9KZVKjBgxAsbGxjh69Cj09Di5UiwsAyL7+uuv8dxzzyExMRHjx48XO06HaVA2YE3CGvyY8aP6CmjPHp4I9Q/FUNuh4oYjXK29iui8aGSVZQEATA1MMdlpMkbYjeBoYQ1paQ38nPw6/XjnxMRETJw4EV9//TWeffZZseNIFsuAiG7cuAF3d3eMHz8ee/bsETuORhRdK8JbsW/haPFRCIIAfZk+/J39sT5gfafa/tQVDcoGHCo8hGPFx9CkaoKeTA8j7EZgkuMkdDHsInY8SWhpDUbajcQkx0mSGe88Y8YMHD58GLm5uTAz49HVYmAZENHbb7+NzZs3Izs7G/369RM7jkYlnkvEioMrcK7yHACgi2EXPD/0eSwevbhTvlWibQRBwOnS04jNj1WfJOls6Yxg12D06tpL5HTS0NIauFi6IMg1SHJrcO7cOXh6emLx4sVYu3at2HEkiWVAJPn5+RgwYACWLl2K9957T+w4olAqlfjvqf9iW/I29Xnqtma2eHvC23jE/RGR03VexVXFiJBH4ML1CwAAK1MrBLoEwsPaQ2dvU9M1La1BkEsQ3K3dJbsGK1aswMaNG5GVlQUnJ55iqmksAyJ5/PHHkZycjOzsbHTtKu40MbHdaLiBVfGrsDdnr3pU8rDewxDqH4r+PfuLnK7zqKqvQlx+nPq0SCN9I0zsNxE+fX1goGcgcjpp+PsaGOsbY0K/CVwDANXV1fDw8MDo0aPx888/ix1HclgGRHDgwAH4+/tj9+7deOqpp8SOozXyKvKwNHYpUi+mqs9af9j9Yaz1W8sRuA+gSdWEY+eP4VDRITQoGyCDDEN7D4W/sz9/XTWkUdmIY8XHcKjwkHq0MNfgdrt378YzzzyDAwcOYPLkyWLHkRSWAQ1ramrC8OHDYW5ujsOHD0t2S/BuohRRWJ2wWr2Fam5kjv+M+A/mes/l9QRtIAgCssqyEJ0Xjcq6SgDNEyZD3EJgZ84T3zShpTVw6OaAYNdgrkELBEHA2LFjUV1djZMnT/LnXYNYBjRsx44dmDt3Lk6cOIERI0aIHUdrKZVKhCWH4b9p/8WNhhsAmv8QXTVpFfyc/EROp/1KbpQgUhGpvkDTwtgCU5ynYFCvQSygGtLSGgS6BGJgz4Fcg7s4ceIERo0ahR07duA///lPq54TGhqKX3/9FdnZ2TA1NcWYMWOwYcMGeHh4dHDazoNlQIOuXr0KNzc3TJs2Df/973/FjqMTKmorsCJ+BSLkEVAKSsggg09fH4T6h3JUcguqG6oRfy5e/VaLgZ4BxtqPxViHsTz1UUP+vgaGeoYY6zAWY+3HavVoYW0ye/Zs/Pnnn5DL5bC0tLzn44ODg/Gvf/0LI0eORFNTE5YvX46zZ88iMzNT8tdktRbLgAYtWrQIX3zxBeRyOXr37i12HJ1y5vIZLI9bjjOlZwA0X/z2+IDHsWL8Co5KRvNY2xMXT+DguYOoa6oDAAzqNQgBzgHobtJd3HASoVQpkXwhGQmFCbeswRTnKRzv3EaXLl2Cu7s7XnrpJWzZsqXNz79y5Qp69eqFhIQETJgwAQBQVFSEZcuWISIiAjKZDCEhIdi2bVuryoYUsAxoSGZmJoYMGYL3338fS5cuFTuOzvol8xdsOLIBl6svAwAsTS3xms9rmDV0lsjJxKOoUCBSEYmymjIAzbdnBrsGo193aZ1dISZ5uRxReVFcg3a0fv16vPvuuzh9+jQ8PT3b9FyFQgE3NzecOXMGgwYNgkKhgK+vL+bOnYunn34aN27cwLx58zB48GB8/vnnHfQd6BaWAQ0QBAEhISFQKBTIyMiAsbFuTxsTW4OyARuPbMTu07tR29Q8KtnNyg1r/NZgdN/RIqfTnLKaMkQpoiCvkAMAuhp2hb9z82jhznx8rTbhGnScuro6DBw4EO7u7oiIiGj181QqFaZNm4bKykocPnwYABAYGAhfX99bznT55ZdfsGTJEuTn57d7dl3EMqABf/31Fx555BH8/vvvePTRR8WO02mU3CjB8tjlOFh4UD0qeaLDRKyfsh69zTrv2zB1TXVIOJeApAtJ6rG2Pn19MKHfBJgYmIgdTxJaWoPRfUdjQr8JOj9aWJv8/vvvmD59Ov766y889NBDrXrO3LlzERERgcOHD6Nv374oLCyEo6MjTE1NbxmEpFQqYW9vj9zc3I6Kr1NYBjpYQ0MDBg0ahH79+iE6OppXEXeA5OJkvBP/DnLLm3+oTQ1M8dTgp7B03NJOddGcSlAh7VIaDhQcQHVjNQDA3dodQS5BsO5iLXI6abi5BnEFcerRwh7WHgh0CeQadABBEDBlyhScP38eZ86cgZHR3X+eFyxYgL179yIxMVF9iuEff/yB2bNnIykp6bbHm5qaok+fPh2SXdewDHSwTZs2YenSpTh16hQGDRokdpxO7etTX+OjpI9QUVsBAOjZtSeWjV2Gxwc8LnKyB1dYWYgIRQRKbpQAaB4DHewaDFcrV5GTSce5ynOIVESq14DjnTXj7Nmz8PLywgcffIDXX3+9xccIgoCFCxfit99+w8GDB+Hm5qb+XEREBB599FFUVlaiSxcO37oTloEOdPnyZbi5ueHZZ5/Ftm3bxI4jCbUNtVh7aC1+yvwJDcoGAM1XdK/zX4chNkNETtd2lXWViMmLQcaVDACAiYEJJjs2jxbW1+OBLJrANRDfSy+/hB+//xEKuQK9et0+xGnevHn47rvvsHfv3lvOFujWrRtqa2vh7u6OSZMm4d1330XXrl2hUCgQGRmJjz76SIPfhXZjGehAc+bMwc8//wy5XA5ra24halLRtSIsi1mGY8XHIKB5VHKQaxDW+q3ViVHJDcoGHCk6giPnj6BJ1QQZZBhhNwKTnSZztLCGNCgbcLjoMI6eP8o1EMnNNYg9G4uwZ8Lw9JNPY9euXbc97k5vv3755Zd4/vnnkZycjKVLl+LkyZMQBAFubm547rnn8Morr3T0t6AzWAY6SFpaGry9vbF161YsWLBA7DiSdbDgIFYeXInCa4UAADMjM8weOhuLfBZp5VGngiDgzOUziM2PRVV9FQDAqbsTgl2DYWNmI3I6aeAaiO/mGsTkxajHOysiFPjug+9w8uRJDB06VNyAnRDLQAcQBAETJ05ERUUFTp06BQMDaU8jE5tSqcSu1F34NPVT9R/ufcz74O0Jb+Mht9ZdoawJF6ouIFIRifNV5wEAliaWCHQJRP8e/XnhqYa0tAZBrkEc76xBF6ouIEIRgeKqYgD/fw2cLZwxbNgw9OjRAwcPHuR6tDOWgQ7w448/4sknn0R0dDSmTJkidhz6PzcabmDFgRX4M/dP9ahkb1tvhPqHwr2Hu2i5rtdfR1xBHE6VnALQfLrieIfx8LX3lfxYW01paQ0m9JuA0X1Hcw005Hr9dcTmx94yYvvvaxAdHY2goCD8+OOP+Oc//ylm3E6HZaCd1dbWon///hg6dCj27t0rdhxqQW5ZLt6Kewupl1IBAIZ6hpjqMRWrJ6/W6DjZJlUTjhcfR2JhovpiRy8bLwQ4B8Dc2FxjOaSspTUY2nso/J38uQYa8vcR28Dd12DatGlIT09XDyWi9sEy0M7WrFmDNWvWIDMzE66uvO1Lm+2X78fahLW4eOMigOapcvNHzsdLw17q0OsJBEFATnkOohRRuFp3FQDQ16IvQlxD0MeC9zxrAtdAfIIgILssG9F50eo1sLewR7Br8F3XQC6XY+DAgVixYgXeeecdTcXt9FgG2tH58+fh4eGBhQsXYsOGDWLHoVZQKpX4KOkjfJH2hfoQGcfujlg9aTUmOE5o9693ufoyIhWRyL/afASquZE5prhMweBeg/keqIaU3ihFpCISBZUFALgGYvj7GrR1xPabb76J7du3IycnB3379u3ouJLAMtCOnn76acTFxSE3NxcWFhZix6E2qKitwPK45YjJi1GPSh5jPwahAaFw6ObwwK9f01iDg+cO4sSFE+rRwmPsx2Ccw7hOdUqiNqtprEF8QTxSLqZwDUTS0hrcz4jtqqoquLm5YcqUKfj22287MLF0sAy0kyNHjmDcuHH44osv8MILL4gdh+7T6ZLTeCvuLfUBM8b6xnhiwBN4Z/w79zUqWalSIuViCg6eO6geqjSg5wBMcZ4CS1OOTtWEltZgYM+BmOIyheOdNaSlEdsPugZffPEFXnrpJRw5cgRjxoxpx7TSxDLQDlQqFUaNGgUASE5OvmUYBummnzJ+wsajG3Gl+goAwMrUCq/7vo5nhjzT6tfIq8hDpCISV2qaX8Omqw1C3ELg2N2xIyJTC/6+Br3NeiPYNZhroEF/H7Hd26w3QlxDHni8s1KpxKhRo6Cnp4ekpCT+ufuAWAbaQXh4OGbPno3Dhw9j7NixYsehdtKgbMCGwxuw+8xu9b9mPKw9sHbyWozsO/KOzyuvKUd0XjRyynMAAF0Mu8DfyR/DbIdxrK2GcA3EV15Tjqi8KPUAsa6GXeHn5Neua3D48GGMHz8e4eHheO6559rlNaWKZeABVVVVwd3dHX5+fvjuu+/EjkMdoORGCZbFLENCUQIEQYCeTA+THSdjvf969DTrqX5cfVM9EgsTcbz4OJSCEnoyPfj08cFEx4kcLawhXAPx1TXVIbEwEUnFSRpZg5kzZ+LgwYPIzc2FuTlvB71fLAMPaNmyZdi6dStycnJgb28vdhzqQMfPH8c78e9AUaEA0Dwq+Zkhz2DJmCU4e+Us4vLj1KOF3azcEOQahB5deogZWTJUggqnSk5xDUTU0hq4W7sj0CWwQ9egqKgI/fv3x6uvvorQ0NAO+zqdHcvAA1AoFBg4cCCWL1+OlStXih2HNCQ8LRxhSWHqe6PNjc0xym4U7LvZw9rUGsGuwXCzdrvHq1B7KbpWhAh5BC7duASgebxzkEsQ10CDCisLEamIFG0NVq1ahdDQUGRmZsLFhSOl7wfLwAN47LHHcPLkSWRnZ3NOtsTUNtRiVeIq/JL5C5pUTQAA376+CPUPhaOlo6jZpOJa3TXE5Mfg7OWzAJpHC09ynISRdiM5WlhDWhrvLMYa1NTUwMPDAyNGjMBvv/2msa/bmbAM3KfY2FhMmTIFP/zwA5588kmx45BIwk+F48u0L1FeW44eXXrAUM8QD7s/jDnD56CLEQtiR2hUNuLI+SM4UnQEjapGyCCDt503JjtORlejrmLHk4SWRmyLvQY//PADZs6cidjYWPj7+4uSQZexDNyHpqYmDB06FJaWlkhMTOSpZRIWkxeDI+ePoLdZb6RcTFFfT9DdpDue83oOj3o8ylue2okgCMi4koGYvBhcq78GoPm0yGDXYPQ26y1yOmkQBAFnL59FTH6MegKotqyBIAgYP348rl27hrS0NE6LbSP+at2HnTt3IjMzEykpKSwCBABwtnTGnOFzsE++D1+mfYmrdVcRlhSGP3L/wMJRCzHcdrjYEXXaxesXEamIRNG1IgDNZSvQJRCePTz5M6ghF69fRIQ8Qj3eWdvWQCaTISwsDCNHjsSuXbswb948sSPpFO4MtFF5eTnc3Nzwj3/8A59//rnYcUhkN3cGxtiPQaBLIACgpqEGX6R9gT9y/lBvY4+xH4OFPgtF/9eTrrnRcANx+c2jhQUIMNQzxPh+4+Hb1xeG+oZix5OEv6+Bto/YfvHFF/H7779DLpfDyspK7Dg6g2WgjRYuXIivvvoKcrkcNjY2YschkbVUBm4qrirG1qStSL6QDKD5aOPH+j+G2cNm8573e2hSNSGpOAmJhYmoV9YDaB7v7O/sDwtjzv3QhJtrkFCYcMuIbW1fg5KSEri7u+P555/H1q1bxY6jM1gG2iAjIwNeXl5Yv3493njjDbHjkBa4Wxm4KflCMrYlbUNRVfMWt7WpNV4a/hKCXIJ4PcHfCIKA3PJcROVFoaK2AgDQx7wPQtxC0NeC0+k04eZ45+i8aJ1dgw8++ABvvfUWTp8+jQEDBogdRyewDLSSIAgICgrCuXPncPbsWRgZccoZta4MAM3zK37O+hlfp3+NGw03AAD9rfvj1dGvwrOnp6biarXL1ZcRpYhC3tU8AM2jhQOcAzDEZohWvCctBS2N2NbFNaivr8egQYPg5OSEqKgoncouFu17w0dL/fnnn4iJicGff/7JIkBtpqenhxkDZyDYJRg7U3ciUhGJ7PJszN8/H5McJ2HByAWw6iLN9zdrG2ubxztfPAGVoIKBngF8+/pifL/xHC2sIbWNtYg/1zxauDOsgbGxMTZv3oxp06Zh3759mDp1qtiRtB53Blqhvr4eAwcOhIuLCyIjI9kySa21OwN/V3C1AB8d/wjppekAmo82/tegf+HpIU9r5UVZHUElqJByMQXxBfHq0cKePTwR6BLI8c4a0tIadJYR2zd3cwsKCnD27FkYGxuLHUmrSeNPnQcUFhaGc+fO4Y8//mARoHbhZOmEsJAwJJ5LxCcpn6DkRgm+PPUlIhQReHnEy5jkOEnsiB0q/2o+IhWRuFx9GUDzeOdg12A4WTqJnEw6OvsayGQybNmyBV5eXti6dSuWLFkidiStxp2Be7h5Zers2bMRFhYmdhzSMve7M/C/mlRN2H16N/Zk7EFNYw2A5qu2X/V5Fc5Wzu0ZV3QVtRWIzotGdlk2gObRwn5OfhhuO5yjhTVEamvwyiuvIDw8nHeA3QPLwD28+OKL2Lt3L+RyOSwtdXvbjNpfe5SBmypqKrDtxDYcPHcQKkEFfZk+glyD8LL3y7Aw0d5buVqjvqkeh4oO4dj5Y+qxtqP6jMLEfhNhamgqdjxJaGm8sxTWoKKiAm5ubpg+fTrPhrkLloG7SElJwahRo7B9+3bMnTtX7DikhdqzDNyUdSULYcfDkF3e/C83MyMzPDvkWfxzwD917lZEQRCQXpqO2PxY9V0UrlauCHIJQs+uPUVOJw2CIDSPFi6Ik+wafPLJJ1iwYAFOnDgBb29vseNoJZaBOxAEAePGjUNVVRXPuaY76ogyADTfihiTH4NdqbtQXlsOALC3sMeCUQvg09en3b5ORzp/7TwiFBG4eP0igObzFYJcg+Bm5cZrbzSk6FoRIhWRkl+DpqYmDBs2DN26dcOhQ4ck9b23Fv+Gu4MffvgBR48eRWxsLIsAaZyenh6CXIMw2XEyvjz1JX7N+hXnq85jaexSjOozCq/4vKK1B8BU1VchJi8GZy6fAdB88uJEx4nw6ePD0cIacq3uGmLzY7kG/8fAwAAfffQRAgICsGfPHvzrX/8SO5LW4c5AC27Oxh45ciR+/fVXseOQFuuonYG/K7lRgo+TP8bRoqPqM/qnekzFS8Ne0ppRyY3KRhw9fxSHiw6rZzIMtx0OPyc/jhbWEK7B3U2fPh2pqanIzs5Gly7a8XOjLVgGWrBq1SqEhoYiKysLzs6d62pual+aKgM3nbp0CmHJYSi4WgAAsDSxxPNDn8dU96miXU8gCAIyr2QiJj8GlXWVAIB+3foh2DUYtua2omSSmpbGO3MNbpeXl4cBAwZg+fLlWLlypdhxtArLwN8UFRXBw8MDr732GtatWyd2HNJymi4DQPP1BHtz9uKr9K/Uf/k6Wzpj0ehFGGIzRCMZbrp0/RIiFZEovFYIAOhm3A2BLoEY0HMA35fVkEvXLyFCEaEe78w1uLu33noLYWFhyM7OhoODg9hxtAbLwN/861//QkJCAnJzc2Fubi52HNJyYpSBm2oaavDZyc+wL3efekt4nMM4LBy1EL3MenXo165uqMaBggM4eemk+m2LcQ7jMMZ+DEcLa8iNhhs4UHAAaZfSuAZtcP36dbi7u2PSpEn4/vvvxY6jNVgG/sehQ4cwYcIEhIeH47nnnhM7DukAMcvATUXXivBx0sc4cfEEgOaLxR4f8Die93oeRgbte668UqVE8oVkHDx3UD1aeHCvwQhwDkA3k27t+rWoZUqVEkkXkpBwLoFrcJ/Cw8Mxe/ZsHDp0COPGjRM7jlZgGfg/SqUSI0eOhIGBAY4fP65z93OTOLShDNx07PwxbD+xHcVVxQCAHl164D/e/8EUlykP/NqCIEBeIUeUIkp9q6OduR2CXYPh0I1brZpwpzUIcQ2BfTd7kdPpFpVKBR8fH6hUKpw4cYJ/3oO3FqqFh4cjLS0NR48e5W8M0km+9r7w6eODPRl7sPvMbpTVlOH9Q+/j16xfsWj0Inj08Liv171SfQVReVFQVCgANB+C5O/kj6G9h/I9aQ1paQ0CnAPgZePFNbgPenp62Lp1K8aMGYPw8HC88MILYkcSHXcGAFy7dg3u7u4IDAzEN998I3Yc0iHatDPwv6rqqvBJyieIyYtRHz3r5+SH+SPnt3oaXW1jLRIKE5B8IVl9PLKvvS/GO4yHsQEnwGkC16BjPfPMM4iJiYFcLoeFhW4f+f2gWAYALFmyBJ988glyc3PRp08fseOQDtHWMnBTXkUewpLCcLr0NIDmoTT/GvQvPDX4qTuOSlYJKpy8dBIHCg6oByf179EfgS6BsDK10lh2KVMJKqReTEX8uXiuQQcqLi6Gh4cH5s+fj40bN4odR1SSLwNyuRwDBw7EihUr8M4774gdh3SMtpeBmw4UHMCOlB3qcbW2ZraYN3Iexvcbf8vjCq4WIFIRidLqUgBAr669EOwaDGdLnrehKVwDzVqzZg3WrFmDjIwMuLm5iR1HNJIvA9OmTcPp06eRlZUFU9POO7mLOoaulAGgeVTyN+nf4MeMH1HbVAsAGNZ7GF7xeQXdTbojOi8aWWVZAABTA1NMdpqMEXYjOuVYW210tfYq10AEtbW16N+/P4YOHYq9e/eKHUc0ki4DUVFRCA4Oxk8//YQnnnhC7Dikg3SpDNxUVlOG7cnbkVCYAJWggp5MD/YW9nDo5gBDfUOMsBuBSY6T0MWQx7VqQn1TPQ4XHcax4mNoUjVBT6aHkXYjMclxUqceLaxNfvrpJ8yYMQNRUVEIDNSNn+P2Jum7Cdzd3bFixQo8/vjjYkch0pgeXXpg5aSVyLicgS3Ht0BRoUDhtULUNdXh5REvY5LjJF6hrgGCIOB06WnE5sfiesN1AICLpQuCXIPQq2vHHhpFt3riiSfw7rvv8m0CsUMQ6Spd3Bn4X/VN9VgcvRiKCgVG2o2Evp4++lr0RbBrsNZORewMiquKESGPwIXrFwAAVqZWCHIJgru1O4sYiULSOwNEUieTydCjSw9YmVphsuNkHCs+huKqYnx+8nN42XghwDkA5sY8lru9VNVXITY/Vn13h7G+MSb0mwCfvj53vLuDSBP4u4+IoCfTw1iHsRhhN6L5vPuSNKSXpiOrLAvjHcbD196Xf1k9gEZlI44VH8OhwkPqORJDew+Fv7M/zIzMxI5HxDJARP+fubE5Hu3/KEbYjUCkIhLnq84jriAOqZdSEeQShP49+nMbuw0EQUBWWRai86LVEybtLewR4hYCO3M7ccMR/Q+WASK6TR+LPnhh2As4e/ksYvJjUFlXiT0Ze+DU3QnBrsGwMbMRO6LWK7lRgkhFJM5VngMAWBhbINAlEAN7DmShIq3DMkBELZLJZBhsMxgePTxwpOgIjpw/goLKAuxI2QFvO2/4Ofnx9sMWVDdUI/5cPFIvpkKAAAM9A4xzGIex9mM5Wpi0FssAEd2Vkb4RJjtNxjDbYYjJi0HGlQykXEzB2ctnMclxkvouBKm7Od45oTABdU11AIBBvQZhivMUjhYmrccyQESt0t2kO/458J8YVTkKkYpIXLpxCZGKSKRcTEGwazBcrVzFjigaebkcUXlRKKspA9B83HOwazD6de8ncjKi1mEZIKI26de9H/7t/W+cKjmFuPw4lNWU4dvT38Ld2h1BLkGw7mItdkSNKaspQ5QiCvIKOQCgq2FX+Ds3j3fmEcKkS1gGiKjN9GR6GG47HAN6DkBiYSKSipOQW54LRYUCo/uOxoR+E2BiYCJ2zA5T11SHhHMJSLqQpB4t7NPXp9N/39R5sQwQ0X0zMTBBoEsgvG29EZUXhdzyXBw9fxTpJenwc/LDMNthnepfyCpBhbRLaYgriFOPFvaw9kCgS6CkdkSo82EZIKIHZt3FGk8NfgqKCgUiFZEoqynDn7l/4sTFEwhxDekU752fqzyHSEUkSm6UAAB6dumJINcgSV8rQZ0HywARtRtXK1fMHTEXJy6ewMFzB1FyowRfnvoSA3sOxBSXKehu0l3siG1WWVepvosCaN4NmezYPFqYd1FQZ8EyQETtSl9PH6P7jsYQmyGIL4hHysUUZFzJQE55Dsbaj8VYh7Ew0jcSO+Y9NSgbcLjoMI6eP4omVRNkkGGE3QhMdprM8xWo02EZIKIO0cWwCx52f1h9tHFBZQESChOQVpKGKc5TMKjXIK08iU8QBJy5fAax+bGoqq8CAJ68SJ0eywARdSgbMxvM8pqF7LJsROVFobKuEr9k/YLkC8lad0b/haoLiFBEoLiqGABgaWKJINcgeFh7aGVxIWovLANE1OFkMhk8e3rCzdoNx84fw6GiQzhfdR67Unc1T+9z8hd1VPL1+uuIK4jDqZJTAJpPXZzQbwJG9x3NaY0kCfxdTkQaY6BngPH9xmNo76GIzY9Femk6TpWcQuaVTFH+8m1SNeF48XEkFiaiQdkAAFpRTog0jWWAiDTO3Ngc0z2nY2SfkYhURKK4qhix+bE4eekkAl0CO3xbXhAEZJdlIzovGlfrrgIA+lr0RYhrCPpY9Omwr0ukrVgGiEg0fS364sVhL+J06WnE5seiorYCP5z9Ac6Wzgh2DUavrr3a/WuW3ihVX9AIAOZG5pjiMgWDew3mdQEkWSwDRCQqmUwGr95e8Ozpqb6VL/9qPj498SlG9hmJSY6T2uVWvprGGvWtjjdHC4+xH4NxDuN04lZHoo7EMkBEWsFI36j5COPewxCTH4PMK5lIvpCMM6VnMMlx0n0f8qNUKZFyMQXx5+LVo4UH9ByAQJdAnTwEiagjsAwQkVaxNLXEjIEzcK7yHCLkESitLkWEIkI9KtnFyqXVr5VXkYdIRSSu1FwBAPQ2641g12A4dnfsoPREuollgIi0kmN3R/xnxH9w8tJJHCg4gCs1V/DN6W/gYe2BINcgWJla3fG55TXliM6LRk55DoDmA5D8nfw73eAkovbCMkBEWktPpocRdiMwsOdAJBQmIPlCMnLKc24ZlWxsYKx+fF1TnXqkslJQQk+mB58+PpjoOJGjhYnugmWAiLSeqaEpgl2D1UcbKyoUOHL+CNJL0+Hv5I8hNkOQXpqOuPw4VDdWAwDcrNwQ5BqEHl16iJyeSPuxDBCRzujRpQeeGfIM5OVyRCoiUV5bjr05e7E3Z+8tjwlyCYKbtZuISYl0C8sAEekcN2s3OFs6I/lCMqLyotQfD3IJwqg+ozhamKiNeCUNEekkfT19+Nr7wsvGC0DzMcK+9r4sAkT3gWWAiHRaV6Ouzf9v2FXkJES6i2WAiIhI4lgGiIiIJI5lgIiISOJYBoiIiCSOZYCIiEjiWAaIiIgkjmWAiIhI4lgGiIiIJI5lgIiISOJYBoiIiCSOZYCIiEjiWAaIiIgkjmWAiIhI4lgGiIiIJI5lgIiISOJYBoiIiCSOZYCIiEjiWAaIiIgkjmWAiIhI4lgGiIiIJI5lgIiISOJYBoiIiCSOZYCIiEjiWAaIiIgkjmWAiIhI4lgGiIiIJI5lgIiISOJYBoiIiCSOZYCIiEjiWAaIiIgkjmWAiIhI4lgGiIiIJI5lgIiISOJYBoiIiCSOZYCIiEjiWAaIiIgkjmWAiIhI4lgGiIiIJI5lgIiISOJYBoiIiCSOZYCIiEjiWAaIiIgkjmWAiIhI4lgGiIiIJI5lgIiISOJYBoiIiCSOZYCIiEjiWAaIiIgkjmWAiIhI4lgGiIiIJI5lgIiISOJYBoiIiCSOZYCIiEjiWAaIiIgkjmWAiIhI4lgGiIiIJI5lgIiISOJYBoiIiCSOZYCIiEjiWAaIiIgkjmWAiIhI4lgGiIiIJI5lgIiISOJYBoiIiCSOZYCIiEjiWAaIiIgkjmWAiIhI4lgGiIiIJI5lgIiISOJYBoiIiCSOZYCIiEjiWAaIiIgkjmWAiIhI4lgGiIiIJI5lgIiISOJYBoiIiCSOZYCIiEjiWAaIiIgkjmWAiIhI4lgGiIiIJI5lgIiISOJYBoiIiCSOZYCIiEjiWAaIiIgkjmWAiIhI4lgGiIiIJI5lgIiISOJYBoiIiCSOZYCIiEjiWAaIiIgkjmWAiIhI4lgGiIiIJI5lgIiISOJYBoiIiCSOZYCIiEjiWAaIiIgkjmWAiIhI4lgGiIiIJI5lgIiISOJYBoiIiCSOZYCIiEjiWAaIiIgkjmWAiIhI4lgGiIiIJI5lgIiISOJYBoiIiCSOZYCIiEjiWAaIiIgkjmWAiIhI4lgGiIiIJI5lgIiISOJYBoiIiCSOZYCIiEjiWAaIiIgkjmWAiIhI4lgGiIiIJI5lgIiISOJYBoiIiCSOZYCIiEjiWAaIiIgkjmWAiIhI4lgGiIiIJI5lgIiISOJYBoiIiCSOZYCIiEjiWAaIiIgkjmWAiIhI4lgGiIiIJI5lgIiISOJYBoiIiCSOZYCIiEjiWAaIiIgkjmWAiIhI4lgGiIiIJI5lgIiISOJYBoiIiCTOQOwARPv27RM7wn1Lu5SG3LJc6Mn10JDVIHacNmtUNiI3MxcA8Ne1v2CobyhyorY7eekkcstyoS/XR31Wvdhx7tsjjzwidgSSMJkgCILYIUjaZDKZ2BGIRMc/iklM3Bkg0ZWUlIgdgYhI0rgzQEREJHG8gJCIiEjiWAaIiIgkjmWAiIhI4lgGiIiIJI5lgKiNrl27hjlz5sDV1RWenp64dOmS2JHapKmpCe+//z58fX0xfPhwPPfcc4iJiRE7Vqvp+q8/kTZiGSBqo/nz5+PMmTPYuHEjCgsLUVtbCwB47bXXsG3bNpHT3duyZcvwySefwN/fH4899hjq6+vxyCOPYPbs2Tpxr7uu//oTaSWBiNrEyspKOHnypCAIgmBmZibk5eUJgiAIERERwogRI8SM1iq2trZCQkLCLR/Lz88XBgwYIGzcuFGkVK2n67/+RNqIOwNEbSQIAszNzW/7uJubG+RyuQiJ2qa6uhp9+/a95WNOTk74+OOPsWvXLpFStZ6u//oTaSOWAaI2CgkJwe7du2/7eHV1tU4crTxu3Dh89dVXt33cyckJFy9eFCFR2+j6rz+RNuJxxERtFBoaihEjRgBo/leqTCZDXV0d1qxZg+HDh4uc7t42bNiAsWPH4urVq1i4cCHc3NzQ2NiIjz/+GAMGDBA73j3p+q8/kTbiccRE90GhUGD+/PmIiYmBtbU1rl+/DgsLC+zfv1/9F5U2S0tLw5w5c5CamgojIyMolUp0794dv//+O8aOHSt2vHvS9V9/Im3DMkD0AIqKipCeng5DQ0P4+PjA0tJS7EhtkpOTg4yMDJibm8PHxwcWFhZiR2oTXf/1J9IWLANEREQSxwsIidqgrKwMGzduxPTp0+Hr6wtfX19Mnz4dH3zwAa5cuSJ2vAdy/vx5vPDCC2LHuKva2locPnwYmZmZt32urq4OX3/9tQipiHQfdwaIWunEiRMICgpCly5dEBAQABsbGwBAaWkp4uLiUFNTg6ioKJ19zzo9PR3Dhw+HUqkUO0qLcnNzERgYiKKiIshkMowbNw4//PADbG1tATSvg52dndbmJ9JmLANErTR69Gh4eXlhx44dt93CJggCXn75ZZw+fRrHjh0TKeHd/fHHH3f9fH5+PhYvXqy1f5lOnz4djY2NCA8PR2VlJRYtWoTMzEwcPHgQDg4OLANED4BlgKiVTE1NkZaWhv79+7f4+ezsbAwbNkx9PK620dPTg0wmu+uRwzKZTGv/MrWxsUFsbCwGDx4MoLmAzZs3D/v370d8fDy6du3KMkB0n3jNAFEr9e7dG8nJyXf8fHJysvqtA21ka2uLX3/9FSqVqsX/nTx5UuyId1VbWwsDg/9/NIpMJsOnn36KqVOnYuLEicjNzRUxHZFu46FDRK30xhtvqO/N9/f3v+2agc8++wwffvihyCnvzNvbG6mpqXj00Udb/Py9dg3E1r9/f6SkpMDT0/OWj98cTjRt2jQxYhF1CnybgKgN9uzZgy1btiA1NVW9Ha2vrw9vb2+8/vrrmDFjhsgJ7+zQoUOorq5GcHBwi5+vrq5GSkoKJk6cqOFkrRMaGopDhw5h//79LX5+3rx52LFjB1QqlYaTEek+lgGi+9DY2IiysjIAQI8ePWBoaChyIiKi+8cyQEREJHG8gJCIiEjiWAaIiIgkjmWAiIhI4lgGiNrg8uXLd7x9MCwsDBcvXtRworZhfiJqCcsAURuUl5dj06ZNmD9//i0fX7JkCdauXav1w4qYn4haJBBRm2RnZwt9+vQRZs+eLSiVSmHhwoWCjY2NkJ6eLna0VmF+Ivo73lpIdB/y8vLg7+8PQ0ND1NTUIDY29raT8bQZ8xPR/+LbBET3wcXFBb6+vsjLy8PIkSPh4eEhdqQ2YX4i+l8sA0RtJAgCnnnmGRw/fhwJCQnIycnBjBkz0NTUJHa0VmF+Ivo7vk1A1AZNTU146qmnkJaWhgMHDsDe3h6lpaUICAiAk5MTfv75ZxgZGYkd846Yn4hawp0BojZITk6GXC7HoUOHYG9vDwCwsbFBfHw8SkpKcOjQIZET3h3zE1FLuDNA1EaCIEAmk7X649qG+Yno71gGiIiIJI5vExAREUkcywAREZHEsQwQERFJHMsAERGRxLEMELVBbW0tDh8+jMzMzNs+V1dXh6+//lqEVK3H/ETUEt5NQNRKubm5CAwMRFFREWQyGcaNG4cffvgBtra2AIDS0lLY2dlBqVSKnLRlzE9Ed8KdAaJWWrp0KQYNGoTLly8jJycH5ubmGDt2LIqKisSO1irMT0R3wp0BolaysbFBbGwsBg8eDKD5kJt58+Zh//79iI+PR9euXbX6X6bMT0R3wp0Bolaqra2FgYGB+r9lMhk+/fRTTJ06FRMnTkRubq6I6e6N+YnoTgzu/RAiAoD+/fsjJSUFnp6et3x827ZtAIBp06aJEavVmJ+I7oQ7A0StNH36dHz//fctfm7btm2YOXMmtPldN+YnojvhNQNEREQSx50BojbIysrCl19+iezsbABAdnY25s6dixdeeAEHDhwQOd29MT8RtYQ7A0StFBkZiUcffRRmZmaoqanBb7/9hlmzZsHLywsqlQoJCQmIjo6Gn5+f2FFbxPxEdEcCEbWKr6+v8PbbbwuCIAjff/+9YGlpKSxfvlz9+WXLlglTpkwRK949MT8R3Ql3BohaqVu3bkhNTYWrqytUKhWMjY2RnJyMYcOGAQDOnj2LgIAAlJSUiJy0ZcxPRHfCawaI2kAmkwEA9PT0YGJigm7duqk/Z25ujmvXrokVrVWYn4hawjJA1EqOjo6Qy+Xq/z527BgcHBzU/11UVKQ+J18bMT8R3QkPHSJqpblz595y1O2gQYNu+XxERIRWX7zG/ER0J7xmgIiISOL4NgEREZHEsQwQERFJHMsAERGRxLEMEBERSRzLABERkcSxDBAREUkcywAREZHEsQwQERFJHMsAERGRxLEMEBERSdz/A1Y/4u/cVRGmAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "layer.tensor_product.visualize()" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "f2785a28-4d2b-4cf0-972d-60df70595580", + "metadata": {}, + "outputs": [], + "source": [ + "o = layer(atom_z, coords, edge_index)" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "4b80d729-b5e4-4cb5-802d-c27f2af32a70", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[-0.3004, -0.0939, -0.3610, ..., 0.5229, -0.7874, -1.0619],\n", + " [-0.3039, -0.4817, -0.9658, ..., 0.2480, -0.3691, -0.5077],\n", + " [-0.6431, -0.0786, -0.3443, ..., 0.2488, -0.2363, -0.1527],\n", + " ...,\n", + " [-0.2051, 0.1712, 0.3823, ..., 0.6626, -0.1357, -0.1973],\n", + " [ 0.0641, 0.1162, -1.2586, ..., 0.2460, 0.7797, 0.2671],\n", + " [ 0.5735, -0.6425, -0.1460, ..., -0.0062, 0.0575, 0.3319]],\n", + " device='cuda:0', grad_fn=)" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "o" + ] + }, + { + "cell_type": "markdown", + "id": "35dd937d-2581-4d09-954b-ac3ace509c33", + "metadata": {}, + "source": [ + "## Equivariance check\n", + "\n", + "Uses `e3nn` tooling to generate the random rotation matrix, and output as a function of rotation permutation: rotating the coordinates before passing into the layer, and rotating the transformed embeddings with the same rotation matrix." + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "c4ff86c7-e8e5-4560-83b9-fc1885fd8d13", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/kelvin/miniforge3/envs/equitriton/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "# [0, 1, 2] are necessary at the minimum, but all orders should work\n", + "layer = InteractionBlock(\n", + " 64,\n", + " [0, 1, 2, 3, 4, 5, 7],\n", + " 10,\n", + " 32,\n", + " radius_cutoff=6.0,\n", + " degree_norm=17**0.5,\n", + " sph_harm_kwargs={\"use_e3nn\": True}, # this can be toggled for comparison\n", + ").to(\"cuda\")" + ] + }, + { + "cell_type": "markdown", + "id": "f2633995-0e41-4bf2-8093-5f68a92242ae", + "metadata": {}, + "source": [ + "### Rotation check\n", + "\n", + "This performs a random rotation to the coordinates, and we check the embeddings with and without the rotation." + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "0a3a0a02-8037-4784-9771-3abf8dad7209", + "metadata": {}, + "outputs": [], + "source": [ + "rot_matrix = o3.rand_matrix()\n", + "\n", + "# dims_in doesn't actually do anything in the case where the features are scalar\n", + "dims_in = o3.Irreps(layer.atomic_irreps).D_from_matrix(rot_matrix).to(\"cuda\")\n", + "# but dims_out actually does something, since the output features/embeddings need\n", + "# to be rotated based on the same rotation matrix\n", + "dims_out = layer.output_irreps.D_from_matrix(rot_matrix).to(\"cuda\")\n", + "\n", + "# rotate coordinates before passing into the layer\n", + "rot_before = layer(atom_z @ dims_in.T, coords @ rot_matrix.T.to(\"cuda\"), edge_index)\n", + "# rotate layer output by the same rotation matrix\n", + "rot_after = layer(atom_z, coords, edge_index) @ dims_out.T\n", + "\n", + "assert torch.allclose(rot_before, rot_after, rtol=1e-7, atol=1e-4)" + ] + }, + { + "cell_type": "markdown", + "id": "bc060469-723a-4ece-9716-1b1fbc936806", + "metadata": {}, + "source": [ + "### Rotation + translation\n", + "\n", + "If all atoms are shifted by a vector, the Bessel embedding should also be the same as it works solely on interatom distances, i.e. we do not need to shift the output embedding." + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "fea72969-5f5a-410e-9a28-75346701498b", + "metadata": {}, + "outputs": [], + "source": [ + "rot_matrix = o3.rand_matrix()\n", + "# shift all coordinates by the same amount in space\n", + "trans_matrix = torch.randn(size=(1, 3), device=\"cuda\")\n", + "\n", + "# dims_in doesn't actually do anything in the case where the features are scalar\n", + "dims_in = o3.Irreps(layer.atomic_irreps).D_from_matrix(rot_matrix).to(\"cuda\")\n", + "dims_out = layer.output_irreps.D_from_matrix(rot_matrix).to(\"cuda\")\n", + "\n", + "# rotate and translate coordinates before passing into the layer\n", + "rot_before = layer(\n", + " atom_z @ dims_in.T, coords @ rot_matrix.T.to(\"cuda\") + trans_matrix, edge_index\n", + ")\n", + "# rotate layer output by the same rotation matrix\n", + "rot_after = layer(atom_z, coords, edge_index) @ dims_out.T\n", + "\n", + "assert torch.allclose(rot_before, rot_after, rtol=1e-7, atol=1e-4)" + ] + }, + { + "cell_type": "markdown", + "id": "090ad5de-1172-4bdf-83c9-92eda962a398", + "metadata": {}, + "source": [ + "## Dataset definition" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "8e448746-66f8-484f-8971-1c89c8f75135", + "metadata": {}, + "outputs": [], + "source": [ + "class LightningQM9(pl.LightningDataModule):\n", + " def __init__(\n", + " self,\n", + " root_path: str = \"./qm9_data\",\n", + " batch_size: int = 16,\n", + " train_frac: float = 0.8,\n", + " val_frac: float = 0.1,\n", + " num_workers: int = 0,\n", + " ):\n", + " \"\"\"\n", + " Custom data module for QM9 dataset.\n", + "\n", + " Parameters\n", + " ----------\n", + " root_path : str, optional (default: \"./qm9_data\")\n", + " Path to the QM9 dataset.\n", + " batch_size : int, optional (default: 16)\n", + " Number of samples in each mini-batch.\n", + " train_frac : float, optional (default: 0.8)\n", + " Fraction of data used for training.\n", + " val_frac : float, optional (default: 0.1)\n", + " Fraction of data used for validation.\n", + " num_workers : int, optional (default: 0)\n", + " Number of worker processes to use for loading data.\n", + "\n", + " Examples\n", + " --------\n", + " >>> dm = LightningQM9(root_path=\"/path/to/qm9_data\", batch_size=32)\n", + "\n", + " Attributes\n", + " ----------\n", + " dataset : QM9\n", + " Loaded QM9 dataset.\n", + " hparams : dict\n", + " Hyperparameters of the data module.\n", + "\n", + " Methods\n", + " -------\n", + " setup(stage: str)\n", + " Setup data splits for training, validation and testing.\n", + " train_dataloader()\n", + " Returns a DataLoader instance for training data.\n", + " val_dataloader()\n", + " Returns a DataLoader instance for validation data.\n", + " test_dataloader()\n", + " Returns a DataLoader instance for testing data.\n", + " \"\"\"\n", + " super().__init__()\n", + " self.dataset = QM9(root_path)\n", + " self.save_hyperparameters()\n", + "\n", + " def setup(self, stage: str):\n", + " hparams = self.hparams\n", + " num_samples = len(self.dataset)\n", + " num_train = int(num_samples * hparams[\"train_frac\"])\n", + " num_val = int(num_samples * hparams[\"val_frac\"])\n", + " num_test = ceil(\n", + " num_samples * (1 - (hparams[\"train_frac\"] + hparams[\"val_frac\"]))\n", + " )\n", + " # generate random splits\n", + " train_split, val_split, test_split = random_split(\n", + " self.dataset, lengths=[num_train, num_val, num_test]\n", + " )\n", + " self.splits = {\"train\": train_split, \"val\": val_split, \"test\": test_split}\n", + "\n", + " def train_dataloader(self):\n", + " return DataLoader(\n", + " self.splits[\"train\"],\n", + " batch_size=self.hparams[\"batch_size\"],\n", + " shuffle=True,\n", + " num_workers=self.hparams[\"num_workers\"],\n", + " )\n", + "\n", + " def val_dataloader(self):\n", + " return DataLoader(\n", + " self.splits[\"val\"],\n", + " batch_size=self.hparams[\"batch_size\"],\n", + " shuffle=False,\n", + " num_workers=self.hparams[\"num_workers\"],\n", + " )\n", + "\n", + " def test_dataloader(self):\n", + " return DataLoader(\n", + " self.splits[\"test\"],\n", + " batch_size=self.hparams[\"batch_size\"],\n", + " shuffle=False,\n", + " num_workers=self.hparams[\"num_workers\"],\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "3c15f669-04ca-418b-8f22-81993eafbce0", + "metadata": {}, + "source": [ + "## Loss and Lightning module\n", + "\n", + "Model trains optionally with a loss target that Nequip and MACE uses, which is the atom-weighted MSE. For now we're only using a single target, but can expand to use the full QM9 set of targets too." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "6affd332-fe3a-49dd-b0c2-e20fcb974a04", + "metadata": {}, + "outputs": [], + "source": [ + "class AtomWeightedMSE(nn.Module):\n", + " \"\"\"\n", + " Calculates the mean-squared-error between predicted and targets,\n", + " weighted by the number of atoms within each graph.\n", + "\n", + " From matsciml\n", + " \"\"\"\n", + "\n", + " def forward(\n", + " self,\n", + " input: torch.Tensor,\n", + " target: torch.Tensor,\n", + " atoms_per_graph: torch.Tensor,\n", + " ) -> torch.Tensor:\n", + " if atoms_per_graph.size(0) != target.size(0):\n", + " raise RuntimeError(\n", + " \"Dimensions for atom-weighted loss do not match:\"\n", + " f\" expected atoms_per_graph to have {target.size(0)} elements; got {atoms_per_graph.size(0)}.\"\n", + " \"This loss is intended to be applied to scalar targets only.\"\n", + " )\n", + " # check to make sure we are broad casting correctly\n", + " if (input.ndim != target.ndim) and target.size(-1) == 1:\n", + " input.unsqueeze_(-1)\n", + " # for N-d targets, we might want to keep unsqueezing\n", + " while atoms_per_graph.ndim < target.ndim:\n", + " atoms_per_graph.unsqueeze_(-1)\n", + " # ensures that atoms_per_graph is type cast correctly\n", + " squared_error = ((input - target) / atoms_per_graph.to(input.dtype)) ** 2.0\n", + " return squared_error.mean()\n", + "\n", + "\n", + "class EquiTritonLitModule(pl.LightningModule):\n", + " def __init__(\n", + " self,\n", + " model_class: type,\n", + " model_kwargs,\n", + " e_mean: float,\n", + " e_std: float,\n", + " lr: float = 1e-3,\n", + " weight_decay: float = 0.0,\n", + " atom_weighted_loss: bool = True,\n", + " ):\n", + " \"\"\"\n", + " Initializes the EquiTritonLitModule clas.\n", + "\n", + " Parameters\n", + " ----------\n", + " model_class : type\n", + " Th class of the model to be used.\n", + " model_kwargs : dict\n", + " Keyword argument for the model initialization.\n", + " e_mean : float\n", + " The mean of the energy values.\n", + " e_std : float\n", + " The standard deviation of the energy values.\n", + " lr : float, optional\n", + " The learning rate (default is 1e-3) for AdamW.\n", + " weight_decay : float, optional\n", + " Weight decay value (default is 0.0).\n", + " atom_weighted_loss : bool, optional\n", + " Whether to use atom-weighted loss or not (default is True).\n", + " \"\"\"\n", + " super().__init__()\n", + " self.model = model_class(**model_kwargs)\n", + " if atom_weighted_loss:\n", + " self.loss = AtomWeightedMSE()\n", + " else:\n", + " self.loss = nn.MSELoss()\n", + " self.output_head = nn.Linear(self.model.output_dim, 1)\n", + " self.save_hyperparameters()\n", + "\n", + " def configure_optimizers(self):\n", + " return AdamW(\n", + " self.parameters(),\n", + " lr=self.hparams[\"lr\"],\n", + " weight_decay=self.hparams[\"weight_decay\"],\n", + " )\n", + "\n", + " def step(self, graph: PyGGraph, stage: Literal[\"train\", \"test\", \"val\"]):\n", + " \"\"\"\n", + " Performs a single step of the training, validation or testing\n", + " process.\n", + "\n", + " Parameters\n", + " ----------\n", + " graph : PyGGraph\n", + " The input graph.\n", + " stage : Literal[\"train\", \"test\", \"val\"]\n", + " The current stage (training, testing or validation).\n", + "\n", + " Returns\n", + " -------\n", + " loss : float\n", + " The calculated loss value.\n", + " \"\"\"\n", + " g_z, z = self.model(graph)\n", + " pred_energy = self.output_head(g_z)\n", + " target_energy = graph.y[:, 12].unsqueeze(-1)\n", + " norm_energy = (target_energy - self.hparams[\"e_mean\"]) / self.hparams[\"e_std\"]\n", + " if self.hparams[\"atom_weighted_loss\"]:\n", + " loss = self.loss(pred_energy, norm_energy, torch.diff(graph.ptr))\n", + " else:\n", + " loss = self.loss(pred_energy, norm_energy)\n", + " batch_size = getattr(graph, \"batch_size\", 1)\n", + " self.log(\n", + " f\"{stage}_loss\", loss, prog_bar=True, on_step=True, batch_size=batch_size\n", + " )\n", + " return loss\n", + "\n", + " def training_step(self, batch):\n", + " loss = self.step(batch, \"train\")\n", + " return loss\n", + "\n", + " def validation_step(self, batch):\n", + " loss = self.step(batch, \"val\")\n", + " return loss\n", + "\n", + " def test_step(self, batch):\n", + " loss = self.step(batch, \"test\")\n", + " return loss" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "2d7ce968-f33c-46dd-88e3-8ad47480a9e7", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/kelvin/miniforge3/envs/equitriton/lib/python3.11/site-packages/torch_geometric/data/dataset.py:238: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " if osp.exists(f) and torch.load(f) != _repr(self.pre_transform):\n", + "/home/kelvin/miniforge3/envs/equitriton/lib/python3.11/site-packages/torch_geometric/data/dataset.py:246: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " if osp.exists(f) and torch.load(f) != _repr(self.pre_filter):\n", + "/home/kelvin/miniforge3/envs/equitriton/lib/python3.11/site-packages/torch_geometric/io/fs.py:215: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " return torch.load(f, map_location)\n" + ] + } + ], + "source": [ + "dm = LightningQM9(\"./qm9_data/\", batch_size=64)\n", + "dm.setup(\"fit\")\n", + "\n", + "train_loader = dm.train_dataloader()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "9c89a710-1e68-485c-9325-0b06c7b52b75", + "metadata": {}, + "outputs": [], + "source": [ + "values = torch.cat([sample.y[:, 12] for sample in dm.dataset])\n", + "e_mean = values.mean()\n", + "e_std = values.std()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "1738835c-76e4-4ef9-acf7-df3d5f20d449", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(tensor(-76.1160), tensor(10.3238))" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "e_mean, e_std" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "id": "ef3c9902-3673-4c88-a98e-1a4553a5990f", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/kelvin/miniforge3/envs/equitriton/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.\n", + " warnings.warn(\n", + "/home/kelvin/miniforge3/envs/equitriton/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.\n", + " warnings.warn(\n", + "/home/kelvin/miniforge3/envs/equitriton/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.\n", + " warnings.warn(\n", + "/home/kelvin/miniforge3/envs/equitriton/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.\n", + " warnings.warn(\n", + "/home/kelvin/miniforge3/envs/equitriton/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.\n", + " warnings.warn(\n", + "/home/kelvin/miniforge3/envs/equitriton/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "lit_mod = EquiTritonLitModule(\n", + " EquiTritonModel,\n", + " model_kwargs={\n", + " \"initial_atom_dim\": 64,\n", + " \"num_layers\": 3,\n", + " \"output_dim\": 48,\n", + " \"l_values\": [0, 1, 2, 5, 6],\n", + " \"edge_dim\": 10,\n", + " \"hidden_dim\": 16,\n", + " \"radius_cutoff\": 6.0,\n", + " \"degree_norm\": 37.5**0.5,\n", + " },\n", + " e_mean=e_mean,\n", + " e_std=e_std,\n", + " atom_weighted_loss=False,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "id": "1104fe2c-ada6-448b-9e3d-cab1bc491c95", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n" + ] + } + ], + "source": [ + "trainer = pl.Trainer(max_epochs=30, accelerator=\"gpu\")" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "id": "e30a0671-efa3-403b-9d8b-cc2cb63bff20", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "\n", + " | Name | Type | Params\n", + "------------------------------------------------\n", + "0 | model | EquiTritonModel | 630 K \n", + "1 | loss | MSELoss | 0 \n", + "2 | output_head | Linear | 49 \n", + "------------------------------------------------\n", + "630 K Trainable params\n", + "0 Non-trainable params\n", + "630 K Total params\n", + "2.522 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | …" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/kelvin/miniforge3/envs/equitriton/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=27` in the `DataLoader` to improve performance.\n", + "/home/kelvin/miniforge3/envs/equitriton/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=27` in the `DataLoader` to improve performance.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "592377c70fe5446c8474681ceb2ef7c0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: | …" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/kelvin/miniforge3/envs/equitriton/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...\n" + ] + } + ], + "source": [ + "trainer.fit(lit_mod, datamodule=dm)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0663a128-606c-4944-8877-73deafc30305", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/Direct evaluation of spherical harmonics.ipynb b/notebooks/Direct evaluation of spherical harmonics.ipynb new file mode 100644 index 0000000..e76694e --- /dev/null +++ b/notebooks/Direct evaluation of spherical harmonics.ipynb @@ -0,0 +1,1515 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "a575e30b-f424-46fb-a5ba-1099c5c112db", + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "from functools import cache, partial\n", + "from itertools import combinations, chain\n", + "from pathlib import Path\n", + "\n", + "import sympy\n", + "from sympy import symbols, sqrt, diff, Symbol, acos, simplify\n", + "from sympy.functions.special.spherical_harmonics import Ynm\n", + "from sympy.functions.elementary.complexes import sign\n", + "from sympy.simplify.radsimp import collect_const, collect_sqrt\n", + "from e3nn.o3._spherical_harmonics import _spherical_harmonics\n", + "import torch\n", + "\n", + "x, y, z = symbols(\"x y z\", real=True)\n", + "\n", + "# express in terms of spherical coordinates\n", + "r = symbols(\"r\", nonnegative=True)\n", + "phi, theta = symbols(\"phi theta\")\n", + "\n", + "# conversion mapping\n", + "sph_to_cart = {\n", + " \"theta\": acos(z / sqrt(x**2.0 + y**2.0 + z**2)),\n", + " \"phi\": sign(y) * acos(x / sqrt(x**2 + y**2.0)),\n", + " \"r\": sqrt(x**2.0 + y**2 + z**2),\n", + "}\n", + "\n", + "# this is used to round floats: ~1e-7 corresponds to single precision\n", + "# decimal limits, and ~1e-15 is double precision based on the number\n", + "# of bits in the fraction. More precision = potentially more terms\n", + "# in the expression!\n", + "PRECISION_TOL = 1e-8" + ] + }, + { + "cell_type": "markdown", + "id": "ed9f0d48-ce14-40e5-80b3-865453916815", + "metadata": {}, + "source": [ + "This notebook combines the `e3nn` spherical harmonic implementations with `sympy` to simultaneously refactor each order to be purely functions of $x,y,z$ - rather than recursively (or autoregressively, I guess?) through order $l$ - as well as yield pseudo-Python code that can be used to implement the Triton kernel in a relatively straightforward manner. I was unable to get the formatting perfectly behaving the way I wanted, and so the code contained in `equitriton.sph_harm.direct` submodules will be slightly different as redundant literals will be pruned to some extent by hand." + ] + }, + { + "cell_type": "markdown", + "id": "4feef347-fa61-4ea3-afce-2c004c12be51", + "metadata": {}, + "source": [ + "## `e3nn` equations\n", + "\n", + "The following equations are copied over from the `e3nn` implementation, and are used as the basis for the symbolic manipulations." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "bbc12b0d-3974-4cdf-b706-6e31eaa9c016", + "metadata": {}, + "outputs": [], + "source": [ + "x, y, z = symbols(\"x y z\")\n", + "\n", + "y_1_0 = math.sqrt(3) * x\n", + "y_2_0 = math.sqrt(3) * y\n", + "y_3_0 = math.sqrt(3) * z\n", + "\n", + "y_2_0 = math.sqrt(15) * x * z\n", + "y_2_1 = math.sqrt(15) * x * y\n", + "y2 = y**2.0\n", + "x2z2 = x**2.0 + z**2.0\n", + "y_2_2 = math.sqrt(5) * (y2 - (1 / 2) * x2z2)\n", + "y_2_3 = math.sqrt(15) * y * z\n", + "y_2_4 = (1 / 2) * math.sqrt(15) * (z**2.0 - x**2.0)\n", + "\n", + "y_3_0 = (1 / 6) * math.sqrt(42) * (y_2_0 * z + y_2_4 * x)\n", + "y_3_1 = math.sqrt(7) * y_2_0 * y\n", + "y_3_2 = (1 / 8) * math.sqrt(168) * (4.0 * y2 - x2z2) * x\n", + "y_3_3 = (1 / 2) * math.sqrt(7) * y * (2.0 * y2 - 3.0 * x2z2)\n", + "y_3_4 = (1 / 8) * math.sqrt(168) * z * (4.0 * y2 - x2z2)\n", + "y_3_5 = math.sqrt(7) * y_2_4 * y\n", + "y_3_6 = (1 / 6) * math.sqrt(42) * (y_2_4 * z - y_2_0 * x)\n", + "\n", + "y_4_0 = (3 / 4) * math.sqrt(2) * (y_3_0 * z + y_3_6 * x)\n", + "y_4_1 = (\n", + " (3 / 4) * y_3_0 * y\n", + " + (3 / 8) * math.sqrt(6) * y_3_1 * z\n", + " + (3 / 8) * math.sqrt(6) * y_3_5 * x\n", + ")\n", + "y_4_2 = (\n", + " -3 / 56 * math.sqrt(14) * y_3_0 * z\n", + " + (3 / 14) * math.sqrt(21) * y_3_1 * y\n", + " + (3 / 56) * math.sqrt(210) * y_3_2 * z\n", + " + (3 / 56) * math.sqrt(210) * y_3_4 * x\n", + " + (3 / 56) * math.sqrt(14) * y_3_6 * x\n", + ")\n", + "y_4_3 = (\n", + " -3 / 56 * math.sqrt(42) * y_3_1 * z\n", + " + (3 / 28) * math.sqrt(105) * y_3_2 * y\n", + " + (3 / 28) * math.sqrt(70) * y_3_3 * x\n", + " + (3 / 56) * math.sqrt(42) * y_3_5 * x\n", + ")\n", + "y_4_4 = (\n", + " -3 / 28 * math.sqrt(42) * y_3_2 * x\n", + " + (3 / 7) * math.sqrt(7) * y_3_3 * y\n", + " - 3 / 28 * math.sqrt(42) * y_3_4 * z\n", + ")\n", + "y_4_5 = (\n", + " -3 / 56 * math.sqrt(42) * y_3_1 * x\n", + " + (3 / 28) * math.sqrt(70) * y_3_3 * z\n", + " + (3 / 28) * math.sqrt(105) * y_3_4 * y\n", + " - 3 / 56 * math.sqrt(42) * y_3_5 * z\n", + ")\n", + "y_4_6 = (\n", + " -3 / 56 * math.sqrt(14) * y_3_0 * x\n", + " - 3 / 56 * math.sqrt(210) * y_3_2 * x\n", + " + (3 / 56) * math.sqrt(210) * y_3_4 * z\n", + " + (3 / 14) * math.sqrt(21) * y_3_5 * y\n", + " - 3 / 56 * math.sqrt(14) * y_3_6 * z\n", + ")\n", + "y_4_7 = (\n", + " -3 / 8 * math.sqrt(6) * y_3_1 * x\n", + " + (3 / 8) * math.sqrt(6) * y_3_5 * z\n", + " + (3 / 4) * y_3_6 * y\n", + ")\n", + "y_4_8 = (3 / 4) * math.sqrt(2) * (-y_3_0 * x + y_3_6 * z)\n", + "\n", + "y_5_0 = (1 / 10) * math.sqrt(110) * (y_4_0 * z + y_4_8 * x)\n", + "y_5_1 = (\n", + " (1 / 5) * math.sqrt(11) * y_4_0 * y\n", + " + (1 / 5) * math.sqrt(22) * y_4_1 * z\n", + " + (1 / 5) * math.sqrt(22) * y_4_7 * x\n", + ")\n", + "y_5_2 = (\n", + " -1 / 30 * math.sqrt(22) * y_4_0 * z\n", + " + (4 / 15) * math.sqrt(11) * y_4_1 * y\n", + " + (1 / 15) * math.sqrt(154) * y_4_2 * z\n", + " + (1 / 15) * math.sqrt(154) * y_4_6 * x\n", + " + (1 / 30) * math.sqrt(22) * y_4_8 * x\n", + ")\n", + "y_5_3 = (\n", + " -1 / 30 * math.sqrt(66) * y_4_1 * z\n", + " + (1 / 15) * math.sqrt(231) * y_4_2 * y\n", + " + (1 / 30) * math.sqrt(462) * y_4_3 * z\n", + " + (1 / 30) * math.sqrt(462) * y_4_5 * x\n", + " + (1 / 30) * math.sqrt(66) * y_4_7 * x\n", + ")\n", + "y_5_4 = (\n", + " -1 / 15 * math.sqrt(33) * y_4_2 * z\n", + " + (2 / 15) * math.sqrt(66) * y_4_3 * y\n", + " + (1 / 15) * math.sqrt(165) * y_4_4 * x\n", + " + (1 / 15) * math.sqrt(33) * y_4_6 * x\n", + ")\n", + "y_5_5 = (\n", + " -1 / 15 * math.sqrt(110) * y_4_3 * x\n", + " + (1 / 3) * math.sqrt(11) * y_4_4 * y\n", + " - 1 / 15 * math.sqrt(110) * y_4_5 * z\n", + ")\n", + "y_5_6 = (\n", + " -1 / 15 * math.sqrt(33) * y_4_2 * x\n", + " + (1 / 15) * math.sqrt(165) * y_4_4 * z\n", + " + (2 / 15) * math.sqrt(66) * y_4_5 * y\n", + " - 1 / 15 * math.sqrt(33) * y_4_6 * z\n", + ")\n", + "y_5_7 = (\n", + " -1 / 30 * math.sqrt(66) * y_4_1 * x\n", + " - 1 / 30 * math.sqrt(462) * y_4_3 * x\n", + " + (1 / 30) * math.sqrt(462) * y_4_5 * z\n", + " + (1 / 15) * math.sqrt(231) * y_4_6 * y\n", + " - 1 / 30 * math.sqrt(66) * y_4_7 * z\n", + ")\n", + "y_5_8 = (\n", + " -1 / 30 * math.sqrt(22) * y_4_0 * x\n", + " - 1 / 15 * math.sqrt(154) * y_4_2 * x\n", + " + (1 / 15) * math.sqrt(154) * y_4_6 * z\n", + " + (4 / 15) * math.sqrt(11) * y_4_7 * y\n", + " - 1 / 30 * math.sqrt(22) * y_4_8 * z\n", + ")\n", + "y_5_9 = (\n", + " -1 / 5 * math.sqrt(22) * y_4_1 * x\n", + " + (1 / 5) * math.sqrt(22) * y_4_7 * z\n", + " + (1 / 5) * math.sqrt(11) * y_4_8 * y\n", + ")\n", + "y_5_10 = (1 / 10) * math.sqrt(110) * (-y_4_0 * x + y_4_8 * z)\n", + "\n", + "y_6_0 = (1 / 6) * math.sqrt(39) * (y_5_0 * z + y_5_10 * x)\n", + "y_6_1 = (\n", + " (1 / 6) * math.sqrt(13) * y_5_0 * y\n", + " + (1 / 12) * math.sqrt(130) * y_5_1 * z\n", + " + (1 / 12) * math.sqrt(130) * y_5_9 * x\n", + ")\n", + "y_6_2 = (\n", + " -1 / 132 * math.sqrt(286) * y_5_0 * z\n", + " + (1 / 33) * math.sqrt(715) * y_5_1 * y\n", + " + (1 / 132) * math.sqrt(286) * y_5_10 * x\n", + " + (1 / 44) * math.sqrt(1430) * y_5_2 * z\n", + " + (1 / 44) * math.sqrt(1430) * y_5_8 * x\n", + ")\n", + "y_6_3 = (\n", + " -1 / 132 * math.sqrt(858) * y_5_1 * z\n", + " + (1 / 22) * math.sqrt(429) * y_5_2 * y\n", + " + (1 / 22) * math.sqrt(286) * y_5_3 * z\n", + " + (1 / 22) * math.sqrt(286) * y_5_7 * x\n", + " + (1 / 132) * math.sqrt(858) * y_5_9 * x\n", + ")\n", + "y_6_4 = (\n", + " -1 / 66 * math.sqrt(429) * y_5_2 * z\n", + " + (2 / 33) * math.sqrt(286) * y_5_3 * y\n", + " + (1 / 66) * math.sqrt(2002) * y_5_4 * z\n", + " + (1 / 66) * math.sqrt(2002) * y_5_6 * x\n", + " + (1 / 66) * math.sqrt(429) * y_5_8 * x\n", + ")\n", + "y_6_5 = (\n", + " -1 / 66 * math.sqrt(715) * y_5_3 * z\n", + " + (1 / 66) * math.sqrt(5005) * y_5_4 * y\n", + " + (1 / 66) * math.sqrt(3003) * y_5_5 * x\n", + " + (1 / 66) * math.sqrt(715) * y_5_7 * x\n", + ")\n", + "y_6_6 = (\n", + " -1 / 66 * math.sqrt(2145) * y_5_4 * x\n", + " + (1 / 11) * math.sqrt(143) * y_5_5 * y\n", + " - 1 / 66 * math.sqrt(2145) * y_5_6 * z\n", + ")\n", + "y_6_7 = (\n", + " -1 / 66 * math.sqrt(715) * y_5_3 * x\n", + " + (1 / 66) * math.sqrt(3003) * y_5_5 * z\n", + " + (1 / 66) * math.sqrt(5005) * y_5_6 * y\n", + " - 1 / 66 * math.sqrt(715) * y_5_7 * z\n", + ")\n", + "y_6_8 = (\n", + " -1 / 66 * math.sqrt(429) * y_5_2 * x\n", + " - 1 / 66 * math.sqrt(2002) * y_5_4 * x\n", + " + (1 / 66) * math.sqrt(2002) * y_5_6 * z\n", + " + (2 / 33) * math.sqrt(286) * y_5_7 * y\n", + " - 1 / 66 * math.sqrt(429) * y_5_8 * z\n", + ")\n", + "y_6_9 = (\n", + " -1 / 132 * math.sqrt(858) * y_5_1 * x\n", + " - 1 / 22 * math.sqrt(286) * y_5_3 * x\n", + " + (1 / 22) * math.sqrt(286) * y_5_7 * z\n", + " + (1 / 22) * math.sqrt(429) * y_5_8 * y\n", + " - 1 / 132 * math.sqrt(858) * y_5_9 * z\n", + ")\n", + "y_6_10 = (\n", + " -1 / 132 * math.sqrt(286) * y_5_0 * x\n", + " - 1 / 132 * math.sqrt(286) * y_5_10 * z\n", + " - 1 / 44 * math.sqrt(1430) * y_5_2 * x\n", + " + (1 / 44) * math.sqrt(1430) * y_5_8 * z\n", + " + (1 / 33) * math.sqrt(715) * y_5_9 * y\n", + ")\n", + "y_6_11 = (\n", + " -1 / 12 * math.sqrt(130) * y_5_1 * x\n", + " + (1 / 6) * math.sqrt(13) * y_5_10 * y\n", + " + (1 / 12) * math.sqrt(130) * y_5_9 * z\n", + ")\n", + "y_6_12 = (1 / 6) * math.sqrt(39) * (-y_5_0 * x + y_5_10 * z)\n", + "\n", + "y_7_0 = (1 / 14) * math.sqrt(210) * (y_6_0 * z + y_6_12 * x)\n", + "y_7_1 = (\n", + " (1 / 7) * math.sqrt(15) * y_6_0 * y\n", + " + (3 / 7) * math.sqrt(5) * y_6_1 * z\n", + " + (3 / 7) * math.sqrt(5) * y_6_11 * x\n", + ")\n", + "y_7_2 = (\n", + " -1 / 182 * math.sqrt(390) * y_6_0 * z\n", + " + (6 / 91) * math.sqrt(130) * y_6_1 * y\n", + " + (3 / 91) * math.sqrt(715) * y_6_10 * x\n", + " + (1 / 182) * math.sqrt(390) * y_6_12 * x\n", + " + (3 / 91) * math.sqrt(715) * y_6_2 * z\n", + ")\n", + "y_7_3 = (\n", + " -3 / 182 * math.sqrt(130) * y_6_1 * z\n", + " + (3 / 182) * math.sqrt(130) * y_6_11 * x\n", + " + (3 / 91) * math.sqrt(715) * y_6_2 * y\n", + " + (5 / 182) * math.sqrt(858) * y_6_3 * z\n", + " + (5 / 182) * math.sqrt(858) * y_6_9 * x\n", + ")\n", + "y_7_4 = (\n", + " (3 / 91) * math.sqrt(65) * y_6_10 * x\n", + " - 3 / 91 * math.sqrt(65) * y_6_2 * z\n", + " + (10 / 91) * math.sqrt(78) * y_6_3 * y\n", + " + (15 / 182) * math.sqrt(78) * y_6_4 * z\n", + " + (15 / 182) * math.sqrt(78) * y_6_8 * x\n", + ")\n", + "y_7_5 = (\n", + " -5 / 91 * math.sqrt(39) * y_6_3 * z\n", + " + (15 / 91) * math.sqrt(39) * y_6_4 * y\n", + " + (3 / 91) * math.sqrt(390) * y_6_5 * z\n", + " + (3 / 91) * math.sqrt(390) * y_6_7 * x\n", + " + (5 / 91) * math.sqrt(39) * y_6_9 * x\n", + ")\n", + "y_7_6 = (\n", + " -15 / 182 * math.sqrt(26) * y_6_4 * z\n", + " + (12 / 91) * math.sqrt(65) * y_6_5 * y\n", + " + (2 / 91) * math.sqrt(1365) * y_6_6 * x\n", + " + (15 / 182) * math.sqrt(26) * y_6_8 * x\n", + ")\n", + "y_7_7 = (\n", + " -3 / 91 * math.sqrt(455) * y_6_5 * x\n", + " + (1 / 13) * math.sqrt(195) * y_6_6 * y\n", + " - 3 / 91 * math.sqrt(455) * y_6_7 * z\n", + ")\n", + "y_7_8 = (\n", + " -15 / 182 * math.sqrt(26) * y_6_4 * x\n", + " + (2 / 91) * math.sqrt(1365) * y_6_6 * z\n", + " + (12 / 91) * math.sqrt(65) * y_6_7 * y\n", + " - 15 / 182 * math.sqrt(26) * y_6_8 * z\n", + ")\n", + "y_7_9 = (\n", + " -5 / 91 * math.sqrt(39) * y_6_3 * x\n", + " - 3 / 91 * math.sqrt(390) * y_6_5 * x\n", + " + (3 / 91) * math.sqrt(390) * y_6_7 * z\n", + " + (15 / 91) * math.sqrt(39) * y_6_8 * y\n", + " - 5 / 91 * math.sqrt(39) * y_6_9 * z\n", + ")\n", + "y_7_10 = (\n", + " -3 / 91 * math.sqrt(65) * y_6_10 * z\n", + " - 3 / 91 * math.sqrt(65) * y_6_2 * x\n", + " - 15 / 182 * math.sqrt(78) * y_6_4 * x\n", + " + (15 / 182) * math.sqrt(78) * y_6_8 * z\n", + " + (10 / 91) * math.sqrt(78) * y_6_9 * y\n", + ")\n", + "y_7_11 = (\n", + " -3 / 182 * math.sqrt(130) * y_6_1 * x\n", + " + (3 / 91) * math.sqrt(715) * y_6_10 * y\n", + " - 3 / 182 * math.sqrt(130) * y_6_11 * z\n", + " - 5 / 182 * math.sqrt(858) * y_6_3 * x\n", + " + (5 / 182) * math.sqrt(858) * y_6_9 * z\n", + ")\n", + "y_7_12 = (\n", + " -1 / 182 * math.sqrt(390) * y_6_0 * x\n", + " + (3 / 91) * math.sqrt(715) * y_6_10 * z\n", + " + (6 / 91) * math.sqrt(130) * y_6_11 * y\n", + " - 1 / 182 * math.sqrt(390) * y_6_12 * z\n", + " - 3 / 91 * math.sqrt(715) * y_6_2 * x\n", + ")\n", + "y_7_13 = (\n", + " -3 / 7 * math.sqrt(5) * y_6_1 * x\n", + " + (3 / 7) * math.sqrt(5) * y_6_11 * z\n", + " + (1 / 7) * math.sqrt(15) * y_6_12 * y\n", + ")\n", + "y_7_14 = (1 / 14) * math.sqrt(210) * (-y_6_0 * x + y_6_12 * z)\n", + "\n", + "y_8_0 = (1 / 4) * math.sqrt(17) * (y_7_0 * z + y_7_14 * x)\n", + "y_8_1 = (\n", + " (1 / 8) * math.sqrt(17) * y_7_0 * y\n", + " + (1 / 16) * math.sqrt(238) * y_7_1 * z\n", + " + (1 / 16) * math.sqrt(238) * y_7_13 * x\n", + ")\n", + "y_8_2 = (\n", + " -1 / 240 * math.sqrt(510) * y_7_0 * z\n", + " + (1 / 60) * math.sqrt(1785) * y_7_1 * y\n", + " + (1 / 240) * math.sqrt(46410) * y_7_12 * x\n", + " + (1 / 240) * math.sqrt(510) * y_7_14 * x\n", + " + (1 / 240) * math.sqrt(46410) * y_7_2 * z\n", + ")\n", + "y_8_3 = (\n", + " (1 / 80)\n", + " * math.sqrt(2)\n", + " * (\n", + " -math.sqrt(85) * y_7_1 * z\n", + " + math.sqrt(2210) * y_7_11 * x\n", + " + math.sqrt(85) * y_7_13 * x\n", + " + math.sqrt(2210) * y_7_2 * y\n", + " + math.sqrt(2210) * y_7_3 * z\n", + " )\n", + ")\n", + "y_8_4 = (\n", + " (1 / 40) * math.sqrt(935) * y_7_10 * x\n", + " + (1 / 40) * math.sqrt(85) * y_7_12 * x\n", + " - 1 / 40 * math.sqrt(85) * y_7_2 * z\n", + " + (1 / 10) * math.sqrt(85) * y_7_3 * y\n", + " + (1 / 40) * math.sqrt(935) * y_7_4 * z\n", + ")\n", + "y_8_5 = (\n", + " (1 / 48)\n", + " * math.sqrt(2)\n", + " * (\n", + " math.sqrt(102) * y_7_11 * x\n", + " - math.sqrt(102) * y_7_3 * z\n", + " + math.sqrt(1122) * y_7_4 * y\n", + " + math.sqrt(561) * y_7_5 * z\n", + " + math.sqrt(561) * y_7_9 * x\n", + " )\n", + ")\n", + "y_8_6 = (\n", + " (1 / 16) * math.sqrt(34) * y_7_10 * x\n", + " - 1 / 16 * math.sqrt(34) * y_7_4 * z\n", + " + (1 / 4) * math.sqrt(17) * y_7_5 * y\n", + " + (1 / 16) * math.sqrt(102) * y_7_6 * z\n", + " + (1 / 16) * math.sqrt(102) * y_7_8 * x\n", + ")\n", + "y_8_7 = (\n", + " -1 / 80 * math.sqrt(1190) * y_7_5 * z\n", + " + (1 / 40) * math.sqrt(1785) * y_7_6 * y\n", + " + (1 / 20) * math.sqrt(255) * y_7_7 * x\n", + " + (1 / 80) * math.sqrt(1190) * y_7_9 * x\n", + ")\n", + "y_8_8 = (\n", + " -1 / 60 * math.sqrt(1785) * y_7_6 * x\n", + " + (1 / 15) * math.sqrt(255) * y_7_7 * y\n", + " - 1 / 60 * math.sqrt(1785) * y_7_8 * z\n", + ")\n", + "y_8_9 = (\n", + " -1 / 80 * math.sqrt(1190) * y_7_5 * x\n", + " + (1 / 20) * math.sqrt(255) * y_7_7 * z\n", + " + (1 / 40) * math.sqrt(1785) * y_7_8 * y\n", + " - 1 / 80 * math.sqrt(1190) * y_7_9 * z\n", + ")\n", + "y_8_10 = (\n", + " -1 / 16 * math.sqrt(34) * y_7_10 * z\n", + " - 1 / 16 * math.sqrt(34) * y_7_4 * x\n", + " - 1 / 16 * math.sqrt(102) * y_7_6 * x\n", + " + (1 / 16) * math.sqrt(102) * y_7_8 * z\n", + " + (1 / 4) * math.sqrt(17) * y_7_9 * y\n", + ")\n", + "y_8_11 = (\n", + " (1 / 48)\n", + " * math.sqrt(2)\n", + " * (\n", + " math.sqrt(1122) * y_7_10 * y\n", + " - math.sqrt(102) * y_7_11 * z\n", + " - math.sqrt(102) * y_7_3 * x\n", + " - math.sqrt(561) * y_7_5 * x\n", + " + math.sqrt(561) * y_7_9 * z\n", + " )\n", + ")\n", + "y_8_12 = (\n", + " (1 / 40) * math.sqrt(935) * y_7_10 * z\n", + " + (1 / 10) * math.sqrt(85) * y_7_11 * y\n", + " - 1 / 40 * math.sqrt(85) * y_7_12 * z\n", + " - 1 / 40 * math.sqrt(85) * y_7_2 * x\n", + " - 1 / 40 * math.sqrt(935) * y_7_4 * x\n", + ")\n", + "y_8_13 = (\n", + " (1 / 80)\n", + " * math.sqrt(2)\n", + " * (\n", + " -math.sqrt(85) * y_7_1 * x\n", + " + math.sqrt(2210) * y_7_11 * z\n", + " + math.sqrt(2210) * y_7_12 * y\n", + " - math.sqrt(85) * y_7_13 * z\n", + " - math.sqrt(2210) * y_7_3 * x\n", + " )\n", + ")\n", + "y_8_14 = (\n", + " -1 / 240 * math.sqrt(510) * y_7_0 * x\n", + " + (1 / 240) * math.sqrt(46410) * y_7_12 * z\n", + " + (1 / 60) * math.sqrt(1785) * y_7_13 * y\n", + " - 1 / 240 * math.sqrt(510) * y_7_14 * z\n", + " - 1 / 240 * math.sqrt(46410) * y_7_2 * x\n", + ")\n", + "y_8_15 = (\n", + " -1 / 16 * math.sqrt(238) * y_7_1 * x\n", + " + (1 / 16) * math.sqrt(238) * y_7_13 * z\n", + " + (1 / 8) * math.sqrt(17) * y_7_14 * y\n", + ")\n", + "y_8_16 = (1 / 4) * math.sqrt(17) * (-y_7_0 * x + y_7_14 * z)\n", + "\n", + "y_9_0 = (1 / 6) * math.sqrt(38) * (y_8_0 * z + y_8_16 * x)\n", + "y_9_1 = (1 / 9) * math.sqrt(19) * (y_8_0 * y + 2 * y_8_1 * z + 2 * y_8_15 * x)\n", + "y_9_2 = (\n", + " -1 / 306 * math.sqrt(646) * y_8_0 * z\n", + " + (4 / 153) * math.sqrt(646) * y_8_1 * y\n", + " + (2 / 153) * math.sqrt(4845) * y_8_14 * x\n", + " + (1 / 306) * math.sqrt(646) * y_8_16 * x\n", + " + (2 / 153) * math.sqrt(4845) * y_8_2 * z\n", + ")\n", + "y_9_3 = (\n", + " -1 / 306 * math.sqrt(1938) * y_8_1 * z\n", + " + (1 / 306) * math.sqrt(67830) * y_8_13 * x\n", + " + (1 / 306) * math.sqrt(1938) * y_8_15 * x\n", + " + (1 / 51) * math.sqrt(1615) * y_8_2 * y\n", + " + (1 / 306) * math.sqrt(67830) * y_8_3 * z\n", + ")\n", + "y_9_4 = (\n", + " (1 / 306) * math.sqrt(58786) * y_8_12 * x\n", + " + (1 / 153) * math.sqrt(969) * y_8_14 * x\n", + " - 1 / 153 * math.sqrt(969) * y_8_2 * z\n", + " + (2 / 153) * math.sqrt(4522) * y_8_3 * y\n", + " + (1 / 306) * math.sqrt(58786) * y_8_4 * z\n", + ")\n", + "y_9_5 = (\n", + " (1 / 153) * math.sqrt(12597) * y_8_11 * x\n", + " + (1 / 153) * math.sqrt(1615) * y_8_13 * x\n", + " - 1 / 153 * math.sqrt(1615) * y_8_3 * z\n", + " + (1 / 153) * math.sqrt(20995) * y_8_4 * y\n", + " + (1 / 153) * math.sqrt(12597) * y_8_5 * z\n", + ")\n", + "y_9_6 = (\n", + " (1 / 153) * math.sqrt(10659) * y_8_10 * x\n", + " + (1 / 306) * math.sqrt(9690) * y_8_12 * x\n", + " - 1 / 306 * math.sqrt(9690) * y_8_4 * z\n", + " + (2 / 51) * math.sqrt(646) * y_8_5 * y\n", + " + (1 / 153) * math.sqrt(10659) * y_8_6 * z\n", + ")\n", + "y_9_7 = (\n", + " (1 / 306) * math.sqrt(13566) * y_8_11 * x\n", + " - 1 / 306 * math.sqrt(13566) * y_8_5 * z\n", + " + (1 / 153) * math.sqrt(24871) * y_8_6 * y\n", + " + (1 / 306) * math.sqrt(35530) * y_8_7 * z\n", + " + (1 / 306) * math.sqrt(35530) * y_8_9 * x\n", + ")\n", + "y_9_8 = (\n", + " (1 / 153) * math.sqrt(4522) * y_8_10 * x\n", + " - 1 / 153 * math.sqrt(4522) * y_8_6 * z\n", + " + (4 / 153) * math.sqrt(1615) * y_8_7 * y\n", + " + (1 / 51) * math.sqrt(1615) * y_8_8 * x\n", + ")\n", + "y_9_9 = (1 / 51) * math.sqrt(323) * (-2 * y_8_7 * x + 3 * y_8_8 * y - 2 * y_8_9 * z)\n", + "y_9_10 = (\n", + " -1 / 153 * math.sqrt(4522) * y_8_10 * z\n", + " - 1 / 153 * math.sqrt(4522) * y_8_6 * x\n", + " + (1 / 51) * math.sqrt(1615) * y_8_8 * z\n", + " + (4 / 153) * math.sqrt(1615) * y_8_9 * y\n", + ")\n", + "y_9_11 = (\n", + " (1 / 153) * math.sqrt(24871) * y_8_10 * y\n", + " - 1 / 306 * math.sqrt(13566) * y_8_11 * z\n", + " - 1 / 306 * math.sqrt(13566) * y_8_5 * x\n", + " - 1 / 306 * math.sqrt(35530) * y_8_7 * x\n", + " + (1 / 306) * math.sqrt(35530) * y_8_9 * z\n", + ")\n", + "y_9_12 = (\n", + " (1 / 153) * math.sqrt(10659) * y_8_10 * z\n", + " + (2 / 51) * math.sqrt(646) * y_8_11 * y\n", + " - 1 / 306 * math.sqrt(9690) * y_8_12 * z\n", + " - 1 / 306 * math.sqrt(9690) * y_8_4 * x\n", + " - 1 / 153 * math.sqrt(10659) * y_8_6 * x\n", + ")\n", + "y_9_13 = (\n", + " (1 / 153) * math.sqrt(12597) * y_8_11 * z\n", + " + (1 / 153) * math.sqrt(20995) * y_8_12 * y\n", + " - 1 / 153 * math.sqrt(1615) * y_8_13 * z\n", + " - 1 / 153 * math.sqrt(1615) * y_8_3 * x\n", + " - 1 / 153 * math.sqrt(12597) * y_8_5 * x\n", + ")\n", + "y_9_14 = (\n", + " (1 / 306) * math.sqrt(58786) * y_8_12 * z\n", + " + (2 / 153) * math.sqrt(4522) * y_8_13 * y\n", + " - 1 / 153 * math.sqrt(969) * y_8_14 * z\n", + " - 1 / 153 * math.sqrt(969) * y_8_2 * x\n", + " - 1 / 306 * math.sqrt(58786) * y_8_4 * x\n", + ")\n", + "y_9_15 = (\n", + " -1 / 306 * math.sqrt(1938) * y_8_1 * x\n", + " + (1 / 306) * math.sqrt(67830) * y_8_13 * z\n", + " + (1 / 51) * math.sqrt(1615) * y_8_14 * y\n", + " - 1 / 306 * math.sqrt(1938) * y_8_15 * z\n", + " - 1 / 306 * math.sqrt(67830) * y_8_3 * x\n", + ")\n", + "y_9_16 = (\n", + " -1 / 306 * math.sqrt(646) * y_8_0 * x\n", + " + (2 / 153) * math.sqrt(4845) * y_8_14 * z\n", + " + (4 / 153) * math.sqrt(646) * y_8_15 * y\n", + " - 1 / 306 * math.sqrt(646) * y_8_16 * z\n", + " - 2 / 153 * math.sqrt(4845) * y_8_2 * x\n", + ")\n", + "y_9_17 = (1 / 9) * math.sqrt(19) * (-2 * y_8_1 * x + 2 * y_8_15 * z + y_8_16 * y)\n", + "y_9_18 = (1 / 6) * math.sqrt(38) * (-y_8_0 * x + y_8_16 * z)\n", + "\n", + "y_10_0 = (1 / 10) * math.sqrt(105) * (y_9_0 * z + y_9_18 * x)\n", + "y_10_1 = (\n", + " (1 / 10) * math.sqrt(21) * y_9_0 * y\n", + " + (3 / 20) * math.sqrt(42) * y_9_1 * z\n", + " + (3 / 20) * math.sqrt(42) * y_9_17 * x\n", + ")\n", + "y_10_2 = (\n", + " -1 / 380 * math.sqrt(798) * y_9_0 * z\n", + " + (3 / 95) * math.sqrt(399) * y_9_1 * y\n", + " + (3 / 380) * math.sqrt(13566) * y_9_16 * x\n", + " + (1 / 380) * math.sqrt(798) * y_9_18 * x\n", + " + (3 / 380) * math.sqrt(13566) * y_9_2 * z\n", + ")\n", + "y_10_3 = (\n", + " -3 / 380 * math.sqrt(266) * y_9_1 * z\n", + " + (1 / 95) * math.sqrt(6783) * y_9_15 * x\n", + " + (3 / 380) * math.sqrt(266) * y_9_17 * x\n", + " + (3 / 190) * math.sqrt(2261) * y_9_2 * y\n", + " + (1 / 95) * math.sqrt(6783) * y_9_3 * z\n", + ")\n", + "y_10_4 = (\n", + " (3 / 95) * math.sqrt(665) * y_9_14 * x\n", + " + (3 / 190) * math.sqrt(133) * y_9_16 * x\n", + " - 3 / 190 * math.sqrt(133) * y_9_2 * z\n", + " + (4 / 95) * math.sqrt(399) * y_9_3 * y\n", + " + (3 / 95) * math.sqrt(665) * y_9_4 * z\n", + ")\n", + "y_10_5 = (\n", + " (21 / 380) * math.sqrt(190) * y_9_13 * x\n", + " + (1 / 190) * math.sqrt(1995) * y_9_15 * x\n", + " - 1 / 190 * math.sqrt(1995) * y_9_3 * z\n", + " + (3 / 38) * math.sqrt(133) * y_9_4 * y\n", + " + (21 / 380) * math.sqrt(190) * y_9_5 * z\n", + ")\n", + "y_10_6 = (\n", + " (7 / 380) * math.sqrt(1482) * y_9_12 * x\n", + " + (3 / 380) * math.sqrt(1330) * y_9_14 * x\n", + " - 3 / 380 * math.sqrt(1330) * y_9_4 * z\n", + " + (21 / 95) * math.sqrt(19) * y_9_5 * y\n", + " + (7 / 380) * math.sqrt(1482) * y_9_6 * z\n", + ")\n", + "y_10_7 = (\n", + " (3 / 190) * math.sqrt(1729) * y_9_11 * x\n", + " + (21 / 380) * math.sqrt(38) * y_9_13 * x\n", + " - 21 / 380 * math.sqrt(38) * y_9_5 * z\n", + " + (7 / 190) * math.sqrt(741) * y_9_6 * y\n", + " + (3 / 190) * math.sqrt(1729) * y_9_7 * z\n", + ")\n", + "y_10_8 = (\n", + " (3 / 190) * math.sqrt(1463) * y_9_10 * x\n", + " + (7 / 190) * math.sqrt(114) * y_9_12 * x\n", + " - 7 / 190 * math.sqrt(114) * y_9_6 * z\n", + " + (6 / 95) * math.sqrt(266) * y_9_7 * y\n", + " + (3 / 190) * math.sqrt(1463) * y_9_8 * z\n", + ")\n", + "y_10_9 = (\n", + " (3 / 190) * math.sqrt(798) * y_9_11 * x\n", + " - 3 / 190 * math.sqrt(798) * y_9_7 * z\n", + " + (3 / 190) * math.sqrt(4389) * y_9_8 * y\n", + " + (1 / 190) * math.sqrt(21945) * y_9_9 * x\n", + ")\n", + "y_10_10 = (\n", + " -3 / 190 * math.sqrt(1995) * y_9_10 * z\n", + " - 3 / 190 * math.sqrt(1995) * y_9_8 * x\n", + " + (1 / 19) * math.sqrt(399) * y_9_9 * y\n", + ")\n", + "y_10_11 = (\n", + " (3 / 190) * math.sqrt(4389) * y_9_10 * y\n", + " - 3 / 190 * math.sqrt(798) * y_9_11 * z\n", + " - 3 / 190 * math.sqrt(798) * y_9_7 * x\n", + " + (1 / 190) * math.sqrt(21945) * y_9_9 * z\n", + ")\n", + "y_10_12 = (\n", + " (3 / 190) * math.sqrt(1463) * y_9_10 * z\n", + " + (6 / 95) * math.sqrt(266) * y_9_11 * y\n", + " - 7 / 190 * math.sqrt(114) * y_9_12 * z\n", + " - 7 / 190 * math.sqrt(114) * y_9_6 * x\n", + " - 3 / 190 * math.sqrt(1463) * y_9_8 * x\n", + ")\n", + "y_10_13 = (\n", + " (3 / 190) * math.sqrt(1729) * y_9_11 * z\n", + " + (7 / 190) * math.sqrt(741) * y_9_12 * y\n", + " - 21 / 380 * math.sqrt(38) * y_9_13 * z\n", + " - 21 / 380 * math.sqrt(38) * y_9_5 * x\n", + " - 3 / 190 * math.sqrt(1729) * y_9_7 * x\n", + ")\n", + "y_10_14 = (\n", + " (7 / 380) * math.sqrt(1482) * y_9_12 * z\n", + " + (21 / 95) * math.sqrt(19) * y_9_13 * y\n", + " - 3 / 380 * math.sqrt(1330) * y_9_14 * z\n", + " - 3 / 380 * math.sqrt(1330) * y_9_4 * x\n", + " - 7 / 380 * math.sqrt(1482) * y_9_6 * x\n", + ")\n", + "y_10_15 = (\n", + " (21 / 380) * math.sqrt(190) * y_9_13 * z\n", + " + (3 / 38) * math.sqrt(133) * y_9_14 * y\n", + " - 1 / 190 * math.sqrt(1995) * y_9_15 * z\n", + " - 1 / 190 * math.sqrt(1995) * y_9_3 * x\n", + " - 21 / 380 * math.sqrt(190) * y_9_5 * x\n", + ")\n", + "y_10_16 = (\n", + " (3 / 95) * math.sqrt(665) * y_9_14 * z\n", + " + (4 / 95) * math.sqrt(399) * y_9_15 * y\n", + " - 3 / 190 * math.sqrt(133) * y_9_16 * z\n", + " - 3 / 190 * math.sqrt(133) * y_9_2 * x\n", + " - 3 / 95 * math.sqrt(665) * y_9_4 * x\n", + ")\n", + "y_10_17 = (\n", + " -3 / 380 * math.sqrt(266) * y_9_1 * x\n", + " + (1 / 95) * math.sqrt(6783) * y_9_15 * z\n", + " + (3 / 190) * math.sqrt(2261) * y_9_16 * y\n", + " - 3 / 380 * math.sqrt(266) * y_9_17 * z\n", + " - 1 / 95 * math.sqrt(6783) * y_9_3 * x\n", + ")\n", + "y_10_18 = (\n", + " -1 / 380 * math.sqrt(798) * y_9_0 * x\n", + " + (3 / 380) * math.sqrt(13566) * y_9_16 * z\n", + " + (3 / 95) * math.sqrt(399) * y_9_17 * y\n", + " - 1 / 380 * math.sqrt(798) * y_9_18 * z\n", + " - 3 / 380 * math.sqrt(13566) * y_9_2 * x\n", + ")\n", + "y_10_19 = (\n", + " -3 / 20 * math.sqrt(42) * y_9_1 * x\n", + " + (3 / 20) * math.sqrt(42) * y_9_17 * z\n", + " + (1 / 10) * math.sqrt(21) * y_9_18 * y\n", + ")\n", + "y_10_20 = (1 / 10) * math.sqrt(105) * (-y_9_0 * x + y_9_18 * z)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "fae57ecb-2ef5-43f5-a19b-fedb24d9f491", + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"\n", + "See the first cell for details about `PRECISION_TOL`: this impacts the\n", + "number of terms that will appear in the final implementation. More terms\n", + "make it more precise, at the cost of being more difficult to maintain\n", + "and from a performance perspective, may eventually exceed the number of\n", + "registers on a given platform!\n", + "\"\"\"\n", + "nsimplify = partial(sympy.nsimplify, tolerance=PRECISION_TOL, rational=True)\n", + "\n", + "\n", + "def replace_floating_integers(expr):\n", + " \"\"\"Dumb, but straight forward way to replace floats that should be integers.\"\"\"\n", + " replace_dict = {sympy.Float(value): sympy.Integer(value) for value in range(1, 15)}\n", + " return expr.subs(replace_dict)\n", + "\n", + "\n", + "def optimization_chain(expr):\n", + " \"\"\"Sequentially run a sequence of optimization steps on an expression\"\"\"\n", + " opt_chain = [collect_sqrt, collect_const, nsimplify]\n", + " for func in opt_chain:\n", + " expr = func(expr)\n", + " return expr.n()\n", + "\n", + "\n", + "def make_eval_expr(expr, **kwargs):\n", + " \"\"\"\n", + " For a given expression, prepare it for implementation by\n", + " running it through a gambit of manipulations/optimizations,\n", + " particularly in terms of collecting like variables and replacing\n", + " floating point symbols with integer ones if they are numerically\n", + " equivalent (i.e. 2.0 -> 2).\n", + "\n", + " The main computationally intensive part is a loop over various\n", + " possible combinations of xyz collections: some may lead to\n", + " fewer arithmetic operations, and the loop works to find the\n", + " combination with the fewest.\n", + " \"\"\"\n", + " new_expr = expr.expand(func=True).simplify(**kwargs)\n", + " # first thing to do is to replace floats that are actually\n", + " # integers\n", + " new_expr = replace_floating_integers(new_expr)\n", + " # collect terms differently, minimizing number of operations\n", + " combos = chain(\n", + " combinations([x, y, z], 1),\n", + " combinations([x, y, z], 2),\n", + " combinations([x, y, z], 3),\n", + " )\n", + " best_solution = new_expr\n", + " best_num_ops = sympy.count_ops(best_solution)\n", + " for combo in combos:\n", + " # runs a sequence of chained collections to try and\n", + " # minimize the number of operations\n", + " temp = optimization_chain(sympy.collect(new_expr, combo))\n", + " # count the number of computational operations we perform\n", + " counts = sympy.count_ops(temp)\n", + " # if we end up with fewer ops, go for it\n", + " if counts < best_num_ops:\n", + " best_solution = temp\n", + " best_num_ops = counts\n", + " return best_solution\n", + "\n", + "\n", + "def take_derivative(expr, symbols: list[Symbol], optimize: bool = True):\n", + " \"\"\"\n", + " Function to take the derivative of a symbolic equation with respect\n", + " to a list of symbols.\n", + "\n", + " We loop through each symbol, and if it is used in the equation,\n", + " we take the first derivative with respect to that function.\n", + " \"\"\"\n", + " return_dict = {}\n", + " for symbol in symbols:\n", + " if symbol in expr.free_symbols:\n", + " deriv = diff(expr, symbol)\n", + " if simplify:\n", + " deriv = optimization_chain(deriv)\n", + " return_dict[str(symbol)] = deriv\n", + " if len(return_dict) == 0:\n", + " raise RuntimeError(\"None of the requested symbols were used in the expression!\")\n", + " return return_dict\n", + "\n", + "\n", + "@cache\n", + "def driver(*expressions):\n", + " \"\"\"\n", + " Creates the expected data structure, generates the optimized forward\n", + " pass code, computes derivatives, and aggregates them to be in\n", + " terms of xyz for ease of implementation.\n", + " \"\"\"\n", + " outputs = {\"fwd\": [], \"bwd\": {\"x\": 0, \"y\": 0, \"z\": 0}}\n", + " for index, expr in enumerate(expressions):\n", + " fwd = make_eval_expr(expr, ratio=1.5)\n", + " outputs[\"fwd\"].append(fwd)\n", + " bwd = take_derivative(fwd, [x, y, z], optimize=True)\n", + " grad_term = symbols(f\"g_{index}\")\n", + " # collect up all the terms for the backward pass w.r.t.\n", + " # axes, with the corresponding gradient term in the forward\n", + " for key in [\"x\", \"y\", \"z\"]:\n", + " if key in bwd:\n", + " outputs[\"bwd\"][key] += bwd[key] * grad_term\n", + " # now do a final optimization on the derivatives\n", + " for key in [\"x\", \"y\", \"z\"]:\n", + " outputs[\"bwd\"][key] = optimization_chain(outputs[\"bwd\"][key])\n", + " return outputs\n", + "\n", + "\n", + "def wrapper(n: int):\n", + " \"\"\"\n", + " High level function that will take the symbolic e3nn expressions and\n", + " run them through an optimization chain to obtain any of the spherical\n", + " harmonics directly as functions of xyz.\n", + "\n", + " Also computes the derivative, running that also through the optimization chain.\n", + " \"\"\"\n", + " num_terms = 2 * n + 1\n", + " return driver(*[eval(f\"y_{n}_{m}\") for m in range(num_terms)])" + ] + }, + { + "cell_type": "markdown", + "id": "53ce9de1-8bdb-4922-ba77-707af4a6b6d8", + "metadata": {}, + "source": [ + "### Generating the expressions _en masse_\n", + "\n", + "The `driver` function that's called by `wrapper` is cached, which means the first time will take a hot minute but subsequent calls with the same `n` value should be free." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "bc287731-f493-4e58-b10e-a45daebbaf17", + "metadata": {}, + "outputs": [], + "source": [ + "second_order_expressions = wrapper(2)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "6cd00e51-8b97-4ddc-a95d-f9037ae4535d", + "metadata": {}, + "outputs": [], + "source": [ + "third_order_expressions = wrapper(3)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "fe8c8edb-6672-49c0-ac93-46c97b31e9e5", + "metadata": {}, + "outputs": [], + "source": [ + "fourth_order_expressions = wrapper(4)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "1cf41f26-2562-4625-a9e4-8b0d86be6de4", + "metadata": {}, + "outputs": [], + "source": [ + "fifth_order_expressions = wrapper(5)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "5e996499-c56b-44bd-aaee-d796347764bd", + "metadata": {}, + "outputs": [], + "source": [ + "sixth_order_expressions = wrapper(6)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "9b6d7ca3-a427-4ff9-a657-a202623ba1bc", + "metadata": {}, + "outputs": [], + "source": [ + "seventh_order_expressions = wrapper(7)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "429fd2f5-8e13-46f5-a4eb-9cae1419568a", + "metadata": {}, + "outputs": [], + "source": [ + "eighth_order_expressions = wrapper(8)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "d37250a7-ece0-4c3e-88bf-0c987ab9ce5e", + "metadata": {}, + "outputs": [], + "source": [ + "ninth_order_expressions = wrapper(9)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "f5b04e7e-fec5-4871-987a-2951526f6380", + "metadata": {}, + "outputs": [], + "source": [ + "tenth_order_expressions = wrapper(10)" + ] + }, + { + "cell_type": "markdown", + "id": "f323ee83-1a65-4f45-a524-9bee84cd9443", + "metadata": {}, + "source": [ + "## Numerical checks with `e3nn`\n", + "\n", + "This is an elementwise, manual check on specific projections within the spherical harmonics, comparing the symbolic -> numerical evaluation with the corresponding `e3nn` value." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "2a32d49e-7346-401e-99a0-606e89bd6eba", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[1.0000, 1.7321, 1.7321, 1.7321, 3.8730, 3.8730, 0.0000, 3.8730, 0.0000]])" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# second order\n", + "_spherical_harmonics(2, test_tensor[:, 0], test_tensor[:, 1], test_tensor[:, 2])" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "edc9c855-c4b7-482f-ad9d-494d03c4e7ef", + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$\\displaystyle 0$" + ], + "text/plain": [ + "0" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# compare each of the forward terms with above\n", + "second_order_expressions[\"fwd\"][-1].subs({\"x\": 1.0, \"y\": 1.0, \"z\": 1.0})" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "9fabc07f-2de0-4c5b-9d6d-2c00e88c71f7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 1.0000, 1.7321, 1.7321, 1.7321, 3.8730, 3.8730, 0.0000, 3.8730,\n", + " 0.0000, 4.1833, 10.2470, 3.2404, -5.2915, 3.2404, 0.0000, -4.1833]])" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# third order\n", + "_spherical_harmonics(3, test_tensor[:, 0], test_tensor[:, 1], test_tensor[:, 2])" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "005edb58-51d0-4acb-9038-74d5f54ea375", + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$\\displaystyle -5.29150262212918$" + ], + "text/plain": [ + "-5.29150262212918" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "third_order_expressions[\"fwd\"][3].subs({\"x\": 1.0, \"y\": 1.0, \"z\": 1.0})" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "f907e089-8ab7-4b8a-b766-dc3a69a3d7ea", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 1.0000e+00, 1.7321e+00, 1.7321e+00, 1.7321e+00, 3.8730e+00,\n", + " 3.8730e+00, 0.0000e+00, 3.8730e+00, 0.0000e+00, 4.1833e+00,\n", + " 1.0247e+01, 3.2404e+00, -5.2915e+00, 3.2404e+00, 0.0000e+00,\n", + " -4.1833e+00, 0.0000e+00, 1.2550e+01, 1.3416e+01, -4.7434e+00,\n", + " -1.0500e+01, -4.7434e+00, -5.9605e-08, -1.2550e+01, -8.8741e+00]])" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# fourth order\n", + "_spherical_harmonics(4, test_tensor[:, 0], test_tensor[:, 1], test_tensor[:, 2])" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "29a82d20-e9f9-40ce-a70c-963569439cf3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$\\displaystyle -8.87411967464942$" + ], + "text/plain": [ + "-8.87411967464942" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fourth_order_expressions[\"fwd\"][-1].subs({\"x\": 1.0, \"y\": 1.0, \"z\": 1.0})" + ] + }, + { + "cell_type": "markdown", + "id": "a08d5640-8b8d-48f1-86cf-21eeae568001", + "metadata": {}, + "source": [ + "## Operations count comparison\n", + "\n", + "This set of cells were used to analyze the number of floating point operations required to compute a given order of spherical harmonics, compared to the copy-pasted `e3nn` equations. Note that this should be interpreted as an upper bound: `torchscript` and `torch.compile` will likely give some operation fusion, and eliminate differences to bring them closer in actual computation." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "bc255e1f-b274-4010-8699-0fbe7049d7db", + "metadata": {}, + "outputs": [], + "source": [ + "def count_operations_per_n(n: int):\n", + " projections = range(2 * n + 1)\n", + " e3nn_impl = [sympy.count_ops(eval(f\"y_{n}_{m}\")) for m in projections]\n", + " direct = wrapper(n)[\"fwd\"]\n", + " direct_counts = [sympy.count_ops(expr) for expr in direct]\n", + " return e3nn_impl, direct_counts" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "68340cc8-c716-414f-8bf2-48de8976ea17", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from matplotlib import pyplot as plt\n", + "from matplotlib.lines import Line2D" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "a98325e2-8738-4f52-a02f-b581bf0ec944", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAnYAAAHWCAYAAAD6oMSKAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB7oElEQVR4nO3deViU5f4G8HsGGDYHBNnXUVEUF0QUXKjcyjbTMk/1KwXc8oSKmKa2aFm5pIULnMwlUdvN0rTFyrU5qSi44MaijiAoiyDbsM7M7w8PJDGaIzPz4sz9uS6uS953lvsVlC/P+zzfR6TRaDQgIiIiovueWOgARERERKQfLOyIiIiITAQLOyIiIiITwcKOiIiIyESwsCMiIiIyESzsiIiIiEwECzsiIiIiE8HCjoiIiMhEWAodQGhqtRp5eXmQSqUQiURCxyEiIiJqQqPRoLy8HF5eXhCL7zwmZ7aFXWJiIhITE1FbW4sLFy4IHYeIiIjojnJycuDj43PHx4jMfUux0tJStG3bFjk5OXBwcBA6DhEREVETZWVl8PX1xY0bN+Do6HjHx5rtiF2DhtuvDg4OLOyIiIio1bqbKWNcPEFERERkIljYEREREZkIsy3sEhMTERQUhL59+wodhYiIiEgvzH7xRFlZGRwdHVFaWso5dkRERNTq6FKrmO2IHREREZGpYWFHREREZCLMvt2JIWnq6yGytIRGpULt1StQV1ZAbN8GEk8fiCwsGs+T8ahUalhYiKFWa3CtRAllTT3srC3h4WQHsVjUeJ6IiOh+ZLZVRcPOEyqVSu+vrVHVAyIxyuR7ULz9SyjTUqCprWk8L5JYw65HKJxHvQCHBx8GNGqILMz2S2EUKrUaIpEIh88X4KfkbJzNLkFtvbrxvMRSjCA/Jzwe5of+Qe7QaDSw+IdtW4iIiFobLp7Q8+IJjUqF2tzLyFk4C9Xpp//x8bZdesDnrWWQePtDZGHR4ven5lRqDa4WV2LZ1pPIyiv7x8d38nbErGd7wtPZHhZi7h9MRETC4uIJgWhUKlSeSEbWhKfvqqgDgKrzacia8DQqTyRDY4DRQ3OnUmtwWnEd0//z510VdQCQmVuK6f/5E6cV16FS31+/9yQlJaFt27b3/Pz9+/dDJBLhxo0bestERETGw8JOTzSqetTmXsbluVOgqa7S7bnVVbg8dwpqc7Nv3sYlvVCp1bhaXIl3PktFTZ1uRXNNnQrvfJaKq8WVUKnU//yEuxAVFQWRSNTs49FHH9XL6wPAc889h4yMDADAoEGDtL5fw8egQYOaPX/AgAG4evVq416ELS0UiYjIuFjY6YtIjJyFs3Qu6hpoqqtw5d1ZgMjwX5Lr16/j0UcfhZeXF6ytreHr64upU6eirOzuRrTuFyKRCMu2ntS5qGtQU6fC8m9PQaTH27GPPvoorl692uTjyy+/1Nvr29raws3NDQDw3XffNb5HcnIyAOD3339vPPbdd981eW5dXR0kEgk8PDzuaj9CIiJzJ5fLMSQ0BD083TAkNARyuVzoSCzs9EFTX4+yg7/d9e3X26k6n4ayP36Hpt6wo3ZisRgjR47EDz/8gIyMDCQlJeH333/HlClTDPq+xqRSqXHobP5d3369nczcUhw+l496PY3aWVtbw8PDo8mHk5PTzffKzMSDDz4IGxsbBAUF4bfffoNIJML27dsBaL9NeuLECYhEIigUCgBNR9icnZ0b38PV1RUA0K5du8Zj7dq1w8cff4ynnnoK9vb2eP/995u8x/79+xEdHY3S0tLGUb63334bAFBSUoJx48bByckJdnZ2eOyxx5CZmdmYqyHH7t270bVrV7Rp06axqCUiMgUxUeMQ++QjiK4vwlqZFNH1RYh98hHERI0TNBcLOz0QWVqieLt+Rl2Kt3+hlxYoarUaixcvRvv27WFra4vg4GB8++23AAAnJyf8+9//Rp8+feDv74+hQ4filVdewR9//NH4/Lfffhu9evXCli1bIJPJ4OjoiOeffx7l5eWNjxk0aBCmT5+O1157rbGIaPjBLzQLCzF+Ss7Wy2v9mJwNSwO3QFGr1XjmmWcgkUhw5MgRrFmzBnPmzDHoewI3v85PP/000tLSMH78+CbnBgwYgBUrVsDBwaFxlG/WrFkAbt5WPnbsGH744QccOnQIGo0Gjz/+OOrq6hqfr1QqsXz5cmzZsgUHDx5EdnZ24/OJiO5ncrkch77/Fhu7eyPEwRZSSwuEONhiY3dvHPr+W0FH7sy2sNPnXrEalQrKtBQ9pAKUaal6WUSxePFibN68GWvWrMGZM2cQFxeHl156CQcOHGj22Ly8PHz33Xd46KGHmhy/cOECtm/fjl27dmHXrl04cOAAlixZ0uQxmzZtgr29PY4cOYIPPvgACxcuxG+//dbi/C2lVmtwNrtEL691LrsEaj0toti1axfatGnT5GPRokX4/fffcf78eWzevBnBwcF48MEHsWjRIr2855383//9H6Kjo9GhQwf4+fk1OSeRSODo6AiRSNQ4ytemTRtkZmbihx9+wPr16/HAAw8gODgYn3/+OXJzcxtHF4Gbt3bXrFmDPn36oHfv3pg6dSr27Nlj8GsiIjK0uInjESdrB/Hfpq2IRSJM93dG3MTxt3mm4ZltYRcTE4OzZ8/i6NGjLX6t2qtXmvSpawlNTTVqr15p0WvU1NRg0aJF+PTTTzF8+HB06NABUVFReOmll/DJJ580Pu6FF16AnZ0dvL294eDggPXr1zd5HbVajaSkJHTv3h0PPPAAxo4d2+wHc8+ePbFgwQJ06tQJ48aNQ58+fVrFD+9rJcomfepaoqZOjWslSr281uDBg3HixIkmH1OmTMG5c+fg6+sLLy+vxsf2799fL+95J3369NH5OefOnYOlpSXCw8Mbj7Vr1w6BgYE4d+5c4zE7Ozt07Nix8XNPT08UFBS0LDARUStQmJONADuJ1nOd7a1RkKOfO0b3gl1x9UBdWaHf11NWtuj5WVlZUCqVePjhh5scr62tRUhISOPn8fHxWLBgATIyMjBv3jzMnDkT//nPfxrPy2QySKXSxs+1/WDu2bNnk89byw9vZY1+5ylW6en17O3tERAQcE/PFf+vYfKtrSdvvfV5r3kMxcrKqsnnIpEIZt42k4hMhEqjQZayFm4SS3yfX4bc6jp421jhaXcHXKupg1rA/+tY2OmB2L6Nfl/PrmU/bCsqbhaaP/74I7y9vZucs7a2bvxzw+21Ll26wNnZGQ888ADeeusteHp6AtD+g1mtbjoKdjePEYKdtX6/tW31/Hp/17VrV+Tk5ODq1auNf/+HDx9u8piGBRBXr15tXHBx4sQJg+aSSCTNdmfp2rUr6uvrceTIEQwYMADAzZXW6enpCAoKMmgeIqLWwN3PD8su5sBFYoloHycE2EmQpazF0ouFKKqth/vfprYYk9neitUniacPRBLrf37gXRBZ20Di6dOi1wgKCoK1tTWys7MREBDQ5MPX11frcxqKsZoa/dxSFpqHkx0klvr59ra2EsPDyU4vr1VTU4Nr1641+SgqKsKwYcPQuXNnREZG4uTJk/jjjz/wxhtvNHluw9fv7bffRmZmJn788Ud8+OGHesl1OzKZDBUVFdizZw+KioqgVCrRqVMnjBw5EpMmTYJcLsfJkyfx0ksvwdvbGyNHjjRoHiKi1iB62gxYikVY0dWzyeKJFV09YSkWIXraDMGysbDTA5GFBex6hOrltex69G7x1mJSqRSzZs1CXFwcNm3ahAsXLiA1NRWrV6/Gpk2b8NNPP2Hjxo04ffo0FAoFfvzxR0yZMgUDBw6ETCbTy3UITSwWIcjPSS+v1dXPCWI99bL75Zdf4Onp2eQjIiICYrEY33//PaqqqhAWFoaJEyfi/fffb/JcKysrfPnllzh//jx69uyJpUuX4r333tNLrtsZMGAApkyZgueeew6urq744IMPAAAbN25EaGgonnzySfTv3x8ajQY//fRTsxFcIiJTtO3T9YiTudxm8UQ7bPt0/W2eaXi8FasHmvp6OI96AZUpf7b4tZxH/R809fUtbnny7rvvwtXVFYsXL8bFixfRtm1b9O7dG6+//jpUKhXWrVuHuLg41NTUwNfXF8888wzmzp3b4vythUqlxuNhfjhx8XqLX+uJMD/Uq9QtbnmSlJSEpKSk257v3Llzk5Yz2gwcOBCnTp1qcuzWeWtRUVGIiopq9jyZTNZsfpu2+W6DBg1qdvzjjz/Gxx9/3OSYk5MTNm/efNuc2nKMGjWKc+yIyCQU5uUiQCbVeq6zvTUKFblGTvQXkcbM/6fVZWPdO9Go1bj48hhUnU+759ew7dIDHT7ZCpGYA6n6oNZoMPOTQ8jMLb3n1+jk7YiPXu7f7LcyYxGJRPj+++8xatQoQd6fiIiaGxIaguj6IoQ42DY7l1KqxCYrV+xNOa6399OlVmEFoS8aNXzeWgaRTfMv8t0Q2djC563lgEb4hQemQqPRYNazPWFtdW+3tq2tLDDr2Z7Q6KmHHRERmYaFK1djVXZxs9Wvao0GCTklWLhytUDJzLiw02eDYgAQWVhC4u0P/yVrdC7uRDa28F+yBhJvP4gseHdcXyzEYng622PBS711Lu6srSyw4KXe8HS2h4WBd524E41Gw9E6IqJWJiIiAmEjRyP6dC6OlSpRVq/CsVIlok/nImzkaERERAiWzWwLO302KG4gsrCAfa8wBGzYDtsuPe7qObZdeiBgw3bY9wpr8aIJas5CLEJ3WTusemUAOnk73tVzOnk7YtUrA9Bd1g4Welo0QUREpqdOrcHOgnK8m1WAnQXlqGsFd3g4x05Pc+xupVHVAyIxyv74DcXffwllWkqTnSlE1jaw69EbzqP+Dw4PDAM0ao7UGZhKpYZILMKhs/n4KTkbZ7NLmuxMYW0lRlc/JzwR5od+Xd2hUWsEHakjIqLWSy6XI27EcGzo5tVkDrZao8GEM3mI37lbr6N2utQqLOwMUNg1aFjdqlGpUHv1CtTKSojt7G/2vbOw0MvqV9KNSqWGhYUYarUG10qUqKqph621JTyc7CAWi/Sy+pWIiExbw+KJ2+08IeTiCVYVBtRQtIksLGDt43/b82Q8DaNwYrEIXu2a7/DBoo6IiP5JYV4uLrQBNl4pabbzxAPOdigsFK7dCSsLIiIiIh1YOzpib0EeEoL+uhUb4mCL4K42iDmbB2s3L8GycXiCiIiISAciiDDJ11nrzhMTfJwggnAL71jYEREREemguvQGAuwkWs91trdGdekN4wa6BQs7IiIiIh24enkjS1mr9VxmZQ1cvbyNnOgvLOyIiIiIdMCdJ4iIiIhMRMPOExPO5CHlfztPpJQqMeFMnuA7T3BVLBEREZGOEpM2Qy6fjPmx01CoyIWrlzfid34laFEHmHFhl5iYiMTERKhUKqGjEBER0X0oIiJCr42I9YE7Txhw5wkiIiKiltKlVuEcOyIiIiITwcKOiIiIyESY7Rw7IiIiopZQKBRYuyIeF8+fQ4cuXTF5RhxkMpmgmThiR0RERKSjLevXYfLDg9B133eILb2Arvu+w+SHB2HL+nWC5mJhR0RERKQDhUKBLUvfR5wj8N8SJRZdKMR/S5SIcwS2LH0fCoVCsGws7IiIiIh0sHZFPDrVlGPZpSIMdLLD6x1dMdDJDssuFaFjdTnWrogXLBvn2BERERHp4NjhQ6irrMHKrp4Qi0QAgBAHWwR3tUHsuauwOnxIsGwcsSMiIiLSwZXsyxjv49RY1DUQi0SI8m6LK9mXBUrGwo6IiIhINyoVAuwkWk91srcGBNzVioUdERERkQ48fHyRpazVei6zsgYePr5GTvQXFnZEREREOli4cjVWZRdD/bddWdUaDRJySrBw5WqBkrGwIyIiItJJREQEwkaOxoQzeUgpVaKsXoWUUiUmnMlD2MjRiIiIECwbV8USERER6SgxaTPk8smYHzsNhYpcuHp5I37nV4IWdYCJFHYymQwODg4Qi8VwcnLCvn37hI5EREREJi4iIgJ7U44LHaMJkyjsAODPP/9EmzZthI5BREREJBjOsSMiIiIyEYIXdgcPHsSIESPg5eUFkUiE7du3N3tMYmIiZDIZbGxsEB4ejuTk5CbnRSIRHnroIfTt2xeff/65kZITERERtS6CF3aVlZUIDg5GYmKi1vNff/01Zs6ciQULFiA1NRXBwcEYPnw4CgoKGh8jl8uRkpKCH374AYsWLcKpU6eMFZ+IiIio1RBpNH9rwiIgkUiE77//HqNGjWo8Fh4ejr59+yIhIQEAoFar4evri2nTpmHu3LnNXmP27Nno1q0boqKitL5HTU0NampqGj8vKyuDr68vSktL4eDgoNfrISIiImqpsrIyODo63lWtIviI3Z3U1tYiJSUFw4YNazwmFosxbNgwHDp0c4PdyspKlJeXAwAqKiqwd+9edOvW7bavuXjxYjg6OjZ++PoK1x2aiIiISJ9adWFXVFQElUoFd3f3Jsfd3d1x7do1AEB+fj4iIiIQHByMfv36Ydy4cejbt+9tX3PevHkoLS1t/MjJyTHoNRAREZkyuVyOIaEh6OHphiGhIZDL5UJHMmv3fbuTDh064OTJk3f9eGtra1hbWxswERERkXmIiRqHw9u/xQz/dgiQSZGlLELsk4+g36hnkZi0Weh4ZqlVj9i5uLjAwsIC+fn5TY7n5+fDw8OjRa+dmJiIoKCgO47uERERkXZyuRyHvv8WG7t7I8TBFlJLC4Q42GJjd28c+v5bjtwJpFUXdhKJBKGhodizZ0/jMbVajT179qB///4teu2YmBicPXsWR48ebWlMIiIisxM3cTziZO0gFomaHBeLRJju74y4ieMFSmbeBL8VW1FRgaysrMbPL126hBMnTsDZ2Rl+fn6YOXMmIiMj0adPH4SFhWHFihWorKxEdHS0gKmJiIjMW2FONgJ6ems919neGgUXso2ciIBWUNgdO3YMgwcPbvx85syZAIDIyEgkJSXhueeeQ2FhIebPn49r166hV69e+OWXX5otqCAiIiLjUWk0yFLWIsTBttm5zMoaqFtPNzWz0qr62BlTYmIiEhMToVKpkJGRwT52REREOujTpRPsCvKQEOTV5HasWqNBzNk8VLl54dj5TAETGp5cLsf82GkozMuFq5c3Fq5cjYiICL2/jy597My2sGugy18WERER3SSXyzFx+FA4WVnAy8YKSpUadhZi5FXXoaROhfW79xikyGktYqLGIXnHNkz3c0aAnQRZylqsyi5G2MjRel8RbDINiomIiKh1ioiIgNTbF7UaDUa5O2B+gBtGuTugVqOB1NvXpIs6uVyO5B3bsKGbV5MVwRu6eSF5xzZBVwSzsCMiIiKdyeVyoPAqNvbwadrupIcPUHjVpNudzI+dhul+zlpXBE/1dcL82GkCJTPjwo597IiIiO5day5uDK0wLxcBdhKt5zrZW6MwL9fIif5itoUd+9gRERHdu9Zc3Biaq5c3spS1Ws9lVtbA1Ut7GxhjMNvCjoiIiO6djWPb2xY3GZU1sHFsa9xARrRw5Wqsyi5u1tJFrdEgIacEC1euFigZCzsiIiK6BxposC5He3Gz4UoJNDDdphsREREIGzkaE87kIaVUibJ6FVJKlZhwJg9hI0cLunDEbAs7zrEjIiK6dzWlpRjSzh4zzl1FamkVyupVSC2twoxzVzGknT1qSkuFjmhQiUmbEb9zNzZZueJlRTk2Wbkifuduvbc60RX72LGPHRERkc6GhIYgur4IbhJLfJ9fhtzqOnjbWOFpdwdcq6nDJitX7E05LnRMk6BLrSL4lmJERER0/1m4cjXiRgzHhm5emOrfrvG4WqPBm1kFiN/5lYDpzBcLOyIiItJZwzyzl77bik52VlCqNLCzECFTWYcHnhlj0g2KWzMWdkRERHTPrMQijHJ3bNxWa8Xl60JHMmtmu3iCiIiI7l3Dtlobu3s33Xmiu7fg22qZM7Mt7LgqloiI6N6Z884TrZnZFnbceYKIiOjemfPOE62Z2RZ2REREdO9a87Za5oyFHREREemsNW+rZc5Y2BEREZHOWvO2WuaM7U6IiIjoniQmbYZcPhnzY6ehUJELVy9vxO/8ikWdgFjYERER0T3z8fFBvwcexMXz59ChS1f4+PgIHcmsme2tWLY7ISIiapkt69dh8sOD0HXfd4gtvYCu+77D5IcHYcv6dUJHM1sijeZvsx7NjC4b6xIREdFNCoUCkx8ehOXuVk162ak1GszKr8Pa3/ZDJpMJF9CE6FKrmO2IHREREd27tSviMdZOpbVB8Ut2aqxdES9QMvPGwo6IiIh0dvH8uds2KA6ws8LF9HNGTkQACzsiIiK6Bx26dL1tg+IsZR06BHY1ciLjk8vlGBIagh6ebhgSGtIq9sdlYUdEREQ6mzwjDluUFlobFH+mFGPyjDiBkhlHTNQ4xI0Yjuj6IqyVSRFdX4S4EcMREzVO0Fws7IiIiEhnMpkMY+e8gVn5dUgtq0ZZvQqpZdWYlV+HsXPeMOmFE3K5HMk7tmFDNy+EONhCammBEAdbbOjmheQd2wQdueOqWK6KJSIiumcKhQJrV8TjYvo5dAjsiskz4ky6qAOAIaEhiK4vQoiDbbNzKaVKbLJyxd6U43p7P11qFbNtUJyYmIjExESoVCqhoxAREd23ZDIZFq1YKXQMoyrMy0WATKr1XCd7axQqco2c6C9meys2JiYGZ8+exdGjR4WOQkRERPcRVy/v2y4cyaysgauXt5ET/cVsCzsiIiJquda4MtTQFq5cjVXZxVoXjiTklGDhytUCJWNhR0RERPeota4MNbSIiAiEjRyNCWfykFKqRFm9CimlSkw4k4ewkaMREREhWDYunuDiCSIiIp3J5XLEjRiODd28mm0pNuFMHuJ37ha0wDEGuVyO+bHTUJiXC1cvbyxcudog16xLrcLCjoUdERGRzoy9MtScca9YIiIiMqjCvNzbbinWyd4ahXnCrQw1ZyzsiIiISGeteWWoOWNhR0RERDprzStDzRkLOyIiItJZa14Zas5Y2BEREbWQOfZyA4DEpM2I37kbm6xc8bKiHJusXBG/czcSkzYLHc1scVUsV8USEVELxESNQ/KObZju54wAOwmylLVYlV2MsJGjWeCQXnBVLBERkRHI5XIk79iGhR1d8d8SJRZdKMR/S5RY2NEVyTu2mc3IHbUeZjtil5iYiMTERKhUKmRkZHDEjoiIdDYkNARdC7NxrrIG0T5OjSN2G6+UINBOgnQ3f/ZyoxZjg2Id8FYsERHdq04uTvDR1GFlV89muy/EnruKKyIrZBaVCJiQTAFvxRIRERlBXV0dxvs4NSnqAEAsEiHKuy3q6uoESkbmqsWFnUqlwokTJ1BSwt9IiIjIvNhYSe64+4K1lfZzRIaic2E3Y8YMbNiwAcDNou6hhx5C79694evri/379+s7HxERUavl5e9/290XMipr4O3vb+REZO50Luy+/fZbBAcHAwB27tyJS5cu4fz584iLi8Mbb7yh94BEREStVcPuCzlVtUi4fB3z0q8h4fJ15FTVIpG7L5AAdF48YWNjg6ysLPj4+GDy5Mmws7PDihUrcOnSJQQHB6OsrMxQWQ2CiyeIiKgl+nYOgDo/FzNkLo2rYlcoiiB298bRjCyh45EJMOjiCXd3d5w9exYqlQq//PILHn74YQCAUqmEhYXFvSUmIiK6D8nlcqDwKjb28EGIgy2klhYIcbDFxh4+QOFV9rEjo9O5sIuOjsa//vUvdO/eHSKRCMOGDQMAHDlyBF26dNF7QCIiotZqfuw0TPdz1roqdqqvE+bHThMoGZkrS12f8Pbbb6N79+7IycnBmDFjYG1tDQCwsLDA3Llz9R6QiIiotSrMy0WATKr1XCd7axQqco2ciMydzoUdADz77LPNjkVGRrY4DBER0f3ExrEtspRVCHGwbXYuo7IGNo5tjR+KzNo9FXZ79uzBnj17UFBQALVa3eTcp59+qpdgRERErZ0GGqzLKUZCkFeznSc2XCmBxq15wWdq5HI55sdOQ2FeLly9vLFw5WpEREQIHcts6TzH7p133sEjjzyCPXv2oKioCCUlJU0+iIiIzEVNaSmGtLPHjHNXkVpahbJ6FVJLqzDj3FUMaWePmtJSoSMaVEzUOMSNGI7o+iKslUkRXV+EuBHDERM1TuhoZkvnEbs1a9YgKSkJY8eONUQeIiKi+4arlzc61hehf1t7fJ9fhq3XSuFtY4U5HVxxraYOrg6uQkc0GLlcjuQd27Ch21+jlSEOttjQzQsTdmyDXD6ZI3cC0HnErra2FgMGDDBElhZRKpXw9/fHrFmzhI5CRERmoqFBsae1Jab6t8PiQA9M9W8HT2tLJJh4g2KuCG6ddC7sJk6ciC+++MIQWVrk/fffR79+/YSOQUREZiQiIgJhI0djwpk8pJQqUVavQkqpEhPO5CFs5GiTHrEqzMu94z65hXlcESwEnW/FVldXY+3atfj999/Rs2dPWFlZNTn/0Ucf6S3c3crMzMT58+cxYsQInD592ujvT0RE5isxaTPk8sk3FxAobi4giN/5lUkXdcDN29BZyiKtK4IzK2vg6uUtQCrSecTu1KlT6NWrF8RiMU6fPo3jx483fpw4cULnAAcPHsSIESPg5eUFkUiE7du3N3tMYmIiZDIZbGxsEB4ejuTk5CbnZ82ahcWLF+v83kRERPoQERGBvSnHkXa1AHtTjpt8UQf8dRta/bedSdUajcnfhm7NdB6x27dvn14DVFZWIjg4GOPHj8czzzzT7PzXX3+NmTNnYs2aNQgPD8eKFSswfPhwpKenw83NDTt27EDnzp3RuXNn/Pnnn3rNRkRERNo13obesQ1TfZ3Qyd4amZU1SMgpMfnb0K2ZSKP5W6mtgytXrgAAfHx89BNGJML333+PUaNGNR4LDw9H3759kZCQAABQq9Xw9fXFtGnTMHfuXMybNw+fffYZLCwsUFFRgbq6Orz66quYP3/+Xb2nLhvrEhERUVPsY2d4utQqOhd2arUa7733Hj788ENUVFQAAKRSKV599VW88cYbEIt1vrv7V5i/FXa1tbWws7PDt99+26TYi4yMxI0bN7Bjx44mz09KSsLp06exfPny275HTU0NampqGj8vKyuDr68vCzsiIiJqlXQp7HSuwt544w0kJCRgyZIljXPrFi1ahNWrV+Ott96659DaFBUVQaVSwd3dvclxd3d3XLt27Z5ec/HixXB0dGz88PX11UdUIiIyY1u3bkWQlzs6SO0Q5OWOrVu3Ch2JzJTOc+w2bdqE9evX46mnnmo81rNnT3h7e+OVV17B+++/r9eAuoiKivrHx8ybNw8zZ85s/LxhxI6IiOhePBYxAAVpJzBP1g4B/p7IUtZiycRIfLoyHj/LOffblLXG29A6j9gVFxejS5cuzY536dIFxcXFegnVwMXFBRYWFsjPz29yPD8/Hx4eHvf0mtbW1nBwcGjyQUREdC+2bt2KgrQT2NjDGyEOtpBaWiDEwRYbe3ijIO04R+5MWGvdTk3nwi44OLhxIcOtEhISEBwcrJdQDSQSCUJDQ7Fnz57GY2q1Gnv27EH//v1b9NqJiYkICgpC3759WxqTiIjM1ILYqZgha6d194Xp/u2wIHaqQMnIkG7dTu3Wgn5DNy8k79gGuVwuWDadF08cOHAATzzxBPz8/BqLq0OHDiEnJwc//fQTHnjgAZ0CVFRUICsrCwAQEhKCjz76CIMHD4azszP8/Pzw9ddfIzIyEp988gnCwsKwYsUKfPPNNzh//nyzuXf3gqtiiYjoXnWQ2uHz7p6QWlo0O1dWr8KLp6/iUrlSgGRkSENCQxBdr705c0qpEpusXLE35bje3k+XWkXnOXYPPfQQMjIykJiYiPPnzwMAnnnmGbzyyivw8vLSOeyxY8cwePDgxs8b5r9FRkYiKSkJzz33HAoLCzF//nxcu3YNvXr1wi+//KKXoo6IiKglbKRSZClrb7v7gq1UKkAqMrTCvFwEyLR/bTvZW6NQIdx2ai3qY3c/S0xMRGJiIlQqFTIyMjhiR0REOtu6dSuWTIzExh7eTW7HqjUaRKddwdz1mzFmzBgBE5IhtOYRu7sq7E6dOoXu3btDLBbj1KlTd3xsz549dUsrMN6KJSKilri5KvY4pvu3a9x9YdXl63DrEcJVsSZKLpcjbsRwbOjm1aygn3AmD/E7d+t1dazeCzuxWIxr167Bzc0NYrEYIpEI2p4mEomgUqnuPbkAWNgREVFLbd26FQtip6KqvBy2UineWZnAkToTFxM1Dsm32U4tMWmzXt9L74Xd5cuX4efnB5FIhMuXL9/xsf7+/rqlFRgLOyIiIroXxupjp/fFE7cWa5cvX8aAAQNgadn0qfX19fjzzz/vm8Lu1jl2RERERLqKiIjQ61w6fdB58YSFhQWuXr0KNze3JsevX78ONze3+65Q4ogdERERtWYG3StWo9FA9LdGjMDNws7e3l7XlyMiIiIiPbnrPnbPPPMMgJsLJKKiomBtbd14TqVS4dSpUxgwYID+ExIRERHRXbnrws7R0RHAzRE7qVQKW9u/erdIJBL069cPkyZN0n9CIiIiIrord13Ybdy4EQAgk8kwa9as+/62KxdPEBERkakx250nGnDxBBEREbVmBt0rFgC+/fZbfPPNN8jOzkZtbW2Tc6mpqffykkRERETUQjqvil21ahWio6Ph7u6O48ePIywsDO3atcPFixfx2GOPGSIjERFRqyaXyzEkNAQ9PN0wJDQEcrlc6EhkpnQu7P7zn/9g7dq1WL16NSQSCV577TX89ttvmD59OkpLSw2RkYiIqNWKiRqHuBHDEV1fhLUyKaLrixA3YjhiosYJHY3MkM6FXXZ2dmNbE1tbW5SXlwMAxo4diy+//FK/6YiIiFoxuVyO5B3bsKGbF0IcbCG1tECIgy02dPNC8o5tHLkjo9O5sPPw8EBxcTEAwM/PD4cPHwYAXLp0CffTOozExEQEBQWhb9++QkchIqL71PzYaZju5wzx3xr3i0UiTPV1wvzYaQIlI3Olc2E3ZMgQ/PDDDwCA6OhoxMXF4eGHH8Zzzz2Hp59+Wu8BDSUmJgZnz57F0aNHhY5CRHTfM9c5ZoV5uQiwk2g918neGoV5uUZOROZO51Wxa9euhVqtBnCzOGrXrh3+/PNPPPXUU3j55Zf1HpCIiFq3mKhxOLz9W8zwb4cAmRRZyiLEPvkI+o16FolJm4WOZ1CuXt7IUhYhxMG22bnMyhq4enkLkIrMmU597Orr67Fo0SKMHz8ePj4+hsxlNOxjR0R07+RyOaY/8QiSeng3uR2p1mgQlZaLVT/+ioiICAETGpZcLkfciOHY0M2r2fVPOJOH+J27Tfr6yTh0qVV0uhVraWmJDz74APX19S0KSEREpiFu4njEydppnWM23d8ZcRPHC5TMOCIiIhA2cjQmnMlDSqkSZfUqpJQqMeFMHsJGjmZRR0an8xy7oUOH4sCBA4bIQkRE95nCnOzbzjHrbG+NgpxsIycyvsSkzYjfuRubrFzxsqIcm6xcEb9zt8nfhm5grvMrWyud59g99thjmDt3LtLS0hAaGtpsz9innnpKb+EMiXvFEhG1nEqjQZay9rZzzNT3UbeEloiIiMDelONCxzC6mKhxSN6xDdP9nBvnV8aNGI6wkaPNprBtbXTeK1Ysvv0gn0gkuu8KJc6xIyK6d326dIJdQR4SgprPMYs5m4cqNy8cO58pYEIyFM4vNB6DzbEDALVafduP+62oIyKillmxfiMKauox49xVpJZWoaxehdTSKsw4dxWFNfVYsX6j0BHJQNjDr3XS+Vbsraqrq2FjY6OvLEREdJ+JiIiA1NsXBVdzsKOgFEqVBnYWIhTU1kHq7csRGxNWmJeLAJlU67lO9tYoVLCHnxB0HrFTqVR499134e3tjTZt2uDixYsAgLfeegsbNmzQe0AiImq95HI5UHgVywI94SqxgqVIBFeJFZYFegKFVzmR3oTd7OFXq/Uce/gJR+fC7v3330dSUhI++OADSCR/rYTq3r071q9fr9dwRETUujXcjvO1lWCqfzssDvTAVP92Nz/n7TiTtnDlaqzKLm62QEat0SAhpwQLV64WKJl507mw27x5M9auXYsXX3wRFhYWjceDg4Nx/vx5vYYjIqLWjVtqmS/28GuddJ5jl5ubi4CAgGbH1Wo16urq9BKKiIjuDw1barlJLPF9fhlyq+vgbWOFp90dcK2mjrfjTFxi0mbI5ZMxP3YaChW5cPXyRvzOr1jUCUjnwi4oKAh//PEH/P39mxz/9ttvERISordgRETU+i1cuRoTHx0GL2tLRPs4IcBOgixlLZZeLEReTT3W//KV0BHJwMy1h19rpXNhN3/+fERGRiI3NxdqtRrfffcd0tPTsXnzZuzatcsQGQ2CDYqJiFrOx8cHUmsJVnR1b2x7EeJgi+CuNphwLt9k9hUnul/o3KAYAP744w8sXLgQJ0+eREVFBXr37o358+fjkUceMURGg2KDYiKiezdtQjTCjv2udeeJY6VKpPR9GKs3sJcdUUvoUqvcU2FnSljYERHdu7COMqx0s4DU0qLZubJ6FWIL1Dh64ZIAyYhMhy61yj03KD527BjOnTsH4Oa8u9DQ0Ht9KSIiuk+pNbjjXrGalvXBJyId6fwv7sqVK3jhhRfw3//+F23btgUA3LhxAwMGDMBXX33F+RRERGYkbNBgrNv+jda9YjdcKUHYqH8JmI7I/Ojcx27ixImoq6vDuXPnUFxcjOLiYpw7dw5qtRoTJ040REYiImqlXpu/AJUSW617xVZKbPHa/AVCRyQyKzrPsbO1tcWff/7ZrLVJSkoKHnjgASiVSr0GNDTOsSMiapkt69dh7XvvwK26HEqVGnYWYhTYSDH5zQUYO3GS0PGI7nu61Co6j9j5+vpqbUSsUqng5eWl68sREdF9buzESdiyX47A56Pg2Kc/Ap+Pwpb9chZ1ZPLkcjmGhIagh6cbhoSGtIq9kXUesduxYwcWLVqExMRE9OnTB8DNhRTTpk3DnDlzMGrUKEPkNBiO2BEREZGuYqLGIXnHNkz3c25szL0quxhhI0cjMWmzXt/LoO1OnJycoFQqUV9fD0vLm2svGv5sb2/f5LHFxcU6Rjc+FnZERC2nUCiwdkU8Lp4/hw5dumLyjDjIZDKhYxEZhFwuR9yI4djQrfmioQln8hC/c7det1UzaGG3adOmu35sZGSkLi8tCBZ2REQts2X9OmxZ+j7G2qkaRy62KC0wds4bvB1LJmlIaAii64u0tvlJKVVik5WrXrdZY4NiHbCwIyK6dwqFApMfHoQ4R2BHQTlyq+vgbWOFkW5SxJcCa3/bz5E7Mjk9PN2wVia9bWPulxXlSLtaoLf3M3iDYpVKhe3btzc2KO7WrRueeuopWFg0v8DWinvFEhG13NoV8ehUU45ll2oQ7ePUOGK37FIRAu2ssXZFPBatWCl0TCK9cvXyRpZS+4hdZmUNXL28BUh1k84jdllZWXj88ceRm5uLwMBAAEB6ejp8fX3x448/omPHjgYJaigcsSMiuneP9AtDXeZZrOzq2WyuUey5q7DqFIRfDycLmJBI/1rzHDud251Mnz4dHTt2RE5ODlJTU5Gamors7Gy0b98e06dPv+fQRER0/7mSfRnjfZya/HADALFIhCjvtriSfVmgZESGExERgbCRozHhTB5SSpUoq1chpVSJCWfyEDZytF6LOl3pfCv2wIEDOHz4MJydnRuPtWvXDkuWLMHAgQP1Go6IiFo51c0FE9p0srcG8quNHIjIOBKTNkMun4z5sdNQqMiFq5c34nd+JWhRB9zDiJ21tTXKy8ubHa+oqIBEov0fNxERmSYPH19kKWu1nsusrIGHj6+REwmjNTaqJcOLiIjA3pTjSLtagL0pxwUv6oB7KOyefPJJTJ48GUeOHIFGo4FGo8Hhw4cxZcoUPPXUU4bISERErdTClauxKrsY6r9N11ZrNEjIKcHClasFSmY8MVHjEDdiOKLri7BWJkV0fRHiRgxHTNQ4oaORGdK5sFu1ahU6duyI/v37w8bGBjY2Nhg4cCACAgKwciVXPhERmZPWPNfIGORyOZJ3bMOGbl4IcbCF1NICIQ622NDNC8k7tnHkjozunvvYZWVlNbY76dq1KwICAvQazFi4KpaIqOXkcvnNuUZ5N+caLVy52uSLOsD4jWrJPBm8jx0ABAQE3LfFHBER6VfDXCNzU5iXiwCZVOu5TvbWKFTkGjkRmTudb8USERHRTTcb1d5+8YiQjWrJPLGwIyIiukdcPEKtDQs7IiKie2Tui0eo9bnnOXZERETUehvVknnSubDbuHEj2rRpgzFjxjQ5vnXrViiVSkRGRuotHBER0f3AXBePUOuj863YxYsXw8XFpdlxNzc3LFq0SC+hiIiIiEh3Ohd22dnZaN++fbPj/v7+yM7O1ksoXdy4cQN9+vRBr1690L17d6xbt87oGYiIiIhaA50LOzc3N5w6darZ8ZMnT6Jdu3Z6CaULqVSKgwcP4sSJEzhy5AgWLVqE69evGz0HERGROeI+ua2LzoXdCy+8gOnTp2Pfvn1QqVRQqVTYu3cvYmNj8fzzzxsi4x1ZWFjAzs4OAFBTU9O4fy0RkbHxBxyZG+6T2/roXNi9++67CA8Px9ChQ2FrawtbW1s88sgjGDJkyD3NsTt48CBGjBgBLy8viEQibN++vdljEhMTIZPJYGNjg/DwcCQnJzc5f+PGDQQHB8PHxwezZ8/WOgeQiMiQ+AOOzA33yW2ddC7sJBIJvv76a5w/fx6ff/45vvvuO1y4cAGffvopJBKJzgEqKysRHByMxMREree//vprzJw5EwsWLEBqaiqCg4MxfPhwFBQUND6mbdu2OHnyJC5duoQvvvgC+fn5OucgIrpX/AFH5mh+7DRM93OGWCRqclwsEmGqrxPmx04TKJl5E2la0X1LkUiE77//HqNGjWo8Fh4ejr59+yIhIQEAoFar4evri2nTpmHu3LnNXuOVV17BkCFD8Oyzz2p9j5qaGtTU1DR+XlZWBl9f37vaWJeISBtuBE/mqIenG9bKpJBaWjQ7V1avwsuKcqRdLdDyTNJVWVkZHB0d76pWuas+djNnzsS7774Le3t7zJw5846P/eijj+4+6T+ora1FSkoK5s2b13hMLBZj2LBhOHToEAAgPz8fdnZ2kEqlKC0txcGDB/Hvf//7tq+5ePFivPPOO3rLSETEjeDJHN3cJ1f7LzTcJ1c4d1XYHT9+HHV1dY1/NpaioiKoVCq4u7s3Oe7u7o7z588DAC5fvozJkyc3LpqYNm0aevTocdvXnDdvXpPitGHEjojoXvEHHJmjhStXI27EcGzo5tXkdmzDPrnxO78SMJ35uqvCbt++fVr/3BqEhYXhxIkTd/14a2trWFtbGy4QEZkd/oAjc9S4T+6ObZjq64RO9tbIrKxBQk4J98kVkM6LJ8aPH4/y8vJmxysrKzF+/Hi9hGrg4uICCwuLZosh8vPz4eHh0aLXTkxMRFBQEPr27dui1yEiavgB99KpK1iQeQ2zz1/FgsxreOnUFf6AI5OWmLQZ8Tt3Y5OVK15WlGOTlSvid+5GYtJmoaOZLZ0Lu02bNqGqqqrZ8aqqKmzerN8vpEQiQWhoKPbs2dN4TK1WY8+ePejfv3+LXjsmJgZnz57F0aNHWxqTiAgAYCUWYZS7I+YHuGGUuyOsxKJ/fhLRfa5hn9y0qwXYm3Kcv8gI7K5uxQI356I1zGMrLy+HjY1N4zmVSoWffvoJbm5uOgeoqKhAVlZW4+eXLl3CiRMn4OzsDD8/P8ycORORkZHo06cPwsLCsGLFClRWViI6Olrn9yIiMoSGdicbu3s33ooNcbDFxu7emLBjG+TyyfxhR0RGcdftTsRiMUSi2//2KRKJ8M477+CNN97QKcD+/fsxePDgZscjIyORlJQEAEhISMCyZctw7do19OrVC6tWrUJ4eLhO73M7uiwhJiLShu1OiMiQdKlV7rqwO3DgADQaDYYMGYJt27bB2dm58ZxEIoG/vz+8vLxaltyIEhMTkZiYCJVKhYyMDBZ2RHTP2M+LiAzJIIVdg8uXL8PPz++Oo3f3E47YEVFL9e3SGdOtq7SO2B0rVSKx1g5Hz2cIkIyITIEutYrOiyf8/f0hl8vx0ksvYcCAAcjNvdl4c8uWLdw2h4jMkgYarMspRk5VLRIuX8e89GtIuHwdOVW12HClBBq0mg1+iMjE6VzYbdu2DcOHD4etrS1SU1Mbt+cqLS3FokWL9B6QiKi1qykthbvEEvMzCzDQyQ6vd3TFQCc7zM8sgIfEEjWlpUJHNDi5XI4hoSHo4emGIaEh/EWfSCA6F3bvvfce1qxZg3Xr1sHKyqrx+MCBA5GamqrXcIbEPnZEpC9t2rkgv7YeG3p4I8TBFlJLC4Q42GJDD29cq61Hm3YuQkc0qJiocYgbMRzR9UVYK5Miur4IcSOGIyZqnNDRiMyOzoVdeno6HnzwwWbHHR0dcePGDX1kMgr2sSMiffGVtcckX+cmu04AgFgkwgQfJ/jK2guUzPAaWr1s6ObVtKjt5oXkHds4ckdkZDoXdh4eHk36zjWQy+Xo0KGDXkIREd1PFOfOIMBOovVcZ3trXDp31siJjGd+7DRM99Ne1E71dcL82GkCJSMyTzoXdpMmTUJsbCyOHDkCkUiEvLw8fP7555g1axb+/e9/GyIjEVGrptYAWcparecyK2tMevFEYV7ubYvaTvbWKMzLNXIiIvN21ztPNJg7dy7UajWGDh0KpVKJBx98ENbW1pg1axamTeNvZkRkfsIGDca67d8gIciryciVWqPBhislCBv1LwHTGZarlzeylNqbM2dW1sDVy1uAVETmS+c+dg1qa2uRlZWFiooKBAUFoU2bNvrOZlBsUExE+qJQKDC6Xx841lUjytsJAfYSZFXWIim3BKVWNth2+BhkMpnQMQ1CLpcjbsRwbOjWvKidcCYP8Tt3m8V2anK5HPNjp6EwLxeuXt5YuHK1WVw3GYdBGxSbGjYoJiJ92LJ+HVa9/RZUZaWoUqthKxbDwsER099+F2MnThI6nkHFRI1D8o5tmOrrhE721sisrEFCTgnCRo5GYtJmoeMZXMP1T/dzRoCdBFnKWqzKLjab6yfD06VW0flWbHV1NVavXo19+/ahoKAAarW6yfn7qeUJEZE+OUisENXepfGHe1K50ImMIzFpM+TyyTdHrBQ3R6zid35lFiNWt64KbhixbFgVPGHHNsjlk83i74FaD51H7F588UX8+uuvePbZZ+Hu7t5sa7EFCxboNaChccSOiFpKoVBg8sODsNzdqtntyFn5dVj7236TvRVr7oaEhiC6Xvscw5RSJTZZuWJvynEBkpEpMeiI3a5du/DTTz9h4MCB9xyQiMiUrF0Rj7F2KohFTVeHikUivGSnxtoV8Vi0YqVA6ciQCvNyESCTaj3Xyd4ahQquCibj0rndibe3N6RS7d/E9xPuPEFE+nLx/LnbtvwIsLPCxfRzRk5ExnJzVfDtW91wVTAZm86F3Ycffog5c+bg8uXLhshjNNx5goj0pUOXrrf94Z6lrEOHwK5GTkTGsnDlaqzKLob6b7Oa1BoNEnJKsHDlaoGSkbnSubDr06cPqqur0aFDB0ilUjg7Ozf5ICIyN5NnxGGL0kLrD/fPlGJMnhEnUDIytIiICISNHI0JZ/KQUqpEWb0KKaVKTDiTh7CRo7lwgoxO5zl2L7zwAnJzc7Fo0SKtiyeIiMyNTCbD2DlvYNbS9/GSnRoBdlbIUtbhM6UYY+e8wYUTJs6cVwVT66Pzqlg7OzscOnQIwcHBhspkVFwVS0T6olAosHZFPC6mn0OHwK6YPCOORR0RtZhBV8V26dIFVVVV9xyOiMhUyWQyrn4lIkHpPMduyZIlePXVV7F//35cv34dZWVlTT6IiIiISBg634oVi2/Wgn+fW6fRaCASiaBSqfSXzoC4VywRERHdDwy6V+yBAwfueP6hhx7S5eUExzl2RERE1JoZdI7d/Va4EREREZkLnQs7ALhx4waSk5NRUFAAtVrd5Ny4ceP0EoyIiIiIdKNzYbdz5068+OKLqKiogIODQ5O5diKRiIUdERERkUB0XhX76quvYvz48aioqMCNGzdQUlLS+FFcXGyIjERERER0F3Qu7HJzczF9+nTY2dkZIg8RERER3SOdC7vhw4fj2LFjhshCRERERC1wV3Psfvjhh8Y/P/HEE5g9ezbOnj2LHj16wMrKqsljn3rqKf0mJCIiIqK7cld97BqaEv/ji7FBMRGRWZLL5ZgfOw2Feblw9fLGwpWrERERIXQsIpNg0AbFpoYNiolIXxQKBdauiMfF8+fQoUtXTJ4RB5lMJnQsg4uJGofkHdsw3c8ZAXYSZClrsSq7GGEjRyMxabPQ8Yjue7rUKjrPsdu8eTNqamqaHa+trcXmzfwHTETmacv6dZgw5AF03fcdYksvoOu+7zBhyAPYsn6d0NEMSi6XI3nHNmzo5oUQB1tILS0Q4mCLDd28kLxjG+RyudARicyKziN2FhYWuHr1Ktzc3Jocv379Otzc3O6bW7ENOGJHRC2lUCgwLiIcCTIHiG/p7anWaDBVUYbN8iMmO3I3JDQE0fVFCHGwbXYupVSJTVau2JtyXIBkRKbDoCN2Go2mSVPiBleuXIGjo6OuL0dEdN/78N13MKmdpElRBwBikQjjna3w4bvvCJTM8ArzchFgJ9F6rpO9NQrzco2ciMi83fXOEyEhIRCJRBCJRBg6dCgsLf96qkqlwqVLl/Doo48aJCQRUWt2ZP8+/J+b9uKms701EvfvN24gI3L18kaWUvuIXWZlDVy9vAVIRWS+7rqwGzVqFADgxIkTGD58ONq0adN4TiKRQCaTYfTo0XoPSETU2qk1QJay9rbFjebetuW+LyxcuRpxI4ZjQzevZrehE3JKEL/zKwHTEZmfu/7fZsGCBQAAmUyG5557DjY2NgYLRUT3L3NsexE2aDDWbf8GCUHNi5sNV0oQNupfAqYzrIiICISNHI0JO7Zhqq8TOtlbI7OyBgk5JQgbOdrkv/ZErQ3bnXDxBJHemGvbC4VCgdH9+sCxrhpR3k4IsJcgq7IWSbklKLWywbbDx0x28UQDcyzoiYxF733snJ2dkZGRARcXFzg5OWldPNGguLhY98QCYmFHpB9yufy2t+QmnMlD/M7dJv2Dfsv6dVj73jtwqy6HUqWGnYUYBTZSTH5zAcZOnCR0PCK6j+m9sNu0aROef/55WFtbY9OmTXd8bGRkpG5pBcbCjkg/7tT24lipEpvNoO1FY4Pi9HPoEGg+DYqJyLB0qVXuao5dQ7FWX18PkUiE4cOHw93dveVJichkXL54AQFdXLWe62xvjcvnLxg5kfHJZDIsWrFS6BhEZMZ06mNnaWmJKVOmoLq62lB5jCYxMRFBQUHo27ev0FGITEJVVRWylLVaz2VW1qCqqsrIiYiIzI/ODYrDwsJw/Pj9fzslJiYGZ8+exdGjR4WOQmQSRAA+vVIC9d9md6g1GiTl3sDtZ+YSEZG+6Nxc6ZVXXsGrr76KK1euIDQ0FPb29k3O9+zZU2/hiOj+YWVtjSB7a8w4d7XZytAudhJcqhc6IRGR6dO53YlY3HyQTyQSNW41xr1iicxTYmIiPn19Ft7r5I4dBeXIra6Dt40VRrpJ8WZmPsYvWo6YmBihYxIR3Xf0vir2VpcvX77jeX9/f11eTnAs7Ij0p2/nAKjzczHdv11jo9pVl69D7O6NoxlZQscjIrov6X1V7K3ut8KNiIznaEYWEhMTseD1OdDU1UFkZYXZHKkjIjKae9554uzZs8jOzkZtbdNVcE899ZReghkLR+yIiIioNTPoiN3Fixfx9NNPIy0trXFuHYDG3Sjutzl2RERERKZC53YnsbGxaN++PQoKCmBnZ4czZ87g4MGD6NOnD/bv32+AiERERER0N3QesTt06BD27t0LFxcXiMViiMViREREYPHixZg+fbpJ9LgjIiIiuh/pPGKnUqkglUoBAC4uLsjLywNwc1FFenq6ftMREd1H5HI5hoSGoIenG4aEhkAulwsdiYjMjM4jdt27d8fJkyfRvn17hIeH44MPPoBEIsHatWvRoUMHQ2QkImr1YqLGIXnHNkz3c0aATIosZRHiRgxH2MjRSEzaLHQ8IjITOo/Yvfnmm1Cr1QCAhQsX4tKlS3jggQfw008/YdWqVXoPSETU2snlciTv2IYN3bwQ4mALqaUFQhxssaGbF5J3bOPIHREZzT23O7lVcXExnJycGlfG3k/Y7oSIWmpIaAii64sQ4mDb7FxKqRKbrFyxN4Xzj4no3hi03Yk2zs7O+ngZIjIBCoUCa1fE4+L5c+jQpSsmz4iDTCYTOpZBFeblIkAm1Xquk701ChW5Rk5EROZK51uxrU1OTg4GDRqEoKAg9OzZE1u3bhU6EpHZ2rJ+HSY/PAhd932H2NIL6LrvO0x+eBC2rF8ndDSDcvXyRpayVuu5zMoauHp5GzkREZkrvdyKFdLVq1eRn5+PXr164dq1awgNDUVGRgbs7e3v6vm8FUukHwqFApMfHoTl7lYQ3zItQ63RYFZ+Hdb+tt9kR+7kcjniRgzHhm5eza59wpk8xO/cjYiICAETEtH9TJda5b4fsfP09ESvXr0AAB4eHnBxcUFxcbGwoYjM0NoV8Rhrp2pS2ACAWCTCS3ZqrF0RL1Ayw4uIiEDYyNGYcCYPKaVKlNWrkFKqxIQzeQgbOZpFHREZzV0Vdr1790ZJSQmAmythlUql3gIcPHgQI0aMgJeXF0QiEbZv397sMYmJiZDJZLCxsUF4eDiSk5O1vlZKSgpUKhV8fX31lo+I7s7F8+cQYCfRei7AzgoX088ZOZFxJSZtRvzO3dhk5YqXFeXYZOWK+J272eqEiIzqrgq7c+fOobKyEgDwzjvvoKKiQm8BKisrERwcjMTERK3nv/76a8ycORMLFixAamoqgoODMXz4cBQUFDR5XHFxMcaNG4e1a9fqLRsR3b0OXbredp5ZlrIOHQK7GjmR8UVERGBvynGkXS3A3pTjHKkjIqO7qzl2/fv3R5s2bRAREYF33nkHs2bNQps2bbQ+dv78+fceRiTC999/j1GjRjUeCw8PR9++fZGQkAAAUKvV8PX1xbRp0zB37lwAQE1NDR5++GFMmjQJY8eOveN71NTUoKampvHzsrIy+Pr6co4dUQuZ8xw7IjJPCoUCKxPWID3zAgI7dUTs1CkG+X9O7+1OkpKSsGDBAuzatQsikQg///wzLC2bP1UkErWosPu72tpapKSkYN68eY3HxGIxhg0bhkOHDgEANBoNoqKiMGTIkH8s6gBg8eLFeOedd/SWkYhukslkGDvnDUx5Zz4865RQqtSwsxDjqpUd/r1gIYs6IjIp6zduwvKEjXDq8SSknfvgcKECP4+Owqyp0ZgYHSlYrrsq7AIDA/HVV18BuFlY7dmzB25ubgYNBgBFRUVQqVRwd3dvctzd3R3nz58HAPz3v//F119/jZ49ezbOz9uyZQt69Oih9TXnzZuHmTNnNn7eMGJHRC13WP4HasvLMMrfGQF2EmQpa7HycjEOy//A2ImThI5HRKQXCoUCyxM2osNjsyES3ZzV5uzbDU4+XbE8YRmGDX5IsF9mdW5Q3LCdWGsRERGhUyZra2tYW1sbMBGReWrYVuvT7n+1/AhxsMWn3b0wYcc2yOWTOeeMiEzCyoQ1cOrxZGNR10AkEsO5x5NYmbAG8cuXCJLtnnaeuHDhAlasWIFz526ucgsKCkJsbCw6duyo13AuLi6wsLBAfn5+k+P5+fnw8PBo0WsnJiYiMTERKpWqRa9DRDfNj52G6X7OWtudTPV1wvzYadxWi4hMQnrmBUg794HyRj6unPgVyhvXYNfWAz69HkEbV3+kZxwRLJvOfex2796NoKAgJCcno2fPnujZsyeOHDmCbt264bffftNrOIlEgtDQUOzZs6fxmFqtxp49e9C/f/8WvXZMTAzOnj2Lo0ePtjQmEeF/22rdpt1JJ3trFOZxWy0iMg2BnTri0pHvcXb3x3Dp2BvdHn0FLh174+zuj6E4sh2BnfQ70KULnUfs5s6di7i4OCxZsqTZ8Tlz5uDhhx/W6fUqKiqQlZXV+PmlS5dw4sQJODs7w8/PDzNnzkRkZCT69OmDsLAwrFixApWVlYiOjtY1OhEZ0M1ttYoQ4mDb7By31SIiUzJ61JPYvHUKwscubTrH7l9dcWTLHIxe9G/Bsum8pZiNjQ3S0tLQqVOnJsczMjLQs2dPVFdX6xRg//79GDx4cLPjkZGRSEpKAgAkJCRg2bJluHbtGnr16oVVq1YhPDxcp/f5u1tvxWZkZLDdCVELcVstIjIXE16Owdma9nD27dbs3PXs0+hmo8CGT7T3570XurQ70bmw8/X1xUcffYQxY8Y0Of7NN99g1qxZyM7O1j2xgLhXLJH+xESNQ/KObZjq64RO9tbIrKxBQk4JwkaO5g4MRGQyOnXrjfaPvw4rm+b70tdVV+DST4uReSZVb++n9z52t5o0aRImT56MixcvYsCAAQButhxZunRpkzYiRGR+EpM2Qy6fjPmx01CoyIWrlzfid37FkToiMikajRrlhQqtI3ZlBQpoNMJ1ENF5xE6j0WDFihX48MMPkZeXBwDw8vLC7NmzMX36dIj+tiKuteOIHRER0b0z1u4LrcnQRx5D2sUi9Hn+7SYtTzQaNY599TZ6dHDBnl9/1tv7GfRW7K3Ky8sBAFKp9F5fQnAs7IiIqKXMsbgBbu6+8N6yBKis20FVWw0LiQ0saq7jzdlTBd19wdA8fDqgWm0Fa6kzOvQbDambP8oLLuPi4W2oKS+GjbgO165c1Nv76VKr6Nzu5FZSqfS+LeoSExMRFBSEvn37Ch2FyKTI5XIMCQ1BD083DAkNgVwuFzoSkUGt37gJj46OwuEiF9R3fg6Hi1zw6OgorN+4SehoBqVQKDBvwWJUqSzhEzwM3R+fCp/gYahSWWLegsVQKBRCRzSYqtp6+IY8ijplKc7v3YgjW+bi/N6NqFOWwjdkOKpq6wTL1qIRO1PAETsi/WlYPDHd768txVZlF3PxBJkshUKBR0dHNdlaCrh5S+7iz8vwy7Ykkx25+9f/jcX+5PO3vR05KKwLvvlii4AJDcdHFghbvz4ozctAx4jnIHWVobxQgQvyr+Ho1RlV2cdwRZGut/cz2ogdEWmnUCjw+oxYPP/oI3h9RqxJ/+baoGFLsQ3dvBDiYAuppQVCHGyxoZsXknds48gdmaS72VrKVO07IEfHiOdQVVqIjP1bcGL7MmTs34Kq0kJ0GPgv7Dtguv/m582ejuuXTqDP82/D2bcbrGzs4ezbDX2efxvXL53AvNnTBcvGwo5Iz7asX4fJDw9C133fIbb0Arru+w6THx6ELevXCR3NoO5mSzEiU5OeeQFSV5nWc21c/ZGeecG4gYyoXqVBRWEOzv36SZPdF879+gkqi3JQr2pde8vr08E/kxE4NFprQd95SCQO/pksUDIzLuw4x44MQaFQYMvS97Hc3arJqNVydytsWfq+SY/ccUsxMkeBnTqivFCh9VxF4WVBt5YyNOe2UuSn/4kuD09C0YVUnPnlPyi6kIouD09CfvohOLc13elNv+89eNuC3sGtPX7fe9C4gW5xT4Xd1KlTUVxcrO8sRsW9YskQ1q6Ix1g7ldZRqxdtVVi7Il6gZIZ3c0uxWq3nuKUYmarYqVNQkrYLlSVXm9yOrCy5iuK0XYidOkXoiAYjkUjg6NUZ539b12TE7vxv6+Do1QkSifZf9ExBdXXVbQv6sgIFqqurjBvoFndd2F25cqXxz1988QUqKioAAD169EBOTo7+kxHdh86eOH6HUSsJzp44buRExrNw5Wqsyi6G+m/rsdQaDRJySrBw5WqBkhEZjkwmQ7+QLji9K75JcXN6Vzz6hXQx2YUTAFBcfANl1y5oHbEru3YRxcU3hI5oMBq1ChfkXzdrRKzRqHHxv99Ao1YJlEyHnSe6dOmCdu3aYeDAgaiurkZOTg78/PygUChQVyfcsl6i1qSsvBxZ9bUIcbBtdi6zsgZlteUCpDKOiIgIhI0cjQm32VKMu0+QKVIoFNj353GEvbSkyWbwYS8twb7vFkKhUJhscWdlaQlbj444/9s6tO8/unFl6Pnf1sHBoyOqskuEjmgwft7eyMm/ij83zkQbNxnUtdUQS2xQUaBAXXUl/LyFu0Nx1yN2N27cwNatWxEaGgq1Wo3HH38cnTt3Rk1NDXbv3o38/HxD5iS6L9i3aYNPr5RoHbVKyr0B+zZtBEpmHIlJmxG/czc2WbniZUU5Nlm5In7nbrY6IZP17uJl8Aobo3USvWffZ/Hu4mUCJTO8NlI7lF27gN5j3myyMrT3mDdRdu0C2kjthI5oMNNiJkFsYQWJXVv4Bj+M7o9Pg2/ww5DYtYXYwgrTYiYJlu2u+9hVVVXB1vbmKISTkxNSUlJw9epVDBs2DN27d8eZM2fg6+uL9HT99W0xBvaxI316fUYsSr/7DOcqaxDl7YQAewmyKmuRlFuCLvbWaPvMS1i0YqXQMYlIT4y9GXxr0tbVB12fmKF1v9Tr2adx/qeVuFF4Rcsz7389Q8JxrVx92x5+HlIxTh0/orf3M0gfu7Zt2yI8PBwzZ85EbW0tqqqqMHDgQFhaWuLrr79GSUkJNmzY0OLwxsJVsWQIk2fEIdNaitntXfDnDSUWXyjEnzeUmN3eBVnWUkyeESd0RCLSo4bN4LURejN4Q6tTq++wMlSGOhNud5KdexUdI57TOlLbYeC/kJ17VaBkOhR2ubm5ePPNN2FtbY36+nqEhobigQceQG1tLVJTUyESie6rOTRcFUuGIJPJMHbOG4gvBQY42WNeR1cMcLJHfCkwds4bJjvXhkihUCBu1lw8PnIM4mbNNenWPrd6cGD/O06if3Bgf4GSGZ6FSHTHotZCLNJ6zhRIJFZ3LGolEivjBrrFXRd2Li4uGDFiBBYvXgw7OzscPXoU06ZNg0gkwqxZs+Do6IiHHnrIkFmJ7gtjJ07C2t/24/zgp7GqbUecH/w01v62H2MnCjfngsiQ1m/chEGPj8G3e07gRFYhvt1zAoMeH2Pye6UCQGiv7qgqv45jXy/E9ezTqKuuwPXs0zj29UJUlV9HaK/uQkc0mAcGhCFj3yatrV4y92/GAwPChI5oMIMfirhDUXsJgx8SbqDrnvaKdXJywsmTJ+Hn5wepVIqTJ0/Czs4OBw4cwHPPPWeInAbDOXZEpC8KhQIrE9YgPfMCAjt1ROzUKSY/SqtQKND3wUchtnVCh1tWRl48tA3qqhIcPfiLSf8d+MgCobJzR1VpPqRu7aGqrYaFxAblBZdg6+gOC2W+XvcMbU0UCgW6hfSHtbQdAodENX7t0/cmoab8Os4cP2SyX3uFQoE+g0Y0WQ0N3BypTd4yB8cO7NLrtRt8r9hTp07Bx8cHAODv7w8rKyt4eHjcd0UdEZG+rN+4CY+OjsLhIhfUd34Oh4tc8OjoKJMftXrt9bcgkkjR9W+9zLo+PAkiiRSvvf6W0BEN6kZZGdT1tRgQ/RF6PhmLkGfmoOeTsRgQ/RHU9bW4UVYmdESDuXLlCmwc2qHfuKVNVsX2G7cUNg7tmvS/NTVXrlxBXbUSKVvfQ3H2GdRVV6A4+wxStr6HupoqQa/9nkbsTAlH7MgQFAoF1q6Ix8Xz59ChS1dMnhFnsr+50s2v96Ojo9DhsdnNfnu/+PMy/LItyWS//q7e7dG24wCU519s0svs0qFtkLp3wI0Lf6Iw95LQMQ2mjbMnejz16m1Xhp7e+REqioWbSG9Iffs/BMvAEbe9dlXGLhw9dECAZIbXcO02UhdcOfErlDeuwa6tB3x6PYKqskK9X7vBR+xMAVfFkqFsWb8OYwdFIP2rJJSmHEL6V0kYOygCW9avEzoaGcjKhDVw6vGk1hVyzj2exMqENQIlM7ya2ro79jKrqdW+zZypsLG1ueMkehsbG+MGMqKcvKt3vPacPNMsaAEg71oBpK4y2LV1R+dBY9Fr1Gx0HjQWdm3d4eAmQ961AsGymW1hx1WxZAgKhQIr3pwHq/IbGOXugPkBbhjl7gCr8htY8eY8s1kpaG7SMy/c9gdcG1d/pGdeMG4gI7KxlqDDgGe1FrXt+4+GjbW1QMmMo72f3x0n0bf39zNuICOy/IdVsZYi010V69Ku7R2v3aVdW6PmuZXZFnZEhvDBwndgX1uFFV09EeJgC6mlBUIcbLGiqyfsa6vwwcJ3hI5IBhDYqeNt/5OvKLyMwE4djRvIiGxsbP9hxKr59nqm5LWZU5G+N0nrytCMfZvx2sypQkc0GB8/H2Qd/EJrq5cLf3wJXz8fgZIZXueADsjYt0nrtWfu34zOAR0ESsbCjkivkvfvwyRfZ4j/9puqWCTCBB8nJB/YJ1AyMqTYqVNQkrZL63/yxWm7EDt1ikDJDE8isbrjyIWQ/byM4c8jKQBEOPXDR3Dp2BvdHn0FLh1749QPH0EE0f/Om6agwM6oq67UvoCguhJdAzsLHdFgKqtqYdfOF0e2zG3S5ubIlrmwa+eDyirhpiBYCvbORCZIpVYhwE6i9Vwne2uoykx7vpG5kslkmDU1Gos/fBe1Vm2hqq2ChcQWkrobmPfqv0124QRws0HvD79/hrAX32+2cCTr4Od4apjpNugFgOMnT8HKxh6h/3qr8fobVoamfPMujp88JXBCw6lUKtH1kUmNCwiyU3+CXVsPBD3yMqrKClCpPCN0RIMJ7NQRJU4u0GiAC398iVplKSR2jug0KBJiERDoUiRYNhZ2RHpk7+CALGUZQhya337KrKyBvYOjAKnIWMSWVvAJHta4MrToxA6hIxnc+MgXsW3nr0jZ+h469BsNqZs/ygsu4+LhbairKsf4yBeFjmhQlRWVd5xjWJmxS6Bkhpd6Ig3tHx8BKxt7dB40tsk5Kxt7pP70lUDJDC926hT8/L+V8O1efK/xeMNK+NhtSYJlY2FHpEdOTs749PRlrOxq0+R2rFqjQVLuDTh1by9gOjIUhUKB5Qkb0enJuU1GbZx8umJ5wjIMG/yQyY7abdu+C0GP/hvV5cU488t/oK6vhdhSgo4Rz8OmjRO2bd91X203qauaulo43WGO4bUzpjtK37BPrrZ2J6a+T27DKP3yhGVw7vEk2rj6o6LwMorTdmHW1GhB/72zsCPSox4hvVF66SxePp0LLxsrKFVq2FmIkVddhxAHW7QN6S10RDKAhnYnVaWFzXpaNbQ7iV++ROiYBpGeeQFXck5BeT0H3R57pXG0MmPfJti180G6bxuhIxpUTvYV2N+huMnJNt0mvQ8O7I9d+76G0/NvN+/f+N9v8ORg074NPzE6EsMGP3Rzt5mMIwjv1BGxraBnJRsUs0Ex6ZFCocDofn3QprYKE32dEWAnQZayFutzilEhscW2w8cE/0dP+vf4yDHIqfdCYVZysya9rgF94Wt5FT/t2Cp0TIN44qmncfhkFsLHNt9a6ciWuegXHIAff/hewISGJW3nBbt2fk3m2AE3rz/lm3ehvH4Z5ddNs5+bQqHAwKFPod7SvtlteMv6Svx3zw/8/05P2KD4LrBBMRmKlVqF1UFeTdqdrA7ygpVaJXQ0MhAPNxfkp/+ptUlvfvoheLi5CB3RYI6mnkTnwZFa55h1GjQOR1NPCpTMOBylUjh4dETq31aGpm59Dw4eHeEoNd0BA5lMhnfefBVtrDS4cup3nP4pAVdO/Y42Vhq88+arLOoEYraFHRsUkyF8+O47iPFso7XdyRQPe3z4LvvYmSKRWIyOEc9pLW46DPwXRGLT/a9WWVVzxz52yqoa4wYysvhl7+G64iS6PDwJRRdTceaXj1F0MRVdHp6E64qTiF/23j+/yH1sYnQk9uz6Cs8O7YVeAW54dmgv7Nn1FSZGRwodzWxxjh2RHh3YvRv/56u9IWtne2t8tHu3kRORMVy9VgBp58Fazzm4yXA1I9nIiYxHLMIdJ9CLTXfzAQDAmDFjsPI/63Dqh4/QeXAk2rvJUFagwKkfPkJQRy+MGTNG6IgGJ5PJTHYO6T9RKBQ359hlXkBgp46InTpF8JFK0/01kkgAJTdKkKXUvgous7IGJTdKjJyIjOFOO0+UFShMeueJBwaE3bED/wMDwgRKZjzyfb/ikw8X4MqB9TiyaSauHFiPTz5cAPm+X4WORga0fuMmPDo6CoeLXFDf+TkcLnLBo6OjsH7jJkFzsbAj0iOxSIRPr5RA/bc1SQ3tTv5+i9YUKRQKxM2ai8dHjkHcrLlmsT9uQHtfpO9N0lrcZOzbhID2vgIlM7zEVfGory7T2oG/vroMiavihY5oFGPGjMEVRTpKC3JwRZFuFiN15qyhxVGHx2Y3mVfb4bHZWJ6wUdD/97gqlqtiSY8ienZD8I1rOFdZgyhvJwTYS5BVWYuk3BIE2kmQ5uQJ+SnT7ca+fuMmvLcsASrrdlDVVsNCYgOLmut4c/ZUk55z4yMLhK1fH5TnX0T7W1YHXjq8DVL3DqjKPoYrinShYxrM+o2b8Pb7H6Kyph7quhqIraxhb22Jt9941aS/7mS+4mbNxeEiF61TEEpyziDcpUivt6d1qVVY2LGwIz2Sy+WY8eQjeDfADTsKypFbXQdvGyuMdJPirawCrNj1q8k2a1UoFOj74KMQ2zqhwy0tPy4e2gZ1VQmOHvxF8LknhuLo5ofwyI9QV13RrI+dlY09jmyaidKCHKFjGlRrnGtEZCiPjxyD+s7PwcrGvtm5uuoKWGZ8o9cWR2x3QiSQiIgIhI96FnMy8nFZbYViiQMuq60wJyMf4aOeNdmiDgBee/0tiCRShP6t5UfomDchkkjx2utvCR3RYKR2tnecYye1szNuIAE0TKD/acdWxC9fwqLOjGzduhU+skA4uvnBRxaIrVtNs2fjrSzEmjv8m78EC7FwY2ZcFUukZyEPDcXPqVm40fsZSF1luFGoQEXqdwh5aKjQ0Qxq3wE5Ah6OuW3Lj32//0egZIYXv+w9jJ86B3Zt3ZuMVp799RMob+Tj04SlQkckAzPXEcuIwY/g7IU8BA6Z1Ph9//Kr72Dlf9aZ9OKRo0dToba5fJtdN7ZCUX1dsGwcsSPSo4YJtV2eeqPJqFWXp94QfEKtodWrNHfsZ1avMt19Iz09PWElsdU6WmklsYWnp6fQEcmAWuvqSEPbunUrzl7IQ79xS5t83/cbtxRnL+SZ9MhdVW093AMHaG1M7R7YH1W1dYJl4xw7zrEjPTL2hNrWpGvPPmjb+3nYSF2azTOrKitA6fFvcO7UMaFjGkTf/g/BMnCE1q/79ezTUGXswtFDBwRIRoamUCjw6OgodHhsdvORm5+X4ZdWsHeoofjIAuE7aNJtv++vHFhvsouGGq5d+/93hXq/ds6xIxJIeuaF245atXH1R3rmBeMGMqLuQYE4/fN/cPbXT+DSsTe6PfoKXDr2xtlfP8GZX9age1Cg0BENJu9awR1HK/OuFRg3EBnNyoQ1cOrxpNYpCM49nsTKhDUCJTO8cmXVHb/vy5VK4wYyovhl7yF9bxJsHV3RedBY9Bo1G50HjYWtoysy9m0SdMcRsy3suFcsGcKdGtVWFF426Ua1hddLILHVvnhCYitF4XXTbc7s5eF2x8UTXh5uxg0kAHPsXwiY9y9z5rxoaMyYMQjq6IXDm+c06d94ePMcwXccMdvCjnvFkiHETp2CkrRdWhvVFqftQuzUKQIlM7zKikp0HjRW+2bwD72EyopKgZIZXvyy95G5f7PWr3vWgS2IX/a+QMmMw1znmAHm/ctcw6jV7Rpzm/o+ua11xxHOseMcO9Kz9Rs3YXnCRjj3eBJtXP1RUXgZxWm7MGtqtEk3ax08fASsekbetq9T/anN2Lt7pwDJjKNhdWDnwZFw+N9+oRn7NiGoo5fg/9EbkjnPMQNuXn/YkJHo83+Lml3/sS9eR/LeHSZ7/QqFAsHhD8HSxgGdBo1r/L7P3L8Z9dVlOHnkgMleu7HpUquw3QmRnk2MjsSwwQ/dbH2QcQThnToi1sR/uAFArx7dcPg2m8GXFyjQr0fz46ZCoVCgqKwWnQeNw4U/vkKtshQSO0d0HjQORWd+gkKhMNmvf8Mcs+Kcc7jwx5eorSqFxNYRHR94oXGOmakuGAKAo0ePoqqyHClb30OHW3YduXh4G6oqy3H06FGT/toHPTYNGg1ufu3/933faVAkxCKY/Ne+teKIHUfsiPTCnEduzHk19OMjxyAtpxLK6znoPDiysZdZxr5NsGvngx6+bfTagb+1MfbqyNbE2LsvmDOO2BGR0clkMsyaGo3lCcu03oY21aIO+N8E+s59oLyR3+yHextXf6RnHBE6osGINPVQXs9B+NgljQW9s283hI9dgiNb5kLkEyBwQsNqWBlqZWOPzoPGNjlnZWNv0itDAzt1vO0ofUXhZYSb8PzC1sxsF08Qkf5NjI7EL9uSEO5SBMuMbxDuUoRftiWZ9NxC4OYPuEtHvse5v7V6OffrJ1Ac2W7SE+hTT6Sh8+BI7YtmBo1D6ok0gZIZhzmvDDXnxWKtGQs7IqIWGj3qSZRkp6H331q99B7zJoqz0zB61JNCRzQYZXXtHXuZKatrjBvIyMx5ZWjDKP3Fn5ehJOfm7gslOWdw8edlJj9K35qxsCMivTHXthfbtu9Cp0HjbtPqZSy2bd8lUDLDM+cRK6B19zMzBnMdpW/NOMeOyADMcUPwhn1yb1084ezbDU4+XbE8YRmGDX7IZP8OGubYaSN1kyE9I9nIiYwnftl7ePnVd9Bv3NJmi2Yy9m3CJx+a7ohVA/m+X7F161bEzX4T5UolpHZ2+OTD90y+qGsgk8lMdnHQ/YgjdkR6Zq6jVnfaWsmp+xMmvbWSOTepHTNmDFwcrLSOWLk4WJlNcTNmzBhcUaSjtCAHVxTpZnPd1PqwsCPSo1tHrW6da9XhsdlYnrDRpLdZOpF25rZzraRuMpxMO2PcQEZkzpPIFQoFxNaOjT38jmyZhwt/fIXOg8ZBbO1o0t/zRK0Rb8US6dHdbAhuqrcsKsrKYHmb1gdlBQqoysoESGUcDZPI31v+Ljz6PAMHt/YoK7iEa8e+w5uz/m2yt6CBv77nnX27oZ1/jybnxGKxSX/PE7VGHLEj0iNz3hDcvo09Lv75rdZRq0uHtsG+TfMmpqYkafPnKCkpRu6pvTj9UyJyT+1FSUkxkjZ/LnQ0gzLn73mi1ogjdmQwcrkc82OnoTAvF65e3li4cjUiIiKEjmVQ5tywMyS4J3JrcnB402xoIIK6rhpiKxuIoIFL+xCEBPsKHdFgtm7dirMX8jBwfHyzBQSHN8/B1q1bTXbOlTl/zxO1RiYxYvf000/DyckJzz77rNBR6H9iosYhbsRwRNcXYa1Miuj6IsSNGI6YqHFCRzOo2KlTcOHgFq2jVlkHt5j0XKvYqVNQkC6HSGyBLkOjET52KboMjYZIbIH8dLlJX3vc7DcROCRK6y34zoMjETf7TYGSGZ45zy8kao1MorCLjY3F5s2bhY5B/yOXy5G8Yxs2dPNCiIMtpJYWCHGwxYZuXkjesQ1yuVzoiAZz64bgxdk3G3YWZ59Bytb3GjcEN1VHjx6FhbUU4WOXNFk4Ej52CSyspSZ97Q3bSmnj4CYz6W2l2KSWqHURaTQajdAh9GH//v1ISEjAt99+q9PzdNlYl+7OkNAQRNcXwU1iic+v10FRJ4LMSoMX21nhWk0dNlm5Ym/KcaFjGkTDhuAaDXDhjy9RW1UKia0jOj7wAgCNSW8I3nDt2m7JXc8+zWs30WtvYI69G4mMRZdaRfARu4MHD2LEiBHw8vKCSCTC9u3bmz0mMTERMpkMNjY2CA8PR3Ky6Tb7NAWFebk4V61BTJE9Mh+cAZvn45H54AzEFNkjvebmeVNVrqxCRWEOFIe3IeDBFxD+0hIEPPgCFIe3obIox6RHbsx51Mqct5Vq0NCk9qcdWxG/fAmLOiKBCF7YVVZWIjg4GImJiVrPf/3115g5cyYWLFiA1NRUBAcHY/jw4SgoKDByUrpbEgdHfFHjiM4vLm1yS67zi0vxeY0jJA6OQkc0GFuJJfLT/9S6Z2h++iHYSqyEjmgw5ry1lLlvK0VErYfgq2Ife+wxPPbYY7c9/9FHH2HSpEmIjo4GAKxZswY//vgjPv30U8ydO1fn96upqUFNzV+bUpeZcG8todRK7OHZ/yVUlRbiyolfobxxDXZtPeDT6xF4PhSJytQvhI5oMH379kaRYz+tk+g7DBwD1zLTHW02962lzH1bKSJqHQQfsbuT2tpapKSkYNiwYY3HxGIxhg0bhkOHDt3Tay5evBiOjo6NH76+ptuCQShl5ZWoKMzBuV8/gUvH3uj26Ctw6dgb5379BJVFV1BWXil0RINRqUV3uB3ZHiq1yLiBjIijVtxWioiE16oLu6KiIqhUKri7uzc57u7ujmvXrjV+PmzYMIwZMwY//fQTfHx87lj0zZs3D6WlpY0fOTk5Bstvrmysre54O9LG2nRvR5rznqHAzVGrTz5cgCsH1uPIppm4cmA9PvlwAeT7fhU6GhGRWRD8Vqw+/P7773f9WGtra1hbWxswDd2oUKLjw+Nvezvy0u//ESiZ4cVOnYKfR0fByadrs9uRxWm7ELstSbhwRjJmzBiOVBERCaRVj9i5uLjAwsIC+fn5TY7n5+fDw8OjRa+dmJiIoKAg9O3bt0Wv80/kcjn69RsIf1kn9Os30KR7uDWoqVXd8XZkTW29cQMZEXt63Wx7ETdrLh4fOQZxs+ZyE3giIiNq1YWdRCJBaGgo9uzZ03hMrVZjz5496N+/f4teOyYmBmfPnjVo09RJ4ydi9IsvQ9TlaQSOWQxRl6cx+sWXMWn8RIO9Z2vQxt7mjqsj29jbGjeQkU2MjsQv25IQ7lIEy4xvEO5ShF+2JWFidKTQ0Qxu/cZNeHR0FA4XuaC+83M4XOSCR0dHYf3GTUJHIyIyC4Lfiq2oqEBWVlbj55cuXcKJEyfg7OwMPz8/zJw5E5GRkejTpw/CwsKwYsUKVFZWNq6Sba3kcjl+2HMIoWM/aLwl5+zbDU5jP8APW15DpFxusvumzps9A/OXfozw26yOfHfODOHCGUlDTy9zolAosDxhIzo8Nrvp97xPVyxPWIZhgx8yixFLIiIhCT5id+zYMYSEhCAkJAQAMHPmTISEhGD+/PkAgOeeew7Lly/H/Pnz0atXL5w4cQK//PJLswUVrc2sWXPQYXC09nlmgyIxa9YcgZIZXm5mBnrUFSD5b6sjkzfPQY+6AuRmZggdkQxgZcIaOPV4Uuv3vHOPJ7EyYY1AyYiIzIfgI3aDBg3CP+1qNnXqVEydOlWv75uYmIjExESoVCq9vm6Dq9cKEPiATOs5qVt7pB8w3QbLF8+fw2KZFBmVN7D8lw+hUFnA1UKFD90t0cleilXp54SOaHB/9TOrgtTOFvHLTL+fWXrmBUg799F6ro2rP9Izjhg5ERGR+RG8sBNKTEwMYmJiGvdf0zdPDzeUFyq07h1ZXnAJnh5uen/P1qJDl67I2ncOoY52+PJvf7WpZdXoEN5VmGBGEjH4EZy9kIfAIZMgdZWhvFCBl199Byv/s86k234EduqIw7f5nq8ovIxwE2/1QkTUGgh+K9ZULV++FBf3bdS6d+TF/ZuwfPlSgZIZ3uQZcdiitID6byOxao0GnynFmDwjTqBkhrd161acvZCHfuOabqfWb9xSnL2Qh61btwod0WBip05BSdourd/zxWm7EDt1ikDJiIjMBws7A4mIiMBTQ/sjZctrKM5OQ111BYqz05Cy5TU8NbS/yS6cAG4uHBg75w3Myq9Dalk1yupVSC2rxqz8Ooyd84ZJT6CPm/0mAodEaZ1n1nlwJOJmvylQMsNjqxciIuGJNP80wc1E3TrHLiMjA6WlpXBwcND7+8jlcsyaNQdXrxXA08MNy5cvNemi7lYKhQJrV8TjYvo5dAjsiskz4kz+h7ujmx/CIz+ClY19s3N11RU4smkmSgtMe7cThUKBlQlrkJ55AYGdOiJ26hST/7oTERlSw7Sxu6lVzLawa6DLXxbp5uYCgtdRUVmNNvY2iF+2yOQXEPjIAuE7aJLWeWbXs0/jyoH1uKJIFyAZERHdr3SpVXgrlgxi0KChmPLqO/Ad9DLCouLhO+hlTHn1HQwaNFToaAYVv+w9pO9N0jrPLGPfJsQve0+gZEREZA5Y2JHebd26FWkX8xH+twUE4eOWIu1ivkkvIBgzZgyCOnrh8N96+B3ePAdBHb1MfsSSiIiExVuxvBWrdz6yTvAd9PIdbkeuxRVFpgDJjOevPnZKSO3szKKPHRERGYYutYrZ9rEzdINic1ZRWQ2pq0zrOQc3GSoqq4wbSABjxoxhIUdEREZntrdiY2JicPbsWRw9etSg7yOXy9Gv30D4yzqhX7+BkMvlBn2/1qCNvQ3KCxVaz5UVKNDG3ta4gYiIiMyE2RZ2xjBp/ESMfvFliLo8jcAxiyHq8jRGv/gyJo2fKHQ0g4pftgjpe7U3Z87Ym4T4ZYsESkZERGTaWNgZiFwuxw97DiF07AdNFhCEjv0AP+w5ZNIjd3379oVIVYuUre+hOPvM/5ozn0HK1vcgUtWib9++QkckIiIySVw8YaDFE/36DYSoy9NaFxAUZ6dBk74Dhw//V2/v15rEzZqLw0UusJG64MqJX6G8cQ12bT3g0+sR1JQXIdylCPHLlwgdk4iI6L7AxRN3wdCLJ65eK0DgAzKt56Ru7ZF+oMAg79sapGdegLRzH1jZ2KPzoLFNzlnZ2CM944hAyYiIiEyb2d6KNfTiCU8Pt9suICgvuARPDzeDvG9rENip422vvaLwMgI7dTRuICIiIjNhtoWdoS1fvhQX92lfQHBx/yYsX75UoGSGFzt1CkrSdmm99uK0XYidOkWgZERERKaNhZ2BRERE4Kmh/ZGy5TUUZ6f9bwFBGlK2vIanhvZHRESE0BENRiaTYdbUaFz8eRlKcm4unijJOYOLPy/DrKnR3BCeiIjIQLh4wsA7T8jlcsyaNQdXrxXA08MNy5cvNemi7lYKhQIrE9YgPfMCAjt1ROzUKSzqiIiIdKRLrcLCjluKERERUSumS63CW7FEREREJsJsC7vExEQEBQWxWS4RERGZDN6K5a1YIiIiasV4K5aIiIjIDLGwIyIiIjIRLOyIiIiITAQLOzIYuVyOfv0Gwl/WCf36DYRcLhc6EhERkUljYUcGMWn8RIx+8WWIujyNwDGLIeryNEa/+DImjZ8odDQiIiKTxcKO9E4ul+OHPYcQOvYDOPt2g5WNPZx9uyF07Af4Yc8hjtwREREZCAs70rtZs+agw+BoiERNv71EIjE6DIrErFlzBEpGRERk2sy2sGODYsO5eq0AUleZ1nNSt/a4eq3AuIGIiIjMhNkWdjExMTh79iyOHj0qdBST4+nhhvJChdZz5QWX4OnhZtxAREREZsJsCzsynOXLl+Livo3QaNRNjms0alzcvwnLly8VKBkREZFpY2FHehcREYGnhvZHypbXUJydhrrqChRnpyFly2t4amh/RERECB2RiIjIJFkKHYBM07pP1yNSLsesWXOQfqAAnh5u2Pb5JyzqiIiIDIiFHRlMREQEDh/+r9AxiIiIzAZvxRIRERGZCBZ2RERERCaChZ2BxcXFwd7RDfbOXrB3dENcXJzQkYiIiMhEcY6dAbm6eUNl44SeT78GqasM5YUKbNqWhM8+90ZhQa7Q8YiIiMjEcMTOQOLi4qCycUK/cUub7Jfab9xSqGycOHJHREREesfCzkDWfvo5AodEad0vtfPgSKz99HOBkhEREZGpMtvCzuB7xVpY3na/VAc3GSDmXXAiIiLSL7Mt7Ay+V6yq/rb7pZYVKAB1vWHel4iIiMyW2RZ2hjZ5/ItI35ukdb/UjH2bMHn8iwIlIyIiIlMl0mg0GqFDCKmsrAyOjo4oLS2Fg4ODXl+7YVVs58GRcHCToaxAgYx9m2BRXcJVsURERHRXdKlVONHLgAoLchEXF4e1ny67OadOXY/J419EfHy80NGIiIjIBHHEzoAjdkREREQtpUutwjl2RERERCaChR0RERGRiWBhR0RERGQiWNgRERERmQgWdkREREQmgoUdERERkYlgYUdERERkIljYEREREZkIFnZEREREJoKFHREREZGJYGFHREREZCIshQ4gtIatcsvKygROQkRERNRcQ43SULPcidkXduXl5QAAX19fgZMQERER3V55eTkcHR3v+BiR5m7KPxOmVquRl5cHqVQKkUhkkPcoKyuDr68vcnJy4ODgYJD3aK147bx2Xrv5MOdrB8z7+nnthr12jUaD8vJyeHl5QSy+8yw6sx+xE4vF8PHxMcp7OTg4mN03fANeO6/d3PDazfPaAfO+fl674a79n0bqGnDxBBEREZGJYGFHREREZCJY2BmBtbU1FixYAGtra6GjGB2vnddubnjt5nntgHlfP6+99Vy72S+eICIiIjIVHLEjIiIiMhEs7IiIiIhMBAs7IiIiIhPBwo6IiIjIRLCwM6DFixejb9++kEqlcHNzw6hRo5Ceni50LKP4+OOP0bNnz8aGjf3798fPP/8sdCyjW7JkCUQiEWbMmCF0FKN4++23IRKJmnx06dJF6FhGk5ubi5deegnt2rWDra0tevTogWPHjgkdy+BkMlmzr7tIJEJMTIzQ0QxOpVLhrbfeQvv27WFra4uOHTvi3Xffvas9PU1BeXk5ZsyYAX9/f9ja2mLAgAE4evSo0LEM4uDBgxgxYgS8vLwgEomwffv2Juc1Gg3mz58PT09P2NraYtiwYcjMzDR6ThZ2BnTgwAHExMTg8OHD+O2331BXV4dHHnkElZWVQkczOB8fHyxZsgQpKSk4duwYhgwZgpEjR+LMmTNCRzOao0eP4pNPPkHPnj2FjmJU3bp1w9WrVxs/5HK50JGMoqSkBAMHDoSVlRV+/vlnnD17Fh9++CGcnJyEjmZwR48ebfI1/+233wAAY8aMETiZ4S1duhQff/wxEhIScO7cOSxduhQffPABVq9eLXQ0o5g4cSJ+++03bNmyBWlpaXjkkUcwbNgw5ObmCh1N7yorKxEcHIzExESt5z/44AOsWrUKa9aswZEjR2Bvb4/hw4ejurrauEE1ZDQFBQUaAJoDBw4IHUUQTk5OmvXr1wsdwyjKy8s1nTp10vz222+ahx56SBMbGyt0JKNYsGCBJjg4WOgYgpgzZ44mIiJC6BitQmxsrKZjx44atVotdBSDe+KJJzTjx49vcuyZZ57RvPjiiwIlMh6lUqmxsLDQ7Nq1q8nx3r17a9544w2BUhkHAM3333/f+LlardZ4eHholi1b1njsxo0bGmtra82XX35p1GwcsTOi0tJSAICzs7PASYxLpVLhq6++QmVlJfr37y90HKOIiYnBE088gWHDhgkdxegyMzPh5eWFDh064MUXX0R2drbQkYzihx9+QJ8+fTBmzBi4ubkhJCQE69atEzqW0dXW1uKzzz7D+PHjIRKJhI5jcAMGDMCePXuQkZEBADh58iTkcjkee+wxgZMZXn19PVQqFWxsbJoct7W1NZuR+gaXLl3CtWvXmvyf7+joiPDwcBw6dMioWSyN+m5mTK1WY8aMGRg4cCC6d+8udByjSEtLQ//+/VFdXY02bdrg+++/R1BQkNCxDO6rr75Camqqyc4zuZPw8HAkJSUhMDAQV69exTvvvIMHHngAp0+fhlQqFTqeQV28eBEff/wxZs6ciddffx1Hjx7F9OnTIZFIEBkZKXQ8o9m+fTtu3LiBqKgooaMYxdy5c1FWVoYuXbrAwsICKpUK77//Pl588UWhoxmcVCpF//798e6776Jr165wd3fHl19+iUOHDiEgIEDoeEZ17do1AIC7u3uT4+7u7o3njIWFnZHExMTg9OnTZvVbTGBgIE6cOIHS0lJ8++23iIyMxIEDB0y6uMvJyUFsbCx+++23Zr/FmoNbRyl69uyJ8PBw+Pv745tvvsGECRMETGZ4arUaffr0waJFiwAAISEhOH36NNasWWNWhd2GDRvw2GOPwcvLS+goRvHNN9/g888/xxdffIFu3brhxIkTmDFjBry8vMzi675lyxaMHz8e3t7esLCwQO/evfHCCy8gJSVF6Ghmi7dijWDq1KnYtWsX9u3bBx8fH6HjGI1EIkFAQABCQ0OxePFiBAcHY+XKlULHMqiUlBQUFBSgd+/esLS0hKWlJQ4cOIBVq1bB0tISKpVK6IhG1bZtW3Tu3BlZWVlCRzE4T0/PZr+0dO3a1WxuRQPA5cuX8fvvv2PixIlCRzGa2bNnY+7cuXj++efRo0cPjB07FnFxcVi8eLHQ0YyiY8eOOHDgACoqKpCTk4Pk5GTU1dWhQ4cOQkczKg8PDwBAfn5+k+P5+fmN54yFhZ0BaTQaTJ06Fd9//z327t2L9u3bCx1JUGq1GjU1NULHMKihQ4ciLS0NJ06caPzo06cPXnzxRZw4cQIWFhZCRzSqiooKXLhwAZ6enkJHMbiBAwc2a2eUkZEBf39/gRIZ38aNG+Hm5oYnnnhC6ChGo1QqIRY3/VFqYWEBtVotUCJh2Nvbw9PTEyUlJdi9ezdGjhwpdCSjat++PTw8PLBnz57GY2VlZThy5IjR55bzVqwBxcTE4IsvvsCOHTsglUob77M7OjrC1tZW4HSGNW/ePDz22GPw8/NDeXk5vvjiC+zfvx+7d+8WOppBSaXSZnMo7e3t0a5dO7OYWzlr1iyMGDEC/v7+yMvLw4IFC2BhYYEXXnhB6GgGFxcXhwEDBmDRokX417/+heTkZKxduxZr164VOppRqNVqbNy4EZGRkbC0NJ8fLSNGjMD7778PPz8/dOvWDcePH8dHH32E8ePHCx3NKHbv3g2NRoPAwEBkZWVh9uzZ6NKlC6Kjo4WOpncVFRVN7j5cunQJJ06cgLOzM/z8/DBjxgy899576NSpE9q3b4+33noLXl5eGDVqlHGDGnUNrpkBoPVj48aNQkczuPHjx2v8/f01EolE4+rqqhk6dKjm119/FTqWIMyp3clzzz2n8fT01EgkEo23t7fmueee02RlZQkdy2h27typ6d69u8ba2lrTpUsXzdq1a4WOZDS7d+/WANCkp6cLHcWoysrKNLGxsRo/Pz+NjY2NpkOHDpo33nhDU1NTI3Q0o/j66681HTp00EgkEo2Hh4cmJiZGc+PGDaFjGcS+ffu0/kyPjIzUaDQ3W5689dZbGnd3d421tbVm6NChgvx7EGk0ZtIem4iIiMjEcY4dERERkYlgYUdERERkIljYEREREZkIFnZEREREJoKFHREREZGJYGFHREREZCJY2BERERGZCBZ2RERERCaChR0RERGRiWBhR0T0DwYOHIjJkycLHYOI6B+xsCMiugO1Wo2TJ0+id+/eQkchIvpHLOyIiO4gPT0dlZWVLOyI6L7Awo6I6A5SU1NhaWmJnj17Ch2FiOgfsbAjIrqD1NRUBAUFwcbGRugoRET/iIUdEdEdpKam8jYsEd03WNgREd3BiRMnEBoaqvXcU089hdjYWPTr1w+BgYFITk7GyJEj4e/vj//85z+Nj/vss88QFhaGHj164IknnkBNTQ2Am6ttjxw5AgCYMGEC4uPjDX9BRGTSWNgREd3GhQsXcOPGjduO2KWlpaFnz544fPgwhg4ditmzZ+Ozzz7Dvn37sHHjxsbHPfbYY0hOTkZaWhq8vLywf/9+AMBbb72FJUuW4KOPPoJYLEZcXJwxLouITJil0AGIiFqr1NRUAICFhQVOnz7deFwikcDT0xMajQYTJkxoPD59+nRIpVIUFBTAwcEBAKDRaLBu3Tps27YNtbW1yMnJwUsvvQQAePTRR/HGG2/gxx9/xC+//GLEKyMiU8XCjojoNhoKu379+jU5HhERgWXLlqFv376Nx9LS0vD6668DAE6fPo0ePXoAAJKSknD+/HkcPHgQtra26NixI4KCggAAR48eRXFxMfz9/WFlZWWMSyIiE8dbsUREt7F48WJoNJpmH3/88UfjbdgGV65cgY+PD4CbRV5DYXfmzBkMHDgQtra2SExMhFKphKurK3JzczFx4kTs3bsXCoWiyYggEdG9YmFHRHQPbi3scnJy4Ovr2+RcQ2E3duxYfPDBB+jXrx8uXbqEHj16oKqqCmPGjMHq1avRvn17zJs3D++++64g10FEpkWk0Wg0QocgIiIiopbjiB0RERGRiWBhR0RERGQiWNgRERERmQgWdkREREQmgoUdERERkYlgYUdERERkIljYEREREZkIFnZEREREJoKFHREREZGJYGFHREREZCJY2BERERGZCBZ2RERERCbi/wFE/wmSpEzmwwAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots()\n", + "\n", + "for n in range(2, 11):\n", + " num_projections = 2 * n + 1\n", + " x_points = np.ones(num_projections) * n\n", + " e3nn_impl, direct_counts = count_operations_per_n(n)\n", + " ax.scatter(\n", + " x_points,\n", + " e3nn_impl,\n", + " label=\"e3nn\",\n", + " facecolor=\"#d33c25\",\n", + " edgecolor=\"k\",\n", + " s=30.0,\n", + " lw=0.5,\n", + " )\n", + " ax.scatter(\n", + " x_points,\n", + " direct_counts,\n", + " label=\"EquiTriton\",\n", + " edgecolor=\"k\",\n", + " facecolor=\"#4a7cb6\",\n", + " s=30.0,\n", + " lw=0.5,\n", + " )\n", + "ax.set(yscale=\"log\", xlabel=\"$L_{max}$\", ylabel=\"# of arithmetic operations\")\n", + "\n", + "leg = [\n", + " Line2D(\n", + " [0],\n", + " [0],\n", + " marker=\"o\",\n", + " color=\"w\",\n", + " markerfacecolor=\"#d33c25\",\n", + " markersize=15.0,\n", + " label=\"e3nn\",\n", + " ),\n", + " Line2D(\n", + " [0],\n", + " [0],\n", + " marker=\"o\",\n", + " color=\"w\",\n", + " markerfacecolor=\"#4a7cb6\",\n", + " markersize=15.0,\n", + " label=\"EquiTriton\",\n", + " ),\n", + "]\n", + "ax.legend(handles=leg, ncols=2, frameon=False)\n", + "fig.tight_layout()" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "7c4bb376-60fb-4a91-b941-21f43eb76161", + "metadata": {}, + "outputs": [], + "source": [ + "fig.savefig(\"equitriton_algorithmic_scaling.png\", dpi=150)" + ] + }, + { + "cell_type": "markdown", + "id": "c5b89738-0f3c-4dbe-84f1-b4c1851a7924", + "metadata": {}, + "source": [ + "## Dumping expressions to JSON\n", + "\n", + "Helps with implementation and just to have a nice record of them outside the notebook." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "257b5427-64b5-4eeb-a151-fa32099e646f", + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import os" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "dbaee485-0654-461a-8436-b9a2c673907c", + "metadata": {}, + "outputs": [], + "source": [ + "def write_expressions_to_json(expr_set, n: int):\n", + " os.makedirs(\"direct_sph_harm\", exist_ok=True)\n", + " write_dict = {}\n", + " write_dict[\"fwd\"] = [str(expr) for expr in expr_set[\"fwd\"]]\n", + " write_dict[\"bwd\"] = {axis: str(expr) for axis, expr in expr_set[\"bwd\"].items()}\n", + " with open(f\"direct_sph_harm/l_{n}.json\", \"w+\") as write_file:\n", + " json.dump(write_dict, write_file, indent=2)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "aabda44e-e1a1-4ef7-bb4d-2a37884ba38d", + "metadata": {}, + "outputs": [], + "source": [ + "for e, n in zip(\n", + " [\n", + " second_order_expressions,\n", + " third_order_expressions,\n", + " fourth_order_expressions,\n", + " fifth_order_expressions,\n", + " sixth_order_expressions,\n", + " seventh_order_expressions,\n", + " eighth_order_expressions,\n", + " ninth_order_expressions,\n", + " tenth_order_expressions,\n", + " ],\n", + " range(2, 11),\n", + "):\n", + " write_expressions_to_json(e, n)" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "id": "c72e6e01-92c6-4d08-8c70-eeb2708bc8d9", + "metadata": {}, + "outputs": [], + "source": [ + "def collect_symbols(expr, agg_set):\n", + " \"\"\"\n", + " This function iterates across a set of expressions, and collects up unique\n", + " symbols into a set.\n", + " \"\"\"\n", + " if len(expr.args) != 0 and not isinstance(expr, sympy.Pow):\n", + " for arg in expr.args:\n", + " collect_symbols(arg, agg_set)\n", + " else:\n", + " agg_set.add(expr)\n", + "\n", + "\n", + "def ordered_power_replacement(expr):\n", + " x, y, z = symbols(\"x y z\")\n", + " mapping = {}\n", + " char_counter = 0\n", + " for axis in [x, y, z]:\n", + " for exponent in range(10, 1, -1):\n", + " varname = f\"VAR{char_counter:02}\"\n", + " expr = expr.subs({axis**exponent: varname})\n", + " mapping[axis**exponent] = varname\n", + " char_counter += 1\n", + " return expr, mapping" + ] + }, + { + "cell_type": "markdown", + "id": "946b001a-dcc1-428a-848c-08664954afa6", + "metadata": {}, + "source": [ + "## Generating forward code" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "id": "3d837ed4-bb65-4843-9f20-7f19978b0586", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_fwd_implementation(exprs, output_file: Path | None = None):\n", + " \"\"\"\n", + " This function takes a set of expressions for a particular forward\n", + " pass, determines the set of variables and constants used by each\n", + " kernel, and prints out a format that can be used as a 'machine-generated'\n", + " implementation for the collection of kernels.\n", + " \"\"\"\n", + " variable_set = set()\n", + " for expr in exprs:\n", + " collect_symbols(expr, variable_set)\n", + " mapping = {}\n", + " const_counter = 0\n", + " for sym in variable_set:\n", + " if isinstance(sym, (sympy.Float, sympy.Integer)):\n", + " varname = f\"CONST{const_counter:03}\"\n", + " const_counter += 1\n", + " mapping[sym] = varname\n", + " # sort the dict to make the output cleaner\n", + " variable_sec = \"\"\n", + " kernel_sec = \"\"\n", + " mapping = dict(sorted(mapping.items(), key=lambda x: x[1]))\n", + " # now generate the actual kernels\n", + " term_counter = 0\n", + " for index, kernel in enumerate(exprs):\n", + " # this ensures that higher order powers are replaced\n", + " kernel, new_mapping = ordered_power_replacement(kernel)\n", + " mapping.update(new_mapping)\n", + " kernel_str = str(kernel.subs(mapping))\n", + " kernel_sec += f\"Y{index:02} = {kernel_str}\\n\"\n", + " for sym, char in mapping.items():\n", + " variable_sec += f\"{char} = {sym}\\n\"\n", + " fmt_string = \"# -------------------- variable and constant definitions\\n\"\n", + " fmt_string += variable_sec\n", + " fmt_string += \"# -------------------- kernel implementations\\n\"\n", + " fmt_string += kernel_sec\n", + " if output_file:\n", + " with open(output_file, \"w+\") as write_file:\n", + " write_file.write(fmt_string)" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "id": "32600151-71ec-4738-b691-17862c4d3a94", + "metadata": {}, + "outputs": [], + "source": [ + "os.makedirs(\"fwd_implementations\", exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "id": "6685cbaf-0098-41e2-ace7-9d8d92ded1de", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "for order, expressions in zip(\n", + " range(2, 11),\n", + " [\n", + " second_order_expressions,\n", + " third_order_expressions,\n", + " fourth_order_expressions,\n", + " fifth_order_expressions,\n", + " sixth_order_expressions,\n", + " seventh_order_expressions,\n", + " eighth_order_expressions,\n", + " ninth_order_expressions,\n", + " tenth_order_expressions,\n", + " ],\n", + "):\n", + " output_path = Path(f\"fwd_implementations/fwd_{order}.py\")\n", + " generate_fwd_implementation(expressions[\"fwd\"], output_path)" + ] + }, + { + "cell_type": "markdown", + "id": "ffb012be-79d0-4a48-9810-4903b8ebed5e", + "metadata": {}, + "source": [ + "## Generating backward code" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "id": "22dee377-3e18-44bb-ad7c-621cdb27ded2", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_bwd_implementation(exprs: dict, output_file: Path | None = None):\n", + " \"\"\"\n", + " This function takes a set of expressions for a particular backward\n", + " pass, determines the set of variables and constants used by each\n", + " kernel, and prints out a format that can be used as a 'machine-generated'\n", + " implementation for the collection of kernels.\n", + " \"\"\"\n", + " variable_set = set()\n", + " mapping = {}\n", + " # we expect a dictionary, where the key corresponds to the cart axis\n", + " for expr in exprs.values():\n", + " collect_symbols(expr, variable_set)\n", + " const_counter = 0\n", + " for sym in variable_set:\n", + " if isinstance(sym, (sympy.Float, sympy.Integer)):\n", + " varname = f\"CONST{const_counter:03}\"\n", + " const_counter += 1\n", + " mapping[sym] = varname\n", + " variable_sec = \"\"\n", + " kernel_sec = \"\"\n", + " for axis, grad_kernel in exprs.items():\n", + " # this ensures we replace exponents\n", + " grad_kernel, new_mapping = ordered_power_replacement(grad_kernel)\n", + " mapping.update(new_mapping)\n", + " # now replace constants\n", + " kernel_str = str(grad_kernel.subs(mapping))\n", + " kernel_sec += f\"g_{axis} = {kernel_str}\\n\"\n", + " # sort the dict to make the output cleaner\n", + " mapping = dict(sorted(mapping.items(), key=lambda x: x[1]))\n", + " for sym, char in mapping.items():\n", + " variable_sec += f\"{char} = {sym}\\n\"\n", + " fmt_string = \"# -------------------- variable and constant definitions\\n\"\n", + " fmt_string += variable_sec\n", + " fmt_string += \"# -------------------- kernel implementations\\n\"\n", + " fmt_string += kernel_sec\n", + " if output_file:\n", + " with open(output_file, \"w+\") as write_file:\n", + " write_file.write(fmt_string)" + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "id": "53e9e681-72df-4f1d-aa38-21620816d372", + "metadata": {}, + "outputs": [], + "source": [ + "os.makedirs(\"bwd_implementations\", exist_ok=True)\n", + "\n", + "for order, expressions in zip(\n", + " range(2, 11),\n", + " [\n", + " second_order_expressions,\n", + " third_order_expressions,\n", + " fourth_order_expressions,\n", + " fifth_order_expressions,\n", + " sixth_order_expressions,\n", + " seventh_order_expressions,\n", + " eighth_order_expressions,\n", + " ninth_order_expressions,\n", + " tenth_order_expressions,\n", + " ],\n", + "):\n", + " output_path = Path(f\"bwd_implementations/bwd_{order}.py\")\n", + " generate_bwd_implementation(expressions[\"bwd\"], output_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 86, + "id": "13130ed1-0c5c-422c-a1a9-f1fa07ee4b20", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "x \n", + " g_0*(11.6340690431164*VAR06 - 69.8044142586986*VAR08*VAR26 + 11.6340690431164*VAR24) + g_1*y*(-88.2963759165686*VAR08*z + 29.4321253055229*VAR25) + g_10*(46.5362761724657*VAR07*z - 46.5362761724657*VAR25*x) + g_2*(8.67152307844476*VAR06 + 3.0*VAR08*(-13.8744369255116*VAR17 - 3.4686092313779*VAR26) + 41.6233107765348*VAR17*VAR26 - 5.20291384706685*VAR24) + g_3*(-50.9779364038993*VAR08*y*z + 33.9852909359329*VAR16*z - 16.9926454679664*VAR25*y) + g_4*(8.02827036166571*VAR06 + 3.0*VAR08*(-19.2678488679977*VAR17 + 3.21130814466628*VAR26) + 12.8452325786651*VAR15 - 19.2678488679977*VAR17*VAR26 + 1.60565407233314*VAR24) + g_5*(-33.166247903554*VAR16*x + y*(24.8746859276655*VAR07 + 24.8746859276655*VAR26*x)) + g_6*(6.42261628933256*VAR25*x + z*(6.42261628933256*VAR07 - 38.5356977359954*VAR17*x)) + g_7*(33.9852909359329*VAR07*y - 33.9852909359329*VAR16*x) + g_8*(6.9372184627558*VAR25*x + z*(20.8116553882674*VAR07 - 83.2466215530696*VAR17*x)) + g_9*y*(29.4321253055229*VAR07 - 88.2963759165686*VAR26*x) \n", + " g_0*(CONST009*VAR06 + CONST009*VAR24 + CONST040*VAR08*VAR26) + g_1*y*(CONST038*VAR08*z - CONST052*VAR25) + g_10*(CONST029*VAR07*z + CONST043*VAR25*x) + g_2*(CONST001*VAR08*(CONST059*VAR17 + CONST064*VAR26) + CONST006*VAR06 - CONST045*VAR17*VAR26 + CONST063*VAR24) + g_3*(CONST041*VAR08*y*z - CONST049*VAR16*z + CONST057*VAR25*y) + g_4*(CONST000*VAR24 + CONST001*VAR08*(CONST002*VAR26 + CONST055*VAR17) + CONST007*VAR06 + CONST010*VAR15 + CONST056*VAR17*VAR26) + g_5*(CONST048*VAR16*x + y*(CONST019*VAR07 + CONST019*VAR26*x)) + g_6*(CONST005*VAR25*x + z*(CONST004*VAR07 + CONST046*VAR17*x)) + g_7*(CONST049*VAR16*x - CONST051*VAR07*y) + g_8*(CONST008*VAR25*x + z*(CONST039*VAR17*x - CONST054*VAR07)) + g_9*y*(CONST024*VAR07 + CONST038*VAR26*x)\n", + "y \n", + " g_1*(-29.4321253055229*VAR07*z + 29.4321253055229*VAR25*x) + g_2*(-27.7488738510232*VAR07*y + 83.2466215530696*VAR26*x*y) + g_3*(-16.9926454679664*VAR07*z + x*(101.955872807799*VAR17*z - 16.9926454679664*VAR25)) + g_4*(-38.5356977359954*VAR07*y + x*(51.3809303146605*VAR16 - 38.5356977359954*VAR26*y)) + g_5*(6.21867148191637*VAR06 + 12.4373429638327*VAR08*VAR26 + 16.583123951777*VAR15 + 3.0*VAR17*(-16.583123951777*VAR08 - 16.583123951777*VAR26) + 6.21867148191637*VAR24) + g_6*(-38.5356977359954*VAR25*y + z*(-38.5356977359954*VAR08*y + 51.3809303146605*VAR16)) + g_7*(8.49632273398321*VAR06 + 3.0*VAR17*(-16.9926454679664*VAR08 + 16.9926454679664*VAR26) - 8.49632273398321*VAR24) + g_8*(-83.2466215530696*VAR08*y*z + 27.7488738510232*VAR25*y) + g_9*(7.35803132638072*VAR06 - 44.1481879582843*VAR08*VAR26 + 7.35803132638072*VAR24) \n", + " g_1*(CONST052*VAR07*z - CONST052*VAR25*x) + g_2*(-CONST039*VAR26*x*y + CONST053*VAR07*y) + g_3*(CONST058*VAR07*z + x*(CONST034*VAR17*z + CONST057*VAR25)) + g_4*(CONST047*VAR07*y + x*(CONST030*VAR16 + CONST046*VAR26*y)) + g_5*(CONST001*VAR17*(CONST060*VAR08 + CONST060*VAR26) + CONST011*VAR06 + CONST012*VAR24 + CONST014*VAR08*VAR26 - CONST060*VAR15) + g_6*(CONST046*VAR25*y + z*(CONST031*VAR16 + CONST046*VAR08*y)) + g_7*(CONST001*VAR17*(CONST057*VAR08 - CONST057*VAR26) - CONST061*VAR06 + CONST061*VAR24) + g_8*(CONST021*VAR25*y + CONST039*VAR08*y*z) + g_9*(CONST027*VAR06 + CONST027*VAR24 + CONST044*VAR08*VAR26)\n", + "z \n", + " g_0*(-46.5362761724657*VAR07*z + 46.5362761724657*VAR25*x) + g_1*y*(-29.4321253055229*VAR07 + 88.2963759165686*VAR26*x) + g_10*(11.6340690431164*VAR06 - 69.8044142586986*VAR08*VAR26 + 11.6340690431164*VAR24) + g_2*(-6.9372184627558*VAR07*z + x*(83.2466215530696*VAR17*z - 20.8116553882674*VAR25)) + g_3*(-16.9926454679664*VAR07*y + x*(33.9852909359329*VAR16 - 50.9779364038993*VAR26*y)) + g_4*(6.42261628933256*VAR07*z + x*(-38.5356977359954*VAR17*z + 6.42261628933257*VAR25)) + g_5*(-33.166247903554*VAR16*z + y*(24.8746859276655*VAR08*z + 24.8746859276655*VAR25)) + g_6*(1.60565407233314*VAR06 - 19.2678488679977*VAR08*VAR17 + 12.8452325786651*VAR15 + 8.02827036166571*VAR24 + 3.0*VAR26*(3.21130814466628*VAR08 - 19.2678488679977*VAR17)) + g_7*(33.9852909359329*VAR16*z - 33.9852909359329*VAR25*y) + g_8*(5.20291384706685*VAR06 - 41.6233107765348*VAR08*VAR17 - 8.67152307844475*VAR24 + 3.0*VAR26*(3.4686092313779*VAR08 + 13.8744369255116*VAR17)) + g_9*y*(-88.2963759165686*VAR08*z + 29.4321253055229*VAR25) \n", + " g_0*(CONST029*VAR25*x + CONST043*VAR07*z) + g_1*y*(-CONST038*VAR26*x + CONST052*VAR07) + g_10*(CONST009*VAR06 + CONST009*VAR24 + CONST040*VAR08*VAR26) + g_2*(CONST062*VAR07*z + x*(-CONST039*VAR17*z + CONST054*VAR25)) + g_3*(CONST058*VAR07*y + x*(CONST042*VAR26*y - CONST049*VAR16)) + g_4*(CONST005*VAR07*z + x*(CONST046*VAR17*z + CONST050*VAR25)) + g_5*(CONST048*VAR16*z + y*(CONST019*VAR08*z + CONST020*VAR25)) + g_6*(CONST001*VAR26*(CONST002*VAR08 + CONST056*VAR17) + CONST003*VAR06 + CONST007*VAR24 + CONST017*VAR15 + CONST056*VAR08*VAR17) + g_7*(-CONST049*VAR16*z + CONST051*VAR25*y) + g_8*(CONST001*VAR26*(CONST018*VAR17 + CONST037*VAR08) + CONST036*VAR24 + CONST045*VAR08*VAR17 - CONST063*VAR06) + g_9*y*(CONST024*VAR25 + CONST038*VAR08*z)\n", + "{1.60565407233314: 'CONST000', 3.00000000000000: 'CONST001', 3.21130814466628: 'CONST002', 1.60565407233314: 'CONST003', 6.42261628933256: 'CONST004', 6.42261628933256: 'CONST005', 8.67152307844476: 'CONST006', 8.02827036166571: 'CONST007', 6.93721846275580: 'CONST008', 11.6340690431164: 'CONST009', 12.8452325786651: 'CONST010', 6.21867148191637: 'CONST011', 6.21867148191637: 'CONST012', 16.5831239517770: 'CONST013', 12.4373429638327: 'CONST014', 16.9926454679664: 'CONST015', 20.8116553882674: 'CONST016', 12.8452325786651: 'CONST017', 13.8744369255116: 'CONST018', 24.8746859276655: 'CONST019', 24.8746859276655: 'CONST020', 27.7488738510232: 'CONST021', 5.20291384706685: 'CONST022', 29.4321253055229: 'CONST023', 29.4321253055229: 'CONST024', 33.9852909359329: 'CONST025', 33.9852909359329: 'CONST026', 7.35803132638072: 'CONST027', 41.6233107765348: 'CONST028', 46.5362761724657: 'CONST029', 51.3809303146605: 'CONST030', 51.3809303146605: 'CONST031', 83.2466215530696: 'CONST032', 88.2963759165686: 'CONST033', 101.955872807799: 'CONST034', 8.49632273398321: 'CONST035', -8.67152307844475: 'CONST036', 3.46860923137790: 'CONST037', -88.2963759165686: 'CONST038', -83.2466215530696: 'CONST039', -69.8044142586986: 'CONST040', -50.9779364038993: 'CONST041', -50.9779364038993: 'CONST042', -46.5362761724657: 'CONST043', -44.1481879582843: 'CONST044', -41.6233107765348: 'CONST045', -38.5356977359954: 'CONST046', -38.5356977359954: 'CONST047', -33.1662479035540: 'CONST048', -33.9852909359329: 'CONST049', 6.42261628933257: 'CONST050', -33.9852909359329: 'CONST051', -29.4321253055229: 'CONST052', -27.7488738510232: 'CONST053', -20.8116553882674: 'CONST054', -19.2678488679977: 'CONST055', -19.2678488679977: 'CONST056', -16.9926454679664: 'CONST057', -16.9926454679664: 'CONST058', -13.8744369255116: 'CONST059', -16.5831239517770: 'CONST060', -8.49632273398321: 'CONST061', -6.93721846275580: 'CONST062', -5.20291384706685: 'CONST063', -3.46860923137790: 'CONST064', x**10: 'VAR00', x**9: 'VAR01', x**8: 'VAR02', x**7: 'VAR03', x**6: 'VAR04', x**5: 'VAR05', x**4: 'VAR06', x**3: 'VAR07', x**2: 'VAR08', y**10: 'VAR09', y**9: 'VAR10', y**8: 'VAR11', y**7: 'VAR12', y**6: 'VAR13', y**5: 'VAR14', y**4: 'VAR15', y**3: 'VAR16', y**2: 'VAR17', z**10: 'VAR18', z**9: 'VAR19', z**8: 'VAR20', z**7: 'VAR21', z**6: 'VAR22', z**5: 'VAR23', z**4: 'VAR24', z**3: 'VAR25', z**2: 'VAR26'}\n" + ] + } + ], + "source": [ + "generate_bwd_implementation(fifth_order_expressions[\"bwd\"], None)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "379cf400-8e4e-48f2-8e3a-0c02dbbd3897", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/bwd_implementations/bwd_10.py b/notebooks/bwd_implementations/bwd_10.py new file mode 100644 index 0000000..338aeff --- /dev/null +++ b/notebooks/bwd_implementations/bwd_10.py @@ -0,0 +1,431 @@ +# -------------------- variable and constant definitions +CONST000 = 2.00000000000000 +CONST001 = 3.21913870529156 +CONST002 = 4.00000000000000 +CONST003 = 4.82870805793735 +CONST004 = 6.00000000000000 +CONST005 = 4.97432985632550 +CONST006 = 8.00000000000000 +CONST007 = 4.97432985632550 +CONST008 = 10.5521471197994 +CONST009 = 3.00000000000000 +CONST010 = 5.00000000000000 +CONST011 = 7.00000000000000 +CONST012 = 13.2648796168680 +CONST013 = 12.8765548211663 +CONST014 = 12.1657520803952 +CONST015 = 16.7271353825295 +CONST016 = -2030.35546709287 +CONST017 = 19.3148322317494 +CONST018 = -6131.53904851919 +CONST019 = 22.8629854262320 +CONST020 = 23.2135393295190 +CONST021 = 24.6216766128653 +CONST022 = 17.5869118663323 +CONST023 = 27.2034486491732 +CONST024 = 28.9722483476241 +CONST025 = 33.9852909359329 +CONST026 = 33.9852909359329 +CONST027 = 35.5238206489124 +CONST028 = 6180.74631415980 +CONST029 = 38.6296644634988 +CONST030 = 39.7946388506040 +CONST031 = 38.6296644634988 +CONST032 = -2007.25624590353 +CONST033 = -2007.25624590353 +CONST034 = 45.8257569495584 +CONST035 = 45.7259708524640 +CONST036 = 49.2433532257305 +CONST037 = 56.3871618715269 +CONST038 = 56.2781179722634 +CONST039 = -1989.33395633909 +CONST040 = -1989.33395633909 +CONST041 = 59.6919582759060 +CONST042 = 66.9085415301178 +CONST043 = 69.6406179885570 +CONST044 = -8121.42186837148 +CONST045 = 77.2593289269976 +CONST046 = 78.6510608948335 +CONST047 = -1969.73412902922 +CONST048 = 77.3468749368712 +CONST049 = -1969.73412902922 +CONST050 = -9.65741611587469 +CONST051 = 90.1358837481638 +CONST052 = 2141.07332896377 +CONST053 = 94.9693240781945 +CONST054 = 92.8541573180760 +CONST055 = 96.5741611587469 +CONST056 = 98.4867064514610 +CONST057 = 98.4867064514610 +CONST058 = 100.362812295177 +CONST059 = 101.517773354644 +CONST060 = 106.571461946737 +CONST061 = 106.571461946737 +CONST062 = 109.491768723557 +CONST063 = 109.491768723557 +CONST064 = 112.774323743054 +CONST065 = 112.774323743054 +CONST066 = 112.556235944527 +CONST067 = 2165.26701586663 +CONST068 = 130.522851455970 +CONST069 = 131.315608601948 +CONST070 = 133.817083060236 +CONST071 = 139.281235977114 +CONST072 = 139.281235977114 +CONST073 = 141.571909610700 +CONST074 = 142.095282595650 +CONST075 = 147.730059677192 +CONST076 = 150.544218442765 +CONST077 = 150.074981259369 +CONST078 = 154.518657853995 +CONST079 = 2202.22970505534 +CONST080 = -3939.46825805844 +CONST081 = -5968.00186901728 +CONST082 = 176.592751833137 +CONST083 = 176.178376404427 +CONST084 = 2228.49977563382 +CONST085 = 185.708314636152 +CONST086 = 196.973412902922 +CONST087 = 196.973412902922 +CONST088 = 203.035546709287 +CONST089 = 225.548647486108 +CONST090 = 225.548647486108 +CONST091 = 4330.53403173327 +CONST092 = 2285.08968653055 +CONST093 = 244.831037842559 +CONST094 = -1804.38917988886 +CONST095 = -1804.38917988886 +CONST096 = 244.831037842559 +CONST097 = 2317.77986780993 +CONST098 = 278.562471954228 +CONST099 = 284.190565191299 +CONST100 = 284.190565191299 +CONST101 = -1761.78376404427 +CONST102 = 290.050781013267 +CONST103 = -9946.66978169547 +CONST104 = 9.94865971265100 +CONST105 = 305.867618423396 +CONST106 = 305.867618423396 +CONST107 = 309.037315707990 +CONST108 = -7878.93651611688 +CONST109 = 2363.68095483506 +CONST110 = 14.5025390506634 +CONST111 = 338.322971229162 +CONST112 = 360.877835977772 +CONST113 = 4456.99955126765 +CONST114 = -1671.37483172537 +CONST115 = 386.296644634988 +CONST116 = 2436.42656051144 +CONST117 = 393.946825805844 +CONST118 = 393.946825805844 +CONST119 = 393.946825805844 +CONST120 = -1648.19901710928 +CONST121 = 401.451249180707 +CONST122 = 406.071093418574 +CONST123 = 412.049754277320 +CONST124 = 2472.29852566392 +CONST125 = -1624.28437367430 +CONST126 = 426.285847786949 +CONST127 = 426.285847786948 +CONST128 = 2486.66744542387 +CONST129 = 450.224943778107 +CONST130 = 451.097294972216 +CONST131 = 451.097294972216 +CONST132 = 451.097294972215 +CONST133 = 6606.68911516602 +CONST134 = 6606.68911516602 +CONST135 = -1575.78730322338 +CONST136 = -1575.78730322338 +CONST137 = -3608.77835977772 +CONST138 = 492.433532257305 +CONST139 = -1545.18657853995 +CONST140 = -1545.18657853995 +CONST141 = 525.262434407792 +CONST142 = 535.268332240943 +CONST143 = 4635.55973561985 +CONST144 = 541.428124558099 +CONST145 = -3545.52143225260 +CONST146 = 557.124943908456 +CONST147 = -3523.56752808854 +CONST148 = -5571.24943908456 +CONST149 = 580.101562026534 +CONST150 = 10828.5624911620 +CONST151 = 15.7883647328499 +CONST152 = 590.920238708766 +CONST153 = 2642.67564606641 +CONST154 = 2642.67564606641 +CONST155 = 2676.34166120471 +CONST156 = 629.208487158668 +CONST157 = 4727.36190967013 +CONST158 = 4727.36190967013 +CONST159 = -1392.81235977114 +CONST160 = -1390.66792068596 +CONST161 = 2707.14062279049 +CONST162 = 663.111318779698 +CONST163 = -3427.63452979582 +CONST164 = -1378.81389032045 +CONST165 = 676.645942458323 +CONST166 = 706.371007332549 +CONST167 = -1338.17083060236 +CONST168 = -1338.17083060236 +CONST169 = 721.755671955545 +CONST170 = 734.076568351780 +CONST171 = 2785.62471954228 +CONST172 = 742.833258544608 +CONST173 = 772.593289269975 +CONST174 = 787.893651611688 +CONST175 = 787.893651611688 +CONST176 = 787.893651611688 +CONST177 = 6.63243980843400 +CONST178 = 812.142186837148 +CONST179 = 812.142186837148 +CONST180 = -1218.21328025572 +CONST181 = -1202.92611992591 +CONST182 = -1202.92611992591 +CONST183 = -3248.56874734859 +CONST184 = -3248.56874734859 +CONST185 = -5285.35129213281 +CONST186 = -1181.84047741753 +CONST187 = 875.934149788456 +CONST188 = 880.891882022136 +CONST189 = 880.891882022136 +CONST190 = 2936.30627340712 +CONST191 = 900.449887556215 +CONST192 = 2954.60119354383 +CONST193 = -1114.24988781691 +CONST194 = -16.5810995210850 +CONST195 = -1101.11485252767 +CONST196 = -1081.63060497797 +CONST197 = 15.7302121789667 +CONST198 = 979.324151370235 +CONST199 = 984.867064514610 +CONST200 = 984.867064514610 +CONST201 = 1015.17773354644 +CONST202 = -1027.70719569249 +CONST203 = -1021.92317475320 +CONST204 = -3065.76952425960 +CONST205 = -1015.17773354644 +CONST206 = 3090.37315707990 +CONST207 = -994.666978169547 +CONST208 = -984.867064514610 +CONST209 = -984.867064514610 +CONST210 = -979.324151370235 +CONST211 = 1070.53666448189 +CONST212 = -979.324151370235 +CONST213 = 3151.57460644675 +CONST214 = 16.0956935264578 +CONST215 = 1114.24988781691 +CONST216 = -927.111947123971 +CONST217 = -927.111947123970 +CONST218 = -5.63871618715269 +CONST219 = -2954.60119354383 +CONST220 = -902.194589944431 +CONST221 = -900.449887556215 +CONST222 = -880.891882022136 +CONST223 = -880.891882022136 +CONST224 = -875.934149788456 +CONST225 = 1181.84047741753 +CONST226 = -4944.59705132784 +CONST227 = 3248.56874734859 +CONST228 = 3248.56874734859 +CONST229 = -835.687415862684 +CONST230 = 1218.21328025572 +CONST231 = -824.099508554641 +CONST232 = -824.863625092051 +CONST233 = -824.863625092051 +CONST234 = -812.142186837148 +CONST235 = 5352.68332240943 +CONST236 = -787.893651611688 +CONST237 = -787.893651611688 +CONST238 = -772.593289269976 +CONST239 = -742.833258544608 +CONST240 = -2785.62471954228 +CONST241 = -734.076568351780 +CONST242 = 1321.33782303320 +CONST243 = 1321.33782303320 +CONST244 = -706.371007332549 +CONST245 = -696.406179885570 +CONST246 = 1353.29188491665 +CONST247 = -675.337415667161 +CONST248 = -675.337415667161 +CONST249 = 1378.81389032045 +CONST250 = 3427.63452979582 +CONST251 = -669.085415301178 +CONST252 = -669.085415301178 +CONST253 = -669.085415301178 +CONST254 = 3427.63452979582 +CONST255 = -663.111318779698 +CONST256 = -2707.14062279049 +CONST257 = 1392.81235977114 +CONST258 = 1392.81235977114 +CONST259 = 1412.74201466510 +CONST260 = -4727.36190967013 +CONST261 = -2676.34166120471 +CONST262 = -618.074631415980 +CONST263 = -611.735236846792 +CONST264 = -611.735236846792 +CONST265 = 1443.51134391109 +CONST266 = -590.920238708766 +CONST267 = -10828.5624911620 +CONST268 = -580.101562026534 +CONST269 = -2626.31217203896 +CONST270 = 3523.56752808854 +CONST271 = 5571.24943908456 +CONST272 = 5571.24943908456 +CONST273 = -12.8765548211663 +CONST274 = -557.124943908456 +CONST275 = -557.124943908456 +CONST276 = 3545.52143225260 +CONST277 = -541.428124558099 +CONST278 = -6685.49932690147 +CONST279 = 7664.42381064899 +CONST280 = -525.262434407792 +CONST281 = 1532.88476212980 +CONST282 = 1545.18657853995 +CONST283 = -497.333489084773 +CONST284 = -497.333489084773 +CONST285 = -492.433532257305 +CONST286 = 1575.78730322338 +CONST287 = 1575.78730322338 +CONST288 = -463.555973561985 +CONST289 = -450.224943778107 +CONST290 = -450.224943778107 +CONST291 = -450.224943778108 +CONST292 = -437.967074894228 +CONST293 = -2472.29852566392 +CONST294 = 1624.28437367430 +CONST295 = -2472.29852566392 +CONST296 = -406.071093418574 +CONST297 = -393.946825805844 +CONST298 = -393.946825805844 +CONST299 = -2436.42656051144 +CONST300 = -386.296644634988 +CONST301 = -386.296644634988 +CONST302 = -4456.99955126765 +CONST303 = -337.668707833581 +CONST304 = -337.668707833581 +CONST305 = -331.555659389849 +CONST306 = -331.555659389849 +CONST307 = -2363.68095483506 +CONST308 = 7878.93651611688 +CONST309 = -309.037315707990 +CONST310 = -4404.45941011068 +CONST311 = -309.037315707990 +CONST312 = -305.867618423396 +CONST313 = -305.867618423396 +CONST314 = -305.867618423396 +CONST315 = -300.731529981477 +CONST316 = 9946.66978169547 +CONST317 = 9946.66978169547 +CONST318 = -290.050781013267 +CONST319 = -284.190565191299 +CONST320 = -278.562471954228 +CONST321 = -278.562471954228 +CONST322 = -2317.77986780993 +CONST323 = -10505.2486881558 +CONST324 = -251.683394863467 +CONST325 = -251.683394863467 +CONST326 = -246.216766128653 +CONST327 = -244.831037842559 +CONST328 = -2285.08968653055 +CONST329 = -2285.08968653055 +CONST330 = 3862.96644634988 +CONST331 = -223.028471767059 +CONST332 = -220.222970505534 +CONST333 = -206.215906273013 +CONST334 = -203.035546709287 +CONST335 = -196.973412902922 +CONST336 = -196.973412902922 +CONST337 = -182.903883409856 +CONST338 = -2228.49977563382 +CONST339 = 5968.00186901728 +CONST340 = 16.4144510752435 +CONST341 = 3939.46825805844 +CONST342 = 3939.46825805844 +CONST343 = -154.518657853995 +CONST344 = -154.518657853995 +CONST345 = -150.074981259369 +CONST346 = -147.730059677191 +CONST347 = -146.815313670356 +CONST348 = -142.095282595650 +CONST349 = -131.315608601948 +CONST350 = -131.315608601948 +CONST351 = -130.522851455970 +CONST352 = -125.841697431734 +CONST353 = -125.841697431734 +CONST354 = -112.556235944527 +CONST355 = -103.107953136506 +CONST356 = -101.517773354644 +CONST357 = 1949.93730367960 +CONST358 = -98.4867064514610 +CONST359 = -98.4867064514610 +CONST360 = -2141.07332896377 +CONST361 = -2141.07332896377 +CONST362 = -92.8541573180760 +CONST363 = -88.2963759165686 +CONST364 = 1969.73412902922 +CONST365 = 1969.73412902922 +CONST366 = -77.3468749368713 +CONST367 = 8121.42186837148 +CONST368 = 8121.42186837148 +CONST369 = -67.6645942458323 +CONST370 = 1989.33395633909 +CONST371 = 1989.33395633909 +CONST372 = -59.6919582759060 +CONST373 = -49.2433532257305 +CONST374 = -49.2433532257305 +CONST375 = -45.1097294972216 +CONST376 = -45.1097294972216 +CONST377 = -42.2085884791976 +CONST378 = -27.2034486491732 +CONST379 = -24.6216766128653 +CONST380 = -22.8629854262320 +CONST381 = -19.7354559160624 +CONST382 = 2030.35546709287 +CONST383 = -17.5869118663323 +CONST384 = -16.4144510752435 +CONST385 = -16.0956935264578 +CONST386 = -14.5025390506634 +CONST387 = 6131.53904851919 +CONST388 = -16.5810995210850 +CONST389 = -15.7883647328499 +CONST390 = -14.0695294930659 +CONST391 = -11.2774323743054 +CONST392 = -11.2774323743054 +CONST393 = -13.2648796168680 +CONST394 = -6.63243980843400 +CONST395 = -5.63871618715269 +CONST396 = -4.82870805793735 +CONST397 = -3.21913870529156 +CONST398 = -11.2774323743054 +VAR00 = x**10 +VAR01 = x**9 +VAR02 = x**8 +VAR03 = x**7 +VAR04 = x**6 +VAR05 = x**5 +VAR06 = x**4 +VAR07 = x**3 +VAR08 = x**2 +VAR09 = y**10 +VAR10 = y**9 +VAR11 = y**8 +VAR12 = y**7 +VAR13 = y**6 +VAR14 = y**5 +VAR15 = y**4 +VAR16 = y**3 +VAR17 = y**2 +VAR18 = z**10 +VAR19 = z**9 +VAR20 = z**8 +VAR21 = z**7 +VAR22 = z**6 +VAR23 = z**5 +VAR24 = z**4 +VAR25 = z**3 +VAR26 = z**2 +# -------------------- kernel implementations +g_x = g_0*(CONST093*VAR02*z + CONST210*VAR08*VAR21 + CONST250*VAR06*VAR23 + CONST328*VAR04*VAR25 - CONST378*VAR19) + g_1*y*(CONST062*VAR20 + CONST063*VAR02 + CONST204*VAR04*VAR26 + CONST204*VAR08*VAR22 + CONST279*VAR06*VAR24) + g_10*(CONST000*x*(CONST089*VAR17*VAR22 + CONST169*VAR13*VAR26 + CONST220*VAR15*VAR24 + CONST355*VAR11 + CONST395*VAR20) + CONST002*VAR07*(CONST111*VAR17*VAR24 + CONST112*VAR13 + CONST220*VAR15*VAR26 + CONST392*VAR22) + CONST004*VAR05*(CONST090*VAR17*VAR26 + CONST315*VAR15 + CONST392*VAR24) + CONST006*VAR03*(CONST037*VAR17 + CONST218*VAR26) + CONST391*VAR01) + g_11*(CONST070*VAR21*x*y + VAR23*(CONST121*VAR07*y + CONST168*VAR16*x) + VAR25*(CONST121*VAR05*y + CONST261*VAR07*VAR16 - CONST361*VAR14*x) + z*(CONST070*VAR03*y + CONST167*VAR05*VAR16 + CONST263*VAR12*x - CONST361*VAR07*VAR14)) + g_12*(CONST000*x*(CONST003*VAR20 - CONST301*VAR15*VAR24 + CONST343*VAR17*VAR22 + CONST363*VAR11) + CONST002*VAR07*(CONST123*VAR13 + CONST300*VAR15*VAR26 - CONST397*VAR22) + CONST004*VAR05*(CONST301*VAR15 - CONST344*VAR17*VAR26 + CONST397*VAR24) + CONST006*VAR03*(CONST045*VAR17 + CONST396*VAR26) + CONST385*VAR01) + g_13*(CONST221*VAR12*x*z + VAR14*(-CONST260*VAR07*z + CONST286*VAR25*x) + VAR16*(CONST080*VAR07*VAR25 + CONST145*VAR05*z + CONST297*VAR23*x) + y*(-CONST237*VAR05*VAR25 - CONST297*VAR07*VAR23 - CONST298*VAR03*z)) + g_14*(CONST000*x*(CONST005*VAR20 - CONST159*VAR15*VAR24 + CONST193*VAR13*VAR26 + CONST320*VAR17*VAR22) + CONST002*VAR07*(CONST020*VAR22 + CONST085*VAR13 + CONST245*VAR17*VAR24 + CONST258*VAR15*VAR26) + CONST004*VAR05*(CONST020*VAR24 + CONST320*VAR15 + CONST320*VAR17*VAR26) + CONST006*VAR03*(CONST007*VAR26 + CONST043*VAR17) + CONST388*VAR01) + g_15*(VAR14*(-CONST147*VAR07*z + CONST147*VAR25*x) + VAR16*(CONST153*VAR23*x + CONST190*VAR07*VAR25 + CONST310*VAR05*z) + y*(CONST156*VAR03*z + CONST222*VAR07*VAR23 + CONST324*VAR21*x)) + g_16*(CONST000*x*(CONST047*VAR15*VAR24 + CONST175*VAR17*VAR22 + CONST380*VAR20) + CONST002*VAR07*(-CONST047*VAR15*VAR26 + CONST379*VAR22) + CONST004*VAR05*(CONST021*VAR24 + CONST236*VAR17*VAR26 + CONST349*VAR15) + CONST006*VAR03*(CONST019*VAR26 + CONST038*VAR17) + CONST383*VAR01) + g_17*(VAR16*(CONST183*VAR23*x + CONST184*VAR05*z - CONST267*VAR07*VAR25) + y*(CONST178*VAR03*z + CONST234*VAR07*VAR23 - CONST268*VAR21*x + CONST299*VAR05*VAR25)) + g_18*(CONST060*VAR20*x + CONST126*VAR03*VAR26 + CONST283*VAR05*VAR24 + CONST305*VAR07*VAR22 + CONST381*VAR01 + VAR17*(CONST039*VAR22*x + CONST081*VAR05*VAR26 + CONST316*VAR07*VAR24 - CONST319*VAR03)) + g_19*y*(CONST018*VAR05*VAR25 - CONST018*VAR07*VAR23 - CONST224*VAR03*z + CONST224*VAR21*x) + g_2*(CONST074*VAR02*z + CONST100*VAR08*VAR21 + CONST255*VAR04*VAR25 + CONST389*VAR19 + VAR17*(CONST040*VAR04*z + CONST081*VAR08*VAR23 - CONST103*VAR06*VAR25 - CONST319*VAR21)) + g_20*(CONST163*VAR05*VAR24 - CONST212*VAR03*VAR26 + CONST327*VAR20*x - CONST329*VAR07*VAR22 + CONST378*VAR01) + g_3*(VAR16*(CONST044*VAR08*VAR24 + CONST144*VAR22 + CONST277*VAR04 + CONST367*VAR06*VAR26) + y*(CONST016*VAR04*VAR26 - CONST205*VAR06*VAR24 + CONST230*VAR08*VAR22 - CONST351*VAR02 + CONST356*VAR20)) + g_4*(CONST008*VAR19 + CONST009*VAR08*(CONST175*VAR17*VAR23 + CONST269*VAR15*VAR25 + CONST390*VAR21) + CONST010*VAR06*(CONST175*VAR15*z + CONST176*VAR17*VAR25 + CONST373*VAR23) + CONST011*VAR04*(CONST303*VAR17*z + CONST390*VAR25) + CONST053*VAR02*z + CONST175*VAR15*VAR23 + CONST304*VAR17*VAR21) + g_5*(VAR14*(CONST185*VAR08*VAR26 - CONST222*VAR06 - CONST223*VAR24) + VAR16*(CONST079*VAR08*VAR24 + CONST133*VAR06*VAR26 + CONST202*VAR04 + CONST241*VAR22) + y*(CONST046*VAR20 + CONST073*VAR02 + CONST195*VAR06*VAR24 + CONST222*VAR04*VAR26)) + g_6*(CONST009*VAR08*(CONST098*VAR17*VAR23 + CONST239*VAR13*z + CONST393*VAR21) + CONST010*VAR06*(-CONST193*VAR15*z + CONST320*VAR17*VAR25) + CONST011*VAR04*(CONST012*VAR25 + CONST321*VAR17*z) + CONST041*VAR02*z + CONST098*VAR17*VAR21 + CONST193*VAR15*VAR23 - CONST239*VAR13*VAR25 + CONST394*VAR19) + g_7*(VAR12*(CONST289*VAR08 - CONST290*VAR26) + VAR14*(-CONST049*VAR06 + CONST186*VAR24 + CONST307*VAR08*VAR26) + VAR16*(CONST164*VAR04 + CONST192*VAR08*VAR24 + CONST199*VAR06*VAR26 - CONST266*VAR22) + y*(CONST075*VAR02 + CONST285*VAR06*VAR24 + CONST297*VAR08*VAR22 + CONST374*VAR20)) + g_8*(CONST009*VAR08*(-CONST140*VAR15*VAR25 + CONST231*VAR13*z - CONST273*VAR21 + CONST288*VAR17*VAR23) + CONST010*VAR06*(CONST017*VAR23 + CONST173*VAR15*z + CONST288*VAR17*VAR25) + CONST011*VAR04*(-CONST273*VAR25 + CONST344*VAR17*z) + CONST024*VAR02*z + CONST082*VAR11*z + CONST173*VAR15*VAR23 + CONST231*VAR13*VAR25 + CONST344*VAR17*VAR21 - CONST397*VAR19) + g_9*(CONST009*VAR08*(CONST042*VAR22*y + CONST211*VAR14*VAR26 + CONST251*VAR16*VAR24 + CONST312*VAR12) + CONST010*VAR06*(CONST058*VAR24*y + CONST142*VAR14 + CONST252*VAR16*VAR26) + CONST011*VAR04*(CONST042*VAR26*y + CONST331*VAR16) + CONST015*VAR20*y + CONST025*VAR10 + CONST076*VAR02*y + CONST142*VAR14*VAR24 + CONST312*VAR12*VAR26 + CONST331*VAR16*VAR22) +g_y = CONST000*g_18*y*(CONST027*VAR02 + CONST027*VAR20 + CONST128*VAR06*VAR24 + CONST207*VAR04*VAR26 + CONST207*VAR08*VAR22) + CONST000*g_2*y*(-CONST039*VAR05*VAR25 + CONST039*VAR07*VAR23 + CONST319*VAR03*z - CONST319*VAR21*x) + g_1*(CONST014*VAR01 + CONST062*VAR20*x + CONST203*VAR07*VAR22 + CONST281*VAR05*VAR24 + CONST292*VAR03*VAR26) + g_10*(CONST034*VAR10 + CONST064*VAR20*y + CONST065*VAR02*y + CONST067*VAR14*VAR24 + CONST182*VAR16*VAR22 + CONST233*VAR12*VAR26 + VAR04*(CONST131*VAR26*y + CONST181*VAR16) + VAR06*(CONST067*VAR14 + CONST137*VAR16*VAR26 + CONST165*VAR24*y) + VAR08*(CONST091*VAR14*VAR26 + CONST130*VAR22*y + CONST137*VAR16*VAR24 + CONST232*VAR12)) + g_11*(CONST015*VAR19 + VAR21*(CONST042*VAR08 + CONST253*VAR17) + VAR23*(CONST033*VAR08*VAR17 + CONST058*VAR06 + CONST155*VAR15) + VAR25*(CONST032*VAR06*VAR17 + CONST042*VAR04 + CONST235*VAR08*VAR15 + CONST361*VAR13) + z*(CONST015*VAR02 + CONST155*VAR06*VAR15 + CONST253*VAR04*VAR17 - CONST312*VAR11 + CONST360*VAR08*VAR13)) + g_12*(-CONST140*VAR16*VAR22 - CONST244*VAR12*VAR26 + CONST293*VAR14*VAR24 + CONST343*VAR20*y - CONST344*VAR02*y + VAR04*(CONST140*VAR16 - CONST311*VAR26*y) + VAR06*(CONST139*VAR16*VAR26 - CONST295*VAR14) + VAR08*(-CONST140*VAR16*VAR24 + CONST244*VAR12 + CONST309*VAR22*y)) + g_13*(CONST009*VAR17*(CONST208*VAR06*VAR25 + CONST266*VAR04*z + CONST335*VAR08*VAR23 - CONST336*VAR21) + CONST010*VAR15*(CONST176*VAR08*VAR25 - CONST186*VAR06*z + CONST298*VAR23) + CONST011*VAR13*(CONST077*VAR25 + CONST290*VAR08*z) - CONST350*VAR04*VAR25 - CONST358*VAR06*VAR23 - CONST374*VAR02*z + CONST384*VAR19) + g_14*(CONST071*VAR02*y + CONST072*VAR20*y - CONST193*VAR14*VAR24 + CONST193*VAR16*VAR22 + VAR04*(CONST193*VAR16 + CONST274*VAR26*y) + VAR06*(CONST159*VAR24*y - CONST193*VAR14 + CONST272*VAR16*VAR26) + VAR08*(-CONST148*VAR16*VAR24 + CONST274*VAR22*y + CONST278*VAR14*VAR26)) + g_15*(CONST009*VAR17*(CONST241*VAR04*z - CONST241*VAR06*VAR25 + CONST242*VAR08*VAR23 + CONST347*VAR21) + CONST010*VAR15*(CONST083*VAR23 + CONST101*VAR08*VAR25 - CONST223*VAR06*z) + CONST046*VAR02*z + CONST197*VAR19 + CONST332*VAR06*VAR23 + CONST352*VAR08*VAR21) + g_16*(-CONST108*VAR06*VAR16*VAR26 - CONST280*VAR16*VAR22 - CONST354*VAR02*y + CONST354*VAR20*y + VAR04*(CONST135*VAR26*y + CONST280*VAR16) + VAR08*(CONST108*VAR16*VAR24 + CONST287*VAR22*y)) + g_17*(CONST009*VAR17*(CONST048*VAR21 + CONST125*VAR08*VAR23 - CONST256*VAR06*VAR25 + CONST277*VAR04*z) + CONST059*VAR02*z + CONST296*VAR04*VAR25 - CONST318*VAR08*VAR21 + CONST334*VAR06*VAR23 + CONST386*VAR19) + g_19*(CONST014*VAR19 + CONST062*VAR02*z + CONST203*VAR04*VAR25 + CONST281*VAR06*VAR23 + CONST292*VAR08*VAR21) + g_3*(CONST009*VAR17*(CONST144*VAR22*x + CONST256*VAR07*VAR24 + CONST294*VAR05*VAR26 + CONST366*VAR03) + CONST122*VAR07*VAR22 + CONST318*VAR03*VAR26 - CONST334*VAR05*VAR24 + CONST356*VAR20*x - CONST386*VAR01) + g_4*(CONST248*VAR03*y*z + VAR05*(CONST213*VAR16*z + CONST286*VAR25*y) + VAR07*(CONST287*VAR23*y + CONST323*VAR16*VAR25) + x*(CONST213*VAR16*VAR23 + CONST247*VAR21*y)) + g_5*(CONST009*VAR17*(-CONST241*VAR07*VAR24 + CONST241*VAR22*x + CONST243*VAR05*VAR26 + CONST347*VAR03) + CONST010*VAR15*(CONST083*VAR05 + CONST101*VAR07*VAR26 - CONST223*VAR24*x) + CONST046*VAR20*x + CONST197*VAR01 + CONST332*VAR05*VAR24 + CONST353*VAR03*VAR26) + g_6*(CONST275*VAR03*y*z + VAR05*(CONST274*VAR25*y - CONST302*VAR16*z) + VAR07*(CONST146*VAR23*y + CONST302*VAR14*z) + x*(CONST146*VAR21*y - CONST302*VAR14*VAR25 + CONST302*VAR16*VAR23)) + g_7*(CONST009*VAR17*(CONST087*VAR05*VAR26 - CONST209*VAR07*VAR24 - CONST266*VAR22*x + CONST336*VAR03) + CONST010*VAR15*(CONST186*VAR24*x + CONST237*VAR07*VAR26 - CONST298*VAR05) + CONST011*VAR13*(-CONST290*VAR26*x + CONST345*VAR07) + CONST340*VAR01 + CONST350*VAR07*VAR22 + CONST358*VAR05*VAR24 + CONST374*VAR20*x) + g_8*(CONST311*VAR03*y*z + VAR05*(CONST206*VAR16*z + CONST216*VAR25*y) + VAR07*(CONST028*VAR16*VAR25 + CONST216*VAR23*y + CONST226*VAR14*z) + x*(CONST206*VAR16*VAR23 + CONST226*VAR14*VAR25 + CONST259*VAR12*z + CONST311*VAR21*y)) + g_9*(CONST015*VAR01 + VAR03*(CONST042*VAR26 + CONST253*VAR17) + VAR05*(CONST033*VAR17*VAR26 + CONST058*VAR24 + CONST155*VAR15) + VAR07*(CONST032*VAR17*VAR24 + CONST042*VAR22 + CONST235*VAR15*VAR26 + CONST361*VAR13) + x*(CONST015*VAR20 + CONST155*VAR15*VAR24 + CONST253*VAR17*VAR22 - CONST314*VAR11 + CONST361*VAR13*VAR26)) +g_z = g_0*(CONST093*VAR20*x + CONST210*VAR03*VAR26 + CONST250*VAR05*VAR24 + CONST328*VAR07*VAR22 - CONST378*VAR01) + g_1*y*(-CONST018*VAR05*VAR25 + CONST018*VAR07*VAR23 + CONST224*VAR03*z - CONST224*VAR21*x) + g_10*(CONST095*VAR15*VAR23 + CONST132*VAR17*VAR21 + CONST265*VAR13*VAR25 + CONST333*VAR11*z + CONST391*VAR19 + CONST398*VAR02*z + VAR04*(CONST131*VAR17*z + CONST376*VAR25) + VAR06*(CONST094*VAR15*z + CONST246*VAR17*VAR25 + CONST369*VAR23) + VAR08*(CONST137*VAR15*VAR25 + CONST246*VAR17*VAR23 + CONST265*VAR13*z + CONST375*VAR21)) + g_11*(CONST009*VAR26*(CONST042*VAR04*y + CONST211*VAR08*VAR14 + CONST251*VAR06*VAR16 + CONST313*VAR12) + CONST010*VAR24*(CONST058*VAR06*y + CONST142*VAR14 + CONST252*VAR08*VAR16) + CONST011*VAR22*(CONST042*VAR08*y + CONST331*VAR16) + CONST015*VAR02*y + CONST026*VAR10 + CONST076*VAR20*y + CONST142*VAR06*VAR14 + CONST314*VAR08*VAR12 + CONST331*VAR04*VAR16) + g_12*(CONST050*VAR02*z + CONST082*VAR11*z + CONST097*VAR15*VAR23 + CONST120*VAR13*VAR25 + CONST262*VAR17*VAR21 - CONST385*VAR19 + VAR04*(CONST273*VAR25 - CONST311*VAR17*z) + VAR06*(CONST017*VAR23 + CONST238*VAR15*z) + VAR08*(CONST029*VAR21 - CONST140*VAR15*VAR25 + CONST217*VAR17*VAR23)) + g_13*(VAR12*(CONST290*VAR08 - CONST290*VAR26) + VAR14*(CONST049*VAR24 - CONST186*VAR06 - CONST307*VAR08*VAR26) + VAR16*(-CONST164*VAR22 + CONST209*VAR08*VAR24 + CONST219*VAR06*VAR26 + CONST266*VAR04) + y*(-CONST285*VAR06*VAR24 - CONST297*VAR04*VAR26 + CONST346*VAR20 - CONST374*VAR02)) + g_14*(CONST104*VAR02*z + CONST114*VAR15*VAR23 + CONST146*VAR17*VAR21 + CONST194*VAR19 - CONST239*VAR13*VAR25 + VAR04*(CONST274*VAR17*z - CONST362*VAR25) + VAR06*(CONST072*VAR23 + CONST171*VAR15*z + CONST240*VAR17*VAR25) + VAR08*(CONST030*VAR21 + CONST114*VAR17*VAR23 - CONST148*VAR15*VAR25 + CONST338*VAR13*z)) + g_15*(VAR14*(CONST185*VAR08*VAR26 - CONST222*VAR24 - CONST223*VAR06) + VAR16*(CONST079*VAR06*VAR26 + CONST134*VAR08*VAR24 + CONST202*VAR22 + CONST241*VAR04) + y*(CONST046*VAR02 + CONST073*VAR20 + CONST195*VAR06*VAR24 + CONST223*VAR08*VAR22)) + g_16*(CONST022*VAR19 + CONST035*VAR02*z + CONST175*VAR15*VAR23 + CONST291*VAR17*VAR21 + VAR04*(CONST057*VAR25 + CONST135*VAR17*z) + VAR06*(CONST341*VAR15*z + CONST346*VAR23) + VAR08*(CONST108*VAR15*VAR25 + CONST158*VAR17*VAR23 + CONST337*VAR21)) + g_17*(VAR16*(-CONST044*VAR06*VAR26 + CONST044*VAR08*VAR24 + CONST144*VAR22 + CONST277*VAR04) + y*(-CONST016*VAR08*VAR22 + CONST059*VAR02 + CONST180*VAR04*VAR26 + CONST205*VAR06*VAR24 + CONST351*VAR20)) + g_18*(CONST061*VAR02*z + CONST127*VAR08*VAR21 + CONST284*VAR06*VAR23 + CONST306*VAR04*VAR25 + CONST381*VAR19 + VAR17*(CONST039*VAR04*z + CONST081*VAR08*VAR23 + CONST316*VAR06*VAR25 - CONST319*VAR21)) + g_19*y*(CONST062*VAR02 + CONST063*VAR20 + CONST204*VAR04*VAR26 + CONST204*VAR08*VAR22 + CONST279*VAR06*VAR24) + g_2*(CONST151*VAR01 + CONST162*VAR07*VAR22 + CONST319*VAR03*VAR26 + CONST348*VAR20*x + VAR17*(-CONST040*VAR22*x - CONST081*VAR05*VAR26 + CONST103*VAR07*VAR24 + CONST319*VAR03)) + g_20*(-CONST163*VAR06*VAR23 + CONST212*VAR08*VAR21 - CONST327*VAR02*z + CONST329*VAR04*VAR25 - CONST378*VAR19) + g_3*(VAR16*(-CONST183*VAR23*x + CONST228*VAR05*z + CONST267*VAR07*VAR25) + y*(CONST116*VAR07*VAR23 - CONST234*VAR05*VAR25 + CONST234*VAR21*x + CONST268*VAR03*z)) + g_4*(CONST008*VAR01 + VAR03*(CONST303*VAR17 + CONST377*VAR26) + VAR05*(CONST175*VAR15 - CONST307*VAR17*VAR26 + CONST326*VAR24) + VAR07*(CONST108*VAR15*VAR26 + CONST341*VAR17*VAR24 + CONST359*VAR22) + x*(CONST053*VAR20 + CONST307*VAR17*VAR22 + CONST341*VAR15*VAR24)) + g_5*(VAR14*(CONST147*VAR07*z - CONST147*VAR25*x) + VAR16*(CONST154*VAR05*z + CONST190*VAR07*VAR25 + CONST310*VAR23*x) + y*(CONST156*VAR21*x + CONST222*VAR05*VAR25 + CONST325*VAR03*z)) + g_6*(CONST177*VAR01 + VAR03*(CONST030*VAR26 + CONST321*VAR17) + VAR05*(-CONST193*VAR15 + CONST229*VAR17*VAR26) + VAR07*(CONST239*VAR13 + CONST258*VAR17*VAR24 + CONST362*VAR22) + x*(CONST148*VAR15*VAR24 - CONST338*VAR13*VAR26 + CONST357*VAR17*VAR22 + CONST372*VAR20)) + g_7*(-CONST221*VAR12*x*z + VAR14*(CONST136*VAR07*z + CONST260*VAR25*x) + VAR16*(CONST119*VAR05*z - CONST145*VAR23*x + CONST342*VAR07*VAR25) + y*(CONST237*VAR07*VAR23 + CONST297*VAR05*VAR25 + CONST298*VAR21*x)) + g_8*(-CONST397*VAR01 + VAR03*(CONST031*VAR26 + CONST344*VAR17) + VAR05*(CONST055*VAR24 + CONST160*VAR17*VAR26 + CONST173*VAR15) + VAR07*(CONST051*VAR22 + CONST143*VAR15*VAR26 + CONST231*VAR13 + CONST322*VAR17*VAR24) + x*(CONST024*VAR20 + CONST082*VAR11 + CONST196*VAR17*VAR22 + CONST295*VAR13*VAR26 + CONST330*VAR15*VAR24)) + g_9*(CONST070*VAR03*y*z + VAR05*(CONST121*VAR25*y + CONST168*VAR16*z) + VAR07*(CONST121*VAR23*y + CONST261*VAR16*VAR25 - CONST361*VAR14*z) + x*(CONST070*VAR21*y + CONST167*VAR16*VAR23 + CONST264*VAR12*z - CONST361*VAR14*VAR25)) diff --git a/notebooks/bwd_implementations/bwd_2.py b/notebooks/bwd_implementations/bwd_2.py new file mode 100644 index 0000000..686e576 --- /dev/null +++ b/notebooks/bwd_implementations/bwd_2.py @@ -0,0 +1,36 @@ +# -------------------- variable and constant definitions +CONST000 = 3.87298334620742 +CONST001 = 4.47213595499958 +CONST002 = -2.23606797749979 +CONST003 = -3.87298334620742 +VAR00 = x**10 +VAR01 = x**9 +VAR02 = x**8 +VAR03 = x**7 +VAR04 = x**6 +VAR05 = x**5 +VAR06 = x**4 +VAR07 = x**3 +VAR08 = x**2 +VAR09 = y**10 +VAR10 = y**9 +VAR11 = y**8 +VAR12 = y**7 +VAR13 = y**6 +VAR14 = y**5 +VAR15 = y**4 +VAR16 = y**3 +VAR17 = y**2 +VAR18 = z**10 +VAR19 = z**9 +VAR20 = z**8 +VAR21 = z**7 +VAR22 = z**6 +VAR23 = z**5 +VAR24 = z**4 +VAR25 = z**3 +VAR26 = z**2 +# -------------------- kernel implementations +g_x = CONST002*g_2*x - CONST003*g_0*z - CONST003*g_1*y + CONST003*g_4*x +g_y = CONST001*g_2*y - CONST003*g_1*x - CONST003*g_3*z +g_z = CONST002*g_2*z - CONST003*g_0*x - CONST003*g_3*y - CONST003*g_4*z diff --git a/notebooks/bwd_implementations/bwd_3.py b/notebooks/bwd_implementations/bwd_3.py new file mode 100644 index 0000000..837fbd2 --- /dev/null +++ b/notebooks/bwd_implementations/bwd_3.py @@ -0,0 +1,48 @@ +# -------------------- variable and constant definitions +CONST000 = 5.12347538297980 +CONST001 = 6.27495019900557 +CONST002 = 6.48074069840786 +CONST003 = 7.93725393319377 +CONST004 = 10.2469507659596 +CONST005 = 12.9614813968157 +CONST006 = 12.5499003980111 +CONST007 = -3.96862696659689 +CONST008 = -12.5499003980111 +CONST009 = -10.2469507659596 +CONST010 = -7.93725393319377 +CONST011 = -6.27495019900557 +CONST012 = -5.12347538297980 +CONST013 = -4.86055552380590 +CONST014 = -3.24037034920393 +CONST015 = -1.62018517460197 +VAR00 = x**10 +VAR01 = x**9 +VAR02 = x**8 +VAR03 = x**7 +VAR04 = x**6 +VAR05 = x**5 +VAR06 = x**4 +VAR07 = x**3 +VAR08 = x**2 +VAR09 = y**10 +VAR10 = y**9 +VAR11 = y**8 +VAR12 = y**7 +VAR13 = y**6 +VAR14 = y**5 +VAR15 = y**4 +VAR16 = y**3 +VAR17 = y**2 +VAR18 = z**10 +VAR19 = z**9 +VAR20 = z**8 +VAR21 = z**7 +VAR22 = z**6 +VAR23 = z**5 +VAR24 = z**4 +VAR25 = z**3 +VAR26 = z**2 +# -------------------- kernel implementations +g_x = CONST008*g_6*x*z - CONST009*g_1*y*z + CONST009*g_5*x*y + CONST010*g_3*x*y + CONST014*g_4*x*z + g_0*(CONST011*VAR08 - CONST011*VAR26) + g_2*(CONST002*VAR17 + CONST013*VAR08 + CONST015*VAR26) +g_y = CONST005*g_2*x*y + CONST005*g_4*y*z - CONST009*g_1*x*z + g_3*(CONST007*VAR08 + CONST007*VAR26 - CONST010*VAR17) + g_5*(CONST012*VAR08 - CONST012*VAR26) +g_z = -CONST008*g_0*x*z - CONST009*g_1*x*y - CONST009*g_5*y*z + CONST010*g_3*y*z + CONST014*g_2*x*z + g_4*(CONST002*VAR17 + CONST013*VAR26 + CONST015*VAR08) + g_6*(CONST011*VAR08 - CONST011*VAR26) diff --git a/notebooks/bwd_implementations/bwd_4.py b/notebooks/bwd_implementations/bwd_4.py new file mode 100644 index 0000000..8886695 --- /dev/null +++ b/notebooks/bwd_implementations/bwd_4.py @@ -0,0 +1,61 @@ +# -------------------- variable and constant definitions +CONST000 = 2.00000000000000 +CONST001 = 4.50000000000000 +CONST002 = 2.25000000000000 +CONST003 = 6.70820393249937 +CONST004 = 6.27495019900557 +CONST005 = 8.87411967464942 +CONST006 = 9.48683298050514 +CONST007 = 10.0623058987491 +CONST008 = 12.0000000000000 +CONST009 = 18.8248505970167 +CONST010 = 20.1246117974981 +CONST011 = 26.6223590239483 +CONST012 = 28.4604989415154 +CONST013 = 37.6497011940334 +CONST014 = 40.2492235949962 +CONST015 = -37.6497011940334 +CONST016 = -6.70820393249937 +CONST017 = -26.6223590239483 +CONST018 = -21.3453742061366 +CONST019 = -20.1246117974981 +CONST020 = -18.8248505970167 +CONST021 = -18.0000000000000 +CONST022 = -14.2302494707577 +CONST023 = -10.0623058987491 +CONST024 = -9.00000000000000 +CONST025 = -8.87411967464942 +CONST026 = -7.11512473537885 +CONST027 = -6.27495019900557 +CONST028 = -3.35410196624968 +VAR00 = x**10 +VAR01 = x**9 +VAR02 = x**8 +VAR03 = x**7 +VAR04 = x**6 +VAR05 = x**5 +VAR06 = x**4 +VAR07 = x**3 +VAR08 = x**2 +VAR09 = y**10 +VAR10 = y**9 +VAR11 = y**8 +VAR12 = y**7 +VAR13 = y**6 +VAR14 = y**5 +VAR15 = y**4 +VAR16 = y**3 +VAR17 = y**2 +VAR18 = z**10 +VAR19 = z**9 +VAR20 = z**8 +VAR21 = z**7 +VAR22 = z**6 +VAR23 = z**5 +VAR24 = z**4 +VAR25 = z**3 +VAR26 = z**2 +# -------------------- kernel implementations +g_x = CONST015*g_7*x*y*z + CONST022*g_5*x*y*z + g_0*(CONST017*VAR08*z - CONST025*VAR25) + g_1*y*(CONST020*VAR08 - CONST020*VAR26) + g_2*(-CONST019*VAR17*z + CONST023*VAR08*z + CONST028*VAR25) + g_3*(CONST006*VAR16 + CONST018*VAR08*y + CONST026*VAR26*y) + g_4*(CONST000*x*(CONST002*VAR26 + CONST024*VAR17) + CONST001*VAR07) + g_6*(-CONST016*VAR07 + CONST019*VAR17*x) + g_8*(CONST017*VAR26*x - CONST025*VAR07) +g_y = CONST000*g_6*y*(CONST023*VAR08 - CONST023*VAR26) + CONST014*g_2*x*y*z + g_1*(-CONST020*VAR26*x + CONST027*VAR07) + g_3*(CONST026*VAR07 + x*(CONST012*VAR17 + CONST026*VAR26)) + g_4*(CONST008*VAR16 + CONST021*VAR08*y + CONST021*VAR26*y) + g_5*(CONST026*VAR25 + z*(CONST012*VAR17 + CONST026*VAR08)) + g_7*(CONST020*VAR08*z - CONST027*VAR25) +g_z = -CONST015*g_1*x*y*z + CONST022*g_3*x*y*z + g_0*(-CONST017*VAR26*x + CONST025*VAR07) + g_2*(CONST028*VAR07 + x*(-CONST019*VAR17 + CONST023*VAR26)) + g_4*(CONST001*VAR08*z + CONST001*VAR25 + CONST021*VAR17*z) + g_5*(CONST006*VAR16 + CONST018*VAR26*y + CONST026*VAR08*y) + g_6*(CONST016*VAR25 - CONST019*VAR17*z) + g_7*y*(CONST020*VAR08 - CONST020*VAR26) + g_8*(CONST017*VAR08*z - CONST025*VAR25) diff --git a/notebooks/bwd_implementations/bwd_5.py b/notebooks/bwd_implementations/bwd_5.py new file mode 100644 index 0000000..fb26cc2 --- /dev/null +++ b/notebooks/bwd_implementations/bwd_5.py @@ -0,0 +1,97 @@ +# -------------------- variable and constant definitions +CONST000 = 1.60565407233314 +CONST001 = 3.00000000000000 +CONST002 = 3.21130814466628 +CONST003 = 1.60565407233314 +CONST004 = 6.42261628933256 +CONST005 = 6.42261628933256 +CONST006 = 8.67152307844476 +CONST007 = 8.02827036166571 +CONST008 = 6.93721846275580 +CONST009 = 11.6340690431164 +CONST010 = 12.8452325786651 +CONST011 = 6.21867148191637 +CONST012 = 6.21867148191637 +CONST013 = 16.5831239517770 +CONST014 = 12.4373429638327 +CONST015 = 16.9926454679664 +CONST016 = 20.8116553882674 +CONST017 = 12.8452325786651 +CONST018 = 13.8744369255116 +CONST019 = 24.8746859276655 +CONST020 = 24.8746859276655 +CONST021 = 27.7488738510232 +CONST022 = 5.20291384706685 +CONST023 = 29.4321253055229 +CONST024 = 29.4321253055229 +CONST025 = 33.9852909359329 +CONST026 = 33.9852909359329 +CONST027 = 7.35803132638072 +CONST028 = 41.6233107765348 +CONST029 = 46.5362761724657 +CONST030 = 51.3809303146605 +CONST031 = 51.3809303146605 +CONST032 = 83.2466215530696 +CONST033 = 88.2963759165686 +CONST034 = 101.955872807799 +CONST035 = 8.49632273398321 +CONST036 = -8.67152307844475 +CONST037 = 3.46860923137790 +CONST038 = -88.2963759165686 +CONST039 = -83.2466215530696 +CONST040 = -69.8044142586986 +CONST041 = -50.9779364038993 +CONST042 = -50.9779364038993 +CONST043 = -46.5362761724657 +CONST044 = -44.1481879582843 +CONST045 = -41.6233107765348 +CONST046 = -38.5356977359954 +CONST047 = -38.5356977359954 +CONST048 = -33.1662479035540 +CONST049 = -33.9852909359329 +CONST050 = 6.42261628933257 +CONST051 = -33.9852909359329 +CONST052 = -29.4321253055229 +CONST053 = -27.7488738510232 +CONST054 = -20.8116553882674 +CONST055 = -19.2678488679977 +CONST056 = -19.2678488679977 +CONST057 = -16.9926454679664 +CONST058 = -16.9926454679664 +CONST059 = -13.8744369255116 +CONST060 = -16.5831239517770 +CONST061 = -8.49632273398321 +CONST062 = -6.93721846275580 +CONST063 = -5.20291384706685 +CONST064 = -3.46860923137790 +VAR00 = x**10 +VAR01 = x**9 +VAR02 = x**8 +VAR03 = x**7 +VAR04 = x**6 +VAR05 = x**5 +VAR06 = x**4 +VAR07 = x**3 +VAR08 = x**2 +VAR09 = y**10 +VAR10 = y**9 +VAR11 = y**8 +VAR12 = y**7 +VAR13 = y**6 +VAR14 = y**5 +VAR15 = y**4 +VAR16 = y**3 +VAR17 = y**2 +VAR18 = z**10 +VAR19 = z**9 +VAR20 = z**8 +VAR21 = z**7 +VAR22 = z**6 +VAR23 = z**5 +VAR24 = z**4 +VAR25 = z**3 +VAR26 = z**2 +# -------------------- kernel implementations +g_x = g_0*(CONST009*VAR06 + CONST009*VAR24 + CONST040*VAR08*VAR26) + g_1*y*(CONST038*VAR08*z - CONST052*VAR25) + g_10*(CONST029*VAR07*z + CONST043*VAR25*x) + g_2*(CONST001*VAR08*(CONST059*VAR17 + CONST064*VAR26) + CONST006*VAR06 - CONST045*VAR17*VAR26 + CONST063*VAR24) + g_3*(CONST041*VAR08*y*z - CONST049*VAR16*z + CONST057*VAR25*y) + g_4*(CONST000*VAR24 + CONST001*VAR08*(CONST002*VAR26 + CONST055*VAR17) + CONST007*VAR06 + CONST010*VAR15 + CONST056*VAR17*VAR26) + g_5*(CONST048*VAR16*x + y*(CONST019*VAR07 + CONST019*VAR26*x)) + g_6*(CONST005*VAR25*x + z*(CONST004*VAR07 + CONST046*VAR17*x)) + g_7*(CONST049*VAR16*x - CONST051*VAR07*y) + g_8*(CONST008*VAR25*x + z*(CONST039*VAR17*x - CONST054*VAR07)) + g_9*y*(CONST024*VAR07 + CONST038*VAR26*x) +g_y = g_1*(CONST052*VAR07*z - CONST052*VAR25*x) + g_2*(-CONST039*VAR26*x*y + CONST053*VAR07*y) + g_3*(CONST058*VAR07*z + x*(CONST034*VAR17*z + CONST057*VAR25)) + g_4*(CONST047*VAR07*y + x*(CONST030*VAR16 + CONST046*VAR26*y)) + g_5*(CONST001*VAR17*(CONST060*VAR08 + CONST060*VAR26) + CONST011*VAR06 + CONST012*VAR24 + CONST014*VAR08*VAR26 - CONST060*VAR15) + g_6*(CONST046*VAR25*y + z*(CONST031*VAR16 + CONST046*VAR08*y)) + g_7*(CONST001*VAR17*(CONST057*VAR08 - CONST057*VAR26) - CONST061*VAR06 + CONST061*VAR24) + g_8*(CONST021*VAR25*y + CONST039*VAR08*y*z) + g_9*(CONST027*VAR06 + CONST027*VAR24 + CONST044*VAR08*VAR26) +g_z = g_0*(CONST029*VAR25*x + CONST043*VAR07*z) + g_1*y*(-CONST038*VAR26*x + CONST052*VAR07) + g_10*(CONST009*VAR06 + CONST009*VAR24 + CONST040*VAR08*VAR26) + g_2*(CONST062*VAR07*z + x*(-CONST039*VAR17*z + CONST054*VAR25)) + g_3*(CONST058*VAR07*y + x*(CONST042*VAR26*y - CONST049*VAR16)) + g_4*(CONST005*VAR07*z + x*(CONST046*VAR17*z + CONST050*VAR25)) + g_5*(CONST048*VAR16*z + y*(CONST019*VAR08*z + CONST020*VAR25)) + g_6*(CONST001*VAR26*(CONST002*VAR08 + CONST056*VAR17) + CONST003*VAR06 + CONST007*VAR24 + CONST017*VAR15 + CONST056*VAR08*VAR17) + g_7*(-CONST049*VAR16*z + CONST051*VAR25*y) + g_8*(CONST001*VAR26*(CONST018*VAR17 + CONST037*VAR08) + CONST036*VAR24 + CONST045*VAR08*VAR17 - CONST063*VAR06) + g_9*y*(CONST024*VAR25 + CONST038*VAR08*z) diff --git a/notebooks/bwd_implementations/bwd_6.py b/notebooks/bwd_implementations/bwd_6.py new file mode 100644 index 0000000..b8610bf --- /dev/null +++ b/notebooks/bwd_implementations/bwd_6.py @@ -0,0 +1,120 @@ +# -------------------- variable and constant definitions +CONST000 = 2.00000000000000 +CONST001 = 3.26558761940328 +CONST002 = 4.00000000000000 +CONST003 = 3.00000000000000 +CONST004 = 6.53117523880657 +CONST005 = 1.63279380970164 +CONST006 = 8.94318001328386 +CONST007 = 8.38944649544891 +CONST008 = 10.3266947761614 +CONST009 = 9.79676285820985 +CONST010 = 7.15454401062709 +CONST011 = 14.5309475774982 +CONST012 = 9.79676285820985 +CONST013 = 16.3279380970164 +CONST014 = 17.8863600265677 +CONST015 = 16.5227116418583 +CONST016 = 20.6533895523229 +CONST017 = 20.2812259244849 +CONST018 = 21.6333076527839 +CONST019 = 19.5935257164197 +CONST020 = 17.8863600265677 +CONST021 = 26.1247009552263 +CONST022 = 29.3902885746295 +CONST023 = 35.7727200531355 +CONST024 = 35.7727200531355 +CONST025 = 39.1870514328394 +CONST026 = 40.5624518489699 +CONST027 = 41.3067791046458 +CONST028 = 41.9472324772445 +CONST029 = 48.9838142910493 +CONST030 = 51.6334738808072 +CONST031 = 52.2494019104525 +CONST032 = 58.7805771492591 +CONST033 = 71.5454401062709 +CONST034 = 72.6547378874909 +CONST035 = 71.5454401062709 +CONST036 = 78.3741028656788 +CONST037 = 81.1249036979398 +CONST038 = 82.6135582092915 +CONST039 = 82.6135582092915 +CONST040 = -3.26558761940328 +CONST041 = 104.498803820905 +CONST042 = 117.561154298518 +CONST043 = 145.309475774982 +CONST044 = 156.748205731358 +CONST045 = 167.788929908978 +CONST046 = 208.997607641810 +CONST047 = 214.636320318813 +CONST048 = -251.683394863467 +CONST049 = -214.636320318813 +CONST050 = -214.636320318813 +CONST051 = 16.5227116418583 +CONST052 = -167.788929908978 +CONST053 = -156.748205731358 +CONST054 = -145.309475774982 +CONST055 = -123.920337313937 +CONST056 = -117.561154298518 +CONST057 = 3.26558761940328 +CONST058 = -108.166538263920 +CONST059 = -107.318160159406 +CONST060 = -104.498803820905 +CONST061 = -104.498803820905 +CONST062 = -83.8944649544891 +CONST063 = -82.6135582092915 +CONST064 = -78.3741028656788 +CONST065 = -72.6547378874909 +CONST066 = -71.5454401062709 +CONST067 = -58.7805771492591 +CONST068 = -54.0832691319598 +CONST069 = -52.2494019104525 +CONST070 = -52.2494019104525 +CONST071 = -48.9838142910492 +CONST072 = -41.3067791046458 +CONST073 = -39.1870514328394 +CONST074 = -35.7727200531355 +CONST075 = -29.3902885746295 +CONST076 = -27.0416345659799 +CONST077 = -26.1247009552263 +CONST078 = -26.1247009552263 +CONST079 = -19.5935257164197 +CONST080 = -14.5309475774982 +CONST081 = -13.5208172829900 +CONST082 = -10.7318160159406 +CONST083 = -9.79676285820985 +CONST084 = -7.15454401062709 +CONST085 = -6.76040864149498 +CONST086 = -3.38020432074749 +CONST087 = -1.63279380970164 +VAR00 = x**10 +VAR01 = x**9 +VAR02 = x**8 +VAR03 = x**7 +VAR04 = x**6 +VAR05 = x**5 +VAR06 = x**4 +VAR07 = x**3 +VAR08 = x**2 +VAR09 = y**10 +VAR10 = y**9 +VAR11 = y**8 +VAR12 = y**7 +VAR13 = y**6 +VAR14 = y**5 +VAR15 = y**4 +VAR16 = y**3 +VAR17 = y**2 +VAR18 = z**10 +VAR19 = z**9 +VAR20 = z**8 +VAR21 = z**7 +VAR22 = z**6 +VAR23 = z**5 +VAR24 = z**4 +VAR25 = z**3 +VAR26 = z**2 +# -------------------- kernel implementations +g_x = g_0*(CONST054*VAR08*VAR25 - CONST065*VAR06*z - CONST080*VAR23) + g_1*y*(CONST028*VAR06 + CONST028*VAR24 + CONST048*VAR08*VAR26) + g_10*(CONST000*x*(CONST006*VAR24 + CONST059*VAR17*VAR26) + CONST002*VAR07*(CONST006*VAR26 + CONST014*VAR17) + CONST082*VAR05) + g_11*y*(-CONST052*VAR07*z + CONST052*VAR25*x) + g_12*(-CONST054*VAR07*VAR26 + CONST065*VAR24*x + CONST080*VAR05) + g_2*(-CONST074*VAR06*z + CONST084*VAR23 + VAR17*(CONST049*VAR08*z - CONST066*VAR25)) + g_3*(VAR16*(CONST064*VAR08 - CONST064*VAR26) + y*(CONST029*VAR06 + CONST067*VAR08*VAR26 + CONST075*VAR24)) + g_4*(CONST003*VAR08*(CONST004*VAR25 + CONST069*VAR17*z) + CONST013*VAR06*z - CONST040*VAR23 - CONST070*VAR15*z + CONST070*VAR17*VAR25) + g_5*(CONST003*VAR08*(CONST016*VAR26*y + CONST072*VAR16) + CONST008*VAR24*y + CONST015*VAR14 + CONST030*VAR06*y + CONST072*VAR16*VAR26) + g_6*(CONST000*x*(CONST026*VAR17*VAR26 + CONST076*VAR15 + CONST086*VAR24) + CONST002*VAR07*(CONST017*VAR17 + CONST086*VAR26) + CONST085*VAR05) + g_7*(-CONST072*VAR25*x*y + z*(CONST063*VAR16*x - CONST072*VAR07*y)) + g_8*(CONST000*x*(CONST077*VAR15 - CONST087*VAR24) + CONST002*VAR07*(-CONST077*VAR17 + CONST087*VAR26) + CONST083*VAR05) + g_9*(CONST053*VAR16*x*z + y*(CONST042*VAR07*z - CONST073*VAR25*x)) +g_y = CONST000*g_2*y*(CONST066*VAR07*z - CONST066*VAR25*x) + g_1*(CONST007*VAR05 + CONST028*VAR24*x + CONST062*VAR07*VAR26) + g_10*(CONST024*VAR06*y + CONST050*VAR08*VAR26*y - CONST074*VAR24*y) + g_11*(CONST007*VAR23 + CONST028*VAR06*z + CONST062*VAR08*VAR25) + g_3*(CONST003*VAR17*(-CONST064*VAR26*x + CONST078*VAR07) + CONST009*VAR05 + CONST075*VAR24*x + CONST079*VAR07*VAR26) + g_4*(CONST061*VAR07*y*z + x*(CONST046*VAR16*z + CONST060*VAR25*y)) + g_5*(CONST008*VAR05 + VAR07*(CONST016*VAR26 + CONST055*VAR17) + x*(CONST008*VAR24 + CONST055*VAR17*VAR26 - CONST063*VAR15)) + g_6*(CONST018*VAR14 + CONST026*VAR06*y + CONST026*VAR24*y + CONST058*VAR16*VAR26 + VAR08*(CONST037*VAR26*y + CONST058*VAR16)) + g_7*(CONST008*VAR23 + VAR25*(CONST016*VAR08 + CONST055*VAR17) + z*(CONST008*VAR06 + CONST039*VAR15 + CONST055*VAR08*VAR17)) + g_8*(CONST060*VAR08*VAR16 - CONST060*VAR16*VAR26 + CONST069*VAR24*y - CONST070*VAR06*y) + g_9*(CONST003*VAR17*(CONST064*VAR08*z - CONST077*VAR25) + CONST022*VAR06*z - CONST079*VAR08*VAR25 + CONST083*VAR23) +g_z = g_0*(CONST054*VAR07*VAR26 - CONST065*VAR24*x - CONST080*VAR05) + g_1*y*(CONST052*VAR07*z - CONST052*VAR25*x) + g_10*(CONST020*VAR06*z + CONST035*VAR17*VAR25 + CONST082*VAR23 + VAR08*(CONST050*VAR17*z - CONST074*VAR25)) + g_11*y*(CONST028*VAR06 + CONST028*VAR24 + CONST048*VAR08*VAR26) + g_12*(CONST054*VAR08*VAR25 - CONST065*VAR06*z - CONST080*VAR23) + g_2*(CONST074*VAR24*x - CONST084*VAR05 + VAR17*(-CONST049*VAR26*x + CONST066*VAR07)) + g_3*(-CONST053*VAR16*x*z + y*(CONST056*VAR25*x + CONST073*VAR07*z)) + g_4*(CONST057*VAR05 + VAR07*(CONST069*VAR17 - CONST079*VAR26) + x*(CONST013*VAR24 + CONST053*VAR17*VAR26 - CONST070*VAR15)) + g_5*(-CONST072*VAR07*y*z + x*(CONST063*VAR16*z - CONST072*VAR25*y)) + g_6*(CONST037*VAR17*VAR25 + CONST068*VAR15*z + CONST085*VAR06*z + CONST085*VAR23 + VAR08*(CONST037*VAR17*z + CONST081*VAR25)) + g_7*(CONST003*VAR26*(CONST016*VAR08*y + CONST072*VAR16) + CONST008*VAR06*y + CONST030*VAR24*y + CONST051*VAR14 + CONST072*VAR08*VAR16) + g_8*(CONST004*VAR08*VAR25 + CONST040*VAR06*z + CONST061*VAR17*VAR25 - CONST070*VAR15*z - CONST083*VAR23) + g_9*(VAR16*(CONST064*VAR08 - CONST064*VAR26) + y*(CONST022*VAR06 - CONST067*VAR08*VAR26 + CONST071*VAR24)) diff --git a/notebooks/bwd_implementations/bwd_7.py b/notebooks/bwd_implementations/bwd_7.py new file mode 100644 index 0000000..8b42d50 --- /dev/null +++ b/notebooks/bwd_implementations/bwd_7.py @@ -0,0 +1,189 @@ +# -------------------- variable and constant definitions +CONST000 = 1.66389743899677 +CONST001 = 3.00000000000000 +CONST002 = 4.99169231699030 +CONST003 = 5.00000000000000 +CONST004 = 3.32779487799353 +CONST005 = 8.31948719498384 +CONST006 = 9.19753915797590 +CONST007 = 9.37968632871057 +CONST008 = 11.7655316231354 +CONST009 = 11.7655316231354 +CONST010 = 11.6472820729774 +CONST011 = 9.19753915797590 +CONST012 = 16.5555704843566 +CONST013 = 17.5477863187212 +CONST014 = 20.4939015319192 +CONST015 = 532.447180478965 +CONST016 = 22.0740939791422 +CONST017 = 23.5310632462709 +CONST018 = 23.5310632462709 +CONST019 = 20.4939015319192 +CONST020 = 27.1108834234519 +CONST021 = 29.9501539019418 +CONST022 = 33.1111409687132 +CONST023 = 33.2779487799353 +CONST024 = 36.7901566319036 +CONST025 = 36.7901566319036 +CONST026 = 38.4260653723485 +CONST027 = 38.4260653723485 +CONST028 = 37.6497011940334 +CONST029 = 38.4260653723485 +CONST030 = 44.1481879582843 +CONST031 = 44.1481879582843 +CONST032 = -4.99169231699030 +CONST033 = 47.0621264925418 +CONST034 = 562.781179722634 +CONST035 = 50.8329064189723 +CONST036 = 44.3705983732471 +CONST037 = 47.0621264925417 +CONST038 = 55.1852349478554 +CONST039 = 56.2781179722634 +CONST040 = 56.2781179722634 +CONST041 = 62.7495019900557 +CONST042 = 66.5558975598707 +CONST043 = 70.5931897388126 +CONST044 = -441.481879582843 +CONST045 = -441.481879582843 +CONST046 = 75.2994023880668 +CONST047 = 76.8521307446970 +CONST048 = 76.8521307446970 +CONST049 = 76.8521307446970 +CONST050 = -8.47215106982872 +CONST051 = 99.8338463398060 +CONST052 = 101.665812837945 +CONST053 = 105.286717912327 +CONST054 = 110.370469895711 +CONST055 = 110.370469895711 +CONST056 = -399.335385359224 +CONST057 = 117.655316231354 +CONST058 = 122.963409191515 +CONST059 = 122.963409191515 +CONST060 = 133.111795119741 +CONST061 = -376.497011940334 +CONST062 = -376.497011940334 +CONST063 = 140.695294930659 +CONST064 = 141.186379477625 +CONST065 = 147.160626527614 +CONST066 = 147.160626527614 +CONST067 = 153.704261489394 +CONST068 = 153.704261489394 +CONST069 = -350.955726374425 +CONST070 = 177.482393492989 +CONST071 = 199.667692679612 +CONST072 = 203.331625675889 +CONST073 = 203.331625675889 +CONST074 = -307.408522978788 +CONST075 = -9.60651634308713 +CONST076 = -9.37968632871057 +CONST077 = 220.740939791422 +CONST078 = 220.740939791422 +CONST079 = -281.390589861317 +CONST080 = -1.66389743899677 +CONST081 = -266.223590239483 +CONST082 = -263.216794780819 +CONST083 = 250.998007960223 +CONST084 = -263.216794780818 +CONST085 = -250.998007960223 +CONST086 = 263.216794780818 +CONST087 = 263.216794780819 +CONST088 = 266.223590239483 +CONST089 = 281.390589861317 +CONST090 = 281.390589861317 +CONST091 = -220.740939791422 +CONST092 = -220.740939791422 +CONST093 = -199.667692679612 +CONST094 = -1.60108605718119 +CONST095 = -187.593726574211 +CONST096 = -177.482393492989 +CONST097 = -9.60651634308712 +CONST098 = -9.19753915797590 +CONST099 = 350.955726374425 +CONST100 = -153.704261489394 +CONST101 = -147.160626527614 +CONST102 = -140.695294930659 +CONST103 = 376.497011940334 +CONST104 = -133.111795119741 +CONST105 = -133.111795119741 +CONST106 = -125.499003980111 +CONST107 = -125.499003980111 +CONST108 = 399.335385359224 +CONST109 = -105.286717912327 +CONST110 = -101.665812837945 +CONST111 = -99.8338463398060 +CONST112 = -101.665812837945 +CONST113 = -4.80325817154356 +CONST114 = -81.3326502703558 +CONST115 = -81.3326502703557 +CONST116 = -76.8521307446970 +CONST117 = -75.2994023880668 +CONST118 = 441.481879582843 +CONST119 = -70.5931897388126 +CONST120 = 441.481879582843 +CONST121 = -66.2222819374265 +CONST122 = -66.5558975598707 +CONST123 = -66.5558975598707 +CONST124 = -62.7495019900557 +CONST125 = -56.2781179722634 +CONST126 = -55.1852349478554 +CONST127 = -55.1852349478554 +CONST128 = -50.8329064189723 +CONST129 = -50.8329064189723 +CONST130 = -562.781179722634 +CONST131 = -47.0621264925418 +CONST132 = -50.8329064189724 +CONST133 = -44.1481879582843 +CONST134 = -44.3705983732471 +CONST135 = -40.6663251351779 +CONST136 = -40.6663251351779 +CONST137 = -8.31948719498384 +CONST138 = -37.6497011940334 +CONST139 = -33.2779487799353 +CONST140 = -29.9501539019418 +CONST141 = -25.4164532094862 +CONST142 = -25.4164532094862 +CONST143 = -23.5310632462709 +CONST144 = -532.447180478965 +CONST145 = -19.2130326861743 +CONST146 = -17.5477863187212 +CONST147 = -12.8765548211663 +CONST148 = -11.6472820729774 +CONST149 = -11.2076024002683 +CONST150 = -9.19753915797590 +CONST151 = -11.0370469895711 +CONST152 = -11.7655316231354 +CONST153 = -12.8765548211663 +CONST154 = -4.80325817154356 +CONST155 = -3.32779487799353 +CONST156 = -1.60108605718119 +VAR00 = x**10 +VAR01 = x**9 +VAR02 = x**8 +VAR03 = x**7 +VAR04 = x**6 +VAR05 = x**5 +VAR06 = x**4 +VAR07 = x**3 +VAR08 = x**2 +VAR09 = y**10 +VAR10 = y**9 +VAR11 = y**8 +VAR12 = y**7 +VAR13 = y**6 +VAR14 = y**5 +VAR15 = y**4 +VAR16 = y**3 +VAR17 = y**2 +VAR18 = z**10 +VAR19 = z**9 +VAR20 = z**8 +VAR21 = z**7 +VAR22 = z**6 +VAR23 = z**5 +VAR24 = z**4 +VAR25 = z**3 +VAR26 = z**2 +# -------------------- kernel implementations +g_x = g_0*(CONST082*VAR08*VAR24 - CONST084*VAR06*VAR26 + CONST146*VAR04 - CONST146*VAR22) + g_1*y*(CONST039*VAR23 + CONST089*VAR06*z + CONST130*VAR08*VAR25) + g_10*(CONST155*VAR23*x + VAR25*(-CONST105*VAR17*x + CONST139*VAR07) + z*(-CONST056*VAR07*VAR17 + CONST081*VAR15*x + CONST140*VAR05)) + g_11*(VAR16*(CONST044*VAR26*x - CONST101*VAR07) + y*(CONST054*VAR24*x - CONST091*VAR07*VAR26 + CONST121*VAR05)) + g_12*(CONST022*VAR23*x + VAR25*(CONST024*VAR07 + CONST045*VAR17*x) + z*(-CONST044*VAR07*VAR17 + CONST126*VAR05)) + g_13*y*(CONST079*VAR24*x + CONST125*VAR05 - CONST130*VAR07*VAR26) + g_14*(-CONST069*VAR07*VAR25 + CONST109*VAR05*z + CONST109*VAR23*x) + g_2*(CONST001*VAR08*(CONST091*VAR17*VAR26 - CONST150*VAR24) + CONST003*VAR06*(CONST012*VAR26 + CONST016*VAR17) + CONST055*VAR17*VAR24 + CONST147*VAR04 + CONST150*VAR22) + g_3*(VAR16*(CONST044*VAR08*z + CONST066*VAR25) + y*(-CONST091*VAR06*z + CONST133*VAR23)) + g_4*(CONST001*VAR08*(CONST122*VAR17*VAR26 + CONST134*VAR15 - CONST137*VAR24) + CONST003*VAR06*(CONST000*VAR26 - CONST139*VAR17) - CONST032*VAR22 - CONST105*VAR15*VAR26 + CONST111*VAR17*VAR24 + CONST148*VAR04) + g_5*(CONST001*VAR08*(CONST106*VAR16*z - CONST131*VAR25*y) + CONST057*VAR06*y*z + CONST107*VAR16*VAR25 - CONST117*VAR14*z - CONST143*VAR23*y) + g_6*(CONST001*VAR08*(CONST116*VAR15 - CONST116*VAR17*VAR26 + CONST154*VAR24) + CONST003*VAR06*(CONST026*VAR17 + CONST113*VAR26) + CONST014*VAR13 + CONST027*VAR17*VAR24 + CONST116*VAR15*VAR26 + CONST149*VAR04 + CONST156*VAR22) + g_7*(CONST114*VAR14*x + VAR16*(CONST072*VAR07 + CONST073*VAR26*x) + y*(CONST110*VAR07*VAR26 + CONST128*VAR05 + CONST129*VAR24*x)) + g_8*(CONST075*VAR23*x + VAR25*(-CONST100*VAR17*x + CONST145*VAR07) + z*(CONST067*VAR07*VAR17 + CONST097*VAR05 + CONST100*VAR15*x)) + g_9*(-CONST085*VAR07*VAR16 + CONST117*VAR14*x + y*(CONST018*VAR24*x + CONST119*VAR05 + CONST131*VAR07*VAR26)) +g_y = g_1*(CONST039*VAR23*x + CONST095*VAR07*VAR25 - CONST125*VAR05*z) + g_10*(CONST123*VAR23*y + VAR25*(-CONST096*VAR16 - CONST105*VAR08*y) + z*(-CONST093*VAR06*y + CONST144*VAR08*VAR16)) + g_11*(CONST001*VAR17*(CONST025*VAR06 + CONST025*VAR24 + CONST092*VAR08*VAR26) - CONST126*VAR06*VAR26 - CONST126*VAR08*VAR24 + CONST151*VAR04 + CONST151*VAR22) + g_12*(CONST030*VAR23*y + CONST045*VAR08*VAR25*y - CONST092*VAR06*y*z) + g_13*(CONST076*VAR04 - CONST076*VAR22 - CONST102*VAR06*VAR26 + CONST102*VAR08*VAR24) + g_2*(CONST030*VAR05*y + CONST045*VAR07*VAR26*y - CONST092*VAR24*x*y) + g_3*(CONST001*VAR17*(CONST066*VAR25*x + CONST101*VAR07*z) - CONST133*VAR05*z + CONST133*VAR23*x) + g_4*(-CONST123*VAR05*y + VAR07*(CONST096*VAR16 + CONST104*VAR26*y) + x*(CONST093*VAR24*y - CONST144*VAR16*VAR26)) + g_5*(-CONST143*VAR05*z + VAR07*(CONST062*VAR17*z - CONST131*VAR25) + x*(CONST061*VAR17*VAR25 - CONST062*VAR15*z - CONST143*VAR23)) + g_6*(CONST048*VAR05*y + VAR07*(CONST074*VAR16 - CONST100*VAR26*y) + x*(CONST058*VAR14 + CONST074*VAR16*VAR26 - CONST116*VAR24*y)) + g_7*(CONST001*VAR17*(-CONST112*VAR08*VAR26 - CONST128*VAR06 - CONST128*VAR24) + CONST003*VAR15*(CONST135*VAR08 + CONST136*VAR26) + CONST020*VAR13 + CONST050*VAR04 + CONST050*VAR22 + CONST141*VAR06*VAR26 + CONST142*VAR08*VAR24) + g_8*(CONST048*VAR23*y + VAR25*(CONST074*VAR16 - CONST100*VAR08*y) + z*(CONST049*VAR06*y + CONST059*VAR14 + CONST074*VAR08*VAR16)) + g_9*(CONST001*VAR17*(-CONST124*VAR06 + CONST124*VAR24) + CONST003*VAR15*(CONST138*VAR08 - CONST138*VAR26) + CONST009*VAR08*VAR24 + CONST152*VAR04 + CONST152*VAR06*VAR26 - CONST152*VAR22) +g_z = g_0*(CONST069*VAR07*VAR25 - CONST109*VAR05*z - CONST109*VAR23*x) + g_1*y*(-CONST079*VAR24*x - CONST125*VAR05 + CONST130*VAR07*VAR26) + g_10*(CONST001*VAR26*(-CONST123*VAR08*VAR17 - CONST134*VAR15 + CONST137*VAR06) + CONST003*VAR24*(CONST080*VAR08 + CONST139*VAR17) + CONST032*VAR04 + CONST105*VAR08*VAR15 - CONST111*VAR06*VAR17 - CONST148*VAR22) + g_11*(VAR16*(CONST044*VAR08*z - CONST101*VAR25) + y*(CONST054*VAR06*z - CONST091*VAR08*VAR25 + CONST121*VAR23)) + g_12*(CONST001*VAR26*(CONST091*VAR08*VAR17 - CONST098*VAR06) + CONST003*VAR24*(CONST012*VAR08 + CONST016*VAR17) + CONST055*VAR06*VAR17 + CONST098*VAR04 + CONST153*VAR22) + g_13*y*(-CONST079*VAR06*z - CONST125*VAR23 + CONST130*VAR08*VAR25) + g_14*(-CONST082*VAR06*VAR26 + CONST084*VAR08*VAR24 + CONST146*VAR04 - CONST146*VAR22) + g_2*(CONST022*VAR05*z + VAR07*(CONST025*VAR25 + CONST045*VAR17*z) + x*(-CONST044*VAR17*VAR25 + CONST127*VAR23)) + g_3*(VAR16*(-CONST045*VAR26*x + CONST101*VAR07) + y*(CONST091*VAR24*x - CONST133*VAR05)) + g_4*(CONST004*VAR05*z + VAR07*(CONST104*VAR17*z - CONST139*VAR25) + x*(CONST056*VAR17*VAR25 - CONST081*VAR15*z - CONST140*VAR23)) + g_5*(-CONST143*VAR05*y + VAR07*(CONST064*VAR26*y + CONST106*VAR16) + x*(CONST057*VAR24*y + CONST061*VAR16*VAR26 - CONST117*VAR14)) + g_6*(CONST097*VAR05*z + VAR07*(-CONST100*VAR17*z + CONST145*VAR25) + x*(CONST075*VAR23 + CONST100*VAR15*z - CONST100*VAR17*VAR25)) + g_7*(CONST115*VAR14*z + VAR16*(CONST072*VAR25 + CONST073*VAR08*z) + y*(CONST112*VAR08*VAR25 + CONST128*VAR23 + CONST132*VAR06*z)) + g_8*(CONST001*VAR26*(-CONST116*VAR08*VAR17 + CONST116*VAR15 + CONST154*VAR06) + CONST003*VAR24*(CONST026*VAR17 + CONST154*VAR08) + CONST019*VAR13 + CONST029*VAR06*VAR17 + CONST094*VAR04 + CONST116*VAR08*VAR15 + CONST149*VAR22) + g_9*(CONST085*VAR16*VAR25 - CONST117*VAR14*z + y*(CONST037*VAR08*VAR25 - CONST119*VAR23 + CONST143*VAR06*z)) diff --git a/notebooks/bwd_implementations/bwd_8.py b/notebooks/bwd_implementations/bwd_8.py new file mode 100644 index 0000000..53794e3 --- /dev/null +++ b/notebooks/bwd_implementations/bwd_8.py @@ -0,0 +1,259 @@ +# -------------------- variable and constant definitions +CONST000 = 2.00000000000000 +CONST001 = 3.00000000000000 +CONST002 = 4.50964677801932 +CONST003 = 517.445649319810 +CONST004 = 5.00000000000000 +CONST005 = 6.78376969317208 +CONST006 = 4.00000000000000 +CONST007 = 9.01929355603863 +CONST008 = 6.76447016702898 +CONST009 = 6.00000000000000 +CONST010 = 12.9361412329953 +CONST011 = 13.5675393863442 +CONST012 = 15.0965641786467 +CONST013 = 13.1367135230810 +CONST014 = 10.3359109268366 +CONST015 = 13.1367135230810 +CONST016 = 20.6718218536732 +CONST017 = 19.4042118494929 +CONST018 = 525.468540923241 +CONST019 = -489.184589393411 +CONST020 = 24.7386337537060 +CONST021 = 1050.93708184648 +CONST022 = 26.4189873126318 +CONST023 = 26.2734270461621 +CONST024 = 27.0578806681159 +CONST025 = 24.7386337537060 +CONST026 = 32.9848450049413 +CONST027 = 33.9188484658604 +CONST028 = 550.332663067587 +CONST029 = 39.4101405692431 +CONST030 = -978.369178786822 +CONST031 = 48.5105296237322 +CONST032 = 1585.13923875791 +CONST033 = 51.7445649319810 +CONST034 = 52.8379746252636 +CONST035 = 48.9184589393411 +CONST036 = 47.4863878522046 +CONST037 = 1085.27064731784 +CONST038 = 61.1480736741764 +CONST039 = 61.1480736741764 +CONST040 = 1085.40315090753 +CONST041 = 65.6835676154051 +CONST042 = 67.8376969317208 +CONST043 = -1467.55376818023 +CONST044 = 70.0624721230988 +CONST045 = -12.2296147348353 +CONST046 = 72.3513764878561 +CONST047 = 582.126355484786 +CONST048 = -437.890450769368 +CONST049 = -434.108258927137 +CONST050 = -434.108258927137 +CONST051 = 79.2569619378954 +CONST052 = -432.926090689854 +CONST053 = 87.5780901538735 +CONST054 = -1447.02752975712 +CONST055 = 91.9569946615672 +CONST056 = -420.374832738593 +CONST057 = 6.46807061649763 +CONST058 = 97.0210592474644 +CONST059 = 97.0210592474644 +CONST060 = 103.489129863962 +CONST061 = 103.489129863962 +CONST062 = -407.026181590325 +CONST063 = 108.231522672464 +CONST064 = 108.231522672464 +CONST065 = 110.066532613517 +CONST066 = 110.066532613517 +CONST067 = 620.934779183772 +CONST068 = -396.284809689477 +CONST069 = 129.361412329953 +CONST070 = 132.094936563159 +CONST071 = 434.108258927137 +CONST072 = 649.389136034782 +CONST073 = 649.389136034781 +CONST074 = 434.108258927137 +CONST075 = 144.702752975712 +CONST076 = -366.888442045058 +CONST077 = -366.888442045058 +CONST078 = -361.756882439281 +CONST079 = 158.513923875791 +CONST080 = -6.78376969317208 +CONST081 = 162.810472636130 +CONST082 = -350.312360615494 +CONST083 = -346.340872551883 +CONST084 = -346.340872551883 +CONST085 = 173.170436275942 +CONST086 = 173.170436275942 +CONST087 = 175.156180307747 +CONST088 = 183.444221022529 +CONST089 = 183.444221022529 +CONST090 = -325.620945272260 +CONST091 = -13.5289403340579 +CONST092 = -13.5675393863442 +CONST093 = 194.042118494929 +CONST094 = 194.042118494929 +CONST095 = 197.050702846215 +CONST096 = -11.3224231339851 +CONST097 = 203.513090795162 +CONST098 = -814.052363180650 +CONST099 = 723.513764878561 +CONST100 = 210.187416369296 +CONST101 = 210.187416369296 +CONST102 = -814.052363180650 +CONST103 = 216.463045344927 +CONST104 = 217.054129463568 +CONST105 = 216.463045344927 +CONST106 = 220.133065227035 +CONST107 = -291.063177742393 +CONST108 = 220.133065227035 +CONST109 = -792.569619378954 +CONST110 = 236.460843415458 +CONST111 = -271.350787726883 +CONST112 = 244.592294696705 +CONST113 = 244.592294696706 +CONST114 = 244.592294696706 +CONST115 = -776.168473979715 +CONST116 = -262.734270461621 +CONST117 = -259.755654413913 +CONST118 = -258.722824659905 +CONST119 = 262.734270461621 +CONST120 = 262.734270461621 +CONST121 = -244.215708954195 +CONST122 = 271.350787726883 +CONST123 = 271.350787726883 +CONST124 = -236.460843415458 +CONST125 = 792.569619378954 +CONST126 = 291.063177742393 +CONST127 = -217.054129463568 +CONST128 = -216.463045344927 +CONST129 = -216.463045344927 +CONST130 = -216.463045344927 +CONST131 = -723.513764878561 +CONST132 = 814.052363180650 +CONST133 = -210.187416369296 +CONST134 = -210.187416369296 +CONST135 = 814.052363180650 +CONST136 = -197.050702846215 +CONST137 = 317.027847751582 +CONST138 = -194.042118494929 +CONST139 = -13.1367135230810 +CONST140 = 324.694568017391 +CONST141 = 325.620945272260 +CONST142 = 324.694568017391 +CONST143 = -175.156180307747 +CONST144 = 1085.27064731784 +CONST145 = 350.312360615494 +CONST146 = -162.810472636130 +CONST147 = -162.347284008695 +CONST148 = 865.852181379709 +CONST149 = -158.513923875791 +CONST150 = 361.756882439281 +CONST151 = -144.702752975712 +CONST152 = -649.389136034782 +CONST153 = -129.877827206956 +CONST154 = -129.361412329953 +CONST155 = 388.084236989858 +CONST156 = 396.284809689477 +CONST157 = -115.446957517294 +CONST158 = -108.231522672464 +CONST159 = -108.231522672464 +CONST160 = 407.026181590325 +CONST161 = -103.489129863962 +CONST162 = -97.0210592474644 +CONST163 = -94.7025823384056 +CONST164 = 420.374832738593 +CONST165 = -91.9569946615672 +CONST166 = 1447.02752975712 +CONST167 = -87.5780901538735 +CONST168 = -85.6073031438469 +CONST169 = -85.6073031438469 +CONST170 = -81.1736420043477 +CONST171 = 432.926090689854 +CONST172 = -79.2569619378954 +CONST173 = -81.1736420043477 +CONST174 = 432.926090689854 +CONST175 = 437.890450769368 +CONST176 = 434.108258927137 +CONST177 = -79.2569619378954 +CONST178 = -72.3513764878561 +CONST179 = -72.1543484483091 +CONST180 = -70.0624721230988 +CONST181 = -72.1543484483091 +CONST182 = -67.8376969317208 +CONST183 = -65.6835676154052 +CONST184 = -61.1480736741764 +CONST185 = -1085.27064731784 +CONST186 = -61.1480736741764 +CONST187 = -1085.40315090753 +CONST188 = -57.7234787586472 +CONST189 = -12.9361412329953 +CONST190 = -1085.27064731784 +CONST191 = -52.8379746252636 +CONST192 = -51.7445649319810 +CONST193 = -1585.13923875791 +CONST194 = -48.5105296237322 +CONST195 = -47.4863878522046 +CONST196 = 978.369178786822 +CONST197 = 978.369178786822 +CONST198 = -517.445649319810 +CONST199 = -40.7026181590325 +CONST200 = -40.5868210021738 +CONST201 = -39.4101405692431 +CONST202 = -40.7026181590325 +CONST203 = -36.0771742241545 +CONST204 = -1056.75949250527 +CONST205 = -29.1063177742393 +CONST206 = 485.105296237322 +CONST207 = -26.2734270461621 +CONST208 = -26.4189873126318 +CONST209 = -1050.93708184648 +CONST210 = -22.6382471577417 +CONST211 = -20.6718218536732 +CONST212 = -19.4042118494929 +CONST213 = -20.3513090795162 +CONST214 = -528.379746252636 +CONST215 = -15.0965641786467 +CONST216 = -13.5675393863442 +CONST217 = -525.468540923241 +CONST218 = -11.3224231339851 +CONST219 = -13.5289403340579 +CONST220 = -9.70210592474644 +CONST221 = -10.3359109268366 +CONST222 = -6.46807061649763 +CONST223 = -13.1367135230810 +CONST224 = -12.2296147348353 +CONST225 = -3.23403530824881 +CONST226 = -1034.89129863962 +VAR00 = x**10 +VAR01 = x**9 +VAR02 = x**8 +VAR03 = x**7 +VAR04 = x**6 +VAR05 = x**5 +VAR06 = x**4 +VAR07 = x**3 +VAR08 = x**2 +VAR09 = y**10 +VAR10 = y**9 +VAR11 = y**8 +VAR12 = y**7 +VAR13 = y**6 +VAR14 = y**5 +VAR15 = y**4 +VAR16 = y**3 +VAR17 = y**2 +VAR18 = z**10 +VAR19 = z**9 +VAR20 = z**8 +VAR21 = z**7 +VAR22 = z**6 +VAR23 = z**5 +VAR24 = z**4 +VAR25 = z**3 +VAR26 = z**2 +# -------------------- kernel implementations +g_x = g_0*(CONST049*VAR08*VAR23 - CONST131*VAR06*VAR25 + CONST151*VAR04*z - CONST211*VAR21) + g_1*y*(CONST178*VAR04 - CONST178*VAR22 + CONST185*VAR08*VAR24 - CONST190*VAR06*VAR26) + g_10*(CONST017*VAR05*VAR26 + CONST161*VAR13*x - CONST189*VAR03 - CONST198*VAR07*VAR15 + CONST222*VAR22*x + VAR17*(CONST058*VAR24*x + CONST107*VAR05 + CONST138*VAR07*VAR26)) + g_11*(CONST056*VAR14*x*z + VAR16*(-CONST082*VAR25*x - CONST209*VAR07*z) + y*(CONST116*VAR07*VAR25 + CONST124*VAR05*z + CONST207*VAR23*x)) + g_12*(CONST011*VAR03 + CONST182*VAR07*VAR24 + CONST199*VAR05*VAR26 + CONST216*VAR22*x + VAR15*(CONST098*VAR26*x + CONST122*VAR07) + VAR17*(-CONST102*VAR07*VAR26 + CONST121*VAR05 + CONST160*VAR24*x)) + g_13*(VAR16*(-CONST030*VAR07*z + CONST030*VAR25*x) + y*(CONST076*VAR05*z + CONST106*VAR23*x + CONST112*VAR07*VAR25)) + g_14*(CONST012*VAR03 + CONST149*VAR05*VAR26 - CONST191*VAR22*x + VAR17*(CONST109*VAR24*x + CONST149*VAR05 - CONST193*VAR07*VAR26)) + g_15*y*(CONST050*VAR05*z + CONST050*VAR23*x - CONST054*VAR07*VAR25) + g_16*(CONST050*VAR05*VAR26 - CONST131*VAR07*VAR24 + CONST151*VAR22*x - CONST211*VAR03) + g_2*(CONST001*VAR08*(-CONST208*VAR23 + CONST214*VAR17*VAR25) + CONST004*VAR06*(-CONST149*VAR17*z - CONST208*VAR25) - CONST149*VAR17*VAR23 + CONST172*VAR04*z + CONST218*VAR21) + g_3*(VAR16*(CONST043*VAR08*VAR26 + CONST113*VAR06 + CONST114*VAR24) + y*(CONST028*VAR06*VAR26 + CONST088*VAR08*VAR24 + CONST168*VAR04 + CONST184*VAR22)) + g_4*(CONST001*VAR08*(CONST005*VAR23 + CONST111*VAR15*z) + CONST004*VAR06*(CONST080*VAR25 - CONST146*VAR17*z) + CONST005*VAR21 - CONST111*VAR15*VAR25 + CONST146*VAR17*VAR23 + CONST195*VAR04*z) + g_5*(VAR14*(CONST133*VAR08 - CONST134*VAR26) + VAR16*(-CONST048*VAR06 + CONST116*VAR24 + CONST217*VAR08*VAR26) + y*(CONST041*VAR06*VAR26 + CONST095*VAR08*VAR24 + CONST165*VAR04 - CONST201*VAR22)) + g_6*(CONST001*VAR08*(CONST093*VAR17*VAR25 + CONST118*VAR15*z + CONST220*VAR23) + CONST004*VAR06*(-CONST162*VAR17*z + CONST220*VAR25) + CONST118*VAR15*VAR25 - CONST161*VAR13*z - CONST162*VAR17*VAR23 + CONST210*VAR04*z + CONST225*VAR21) + g_7*(CONST001*VAR08*(-CONST128*VAR16*VAR26 + CONST153*VAR14 + CONST200*VAR24*y) + CONST004*VAR06*(CONST063*VAR16 + CONST200*VAR26*y) + CONST020*VAR12 + CONST153*VAR14*VAR26 - CONST158*VAR16*VAR24 + CONST163*VAR04*y + CONST219*VAR22*y) + g_8*(CONST000*x*(CONST002*VAR22 - CONST128*VAR15*VAR26 + CONST158*VAR17*VAR24 + CONST188*VAR13) + CONST006*VAR07*(CONST008*VAR24 - CONST158*VAR15 + CONST159*VAR17*VAR26) + CONST007*VAR03 + CONST009*VAR05*(CONST002*VAR26 + CONST203*VAR17)) + g_9*(CONST173*VAR23*x*y + VAR25*(CONST147*VAR07*y + CONST171*VAR16*x) + z*(CONST117*VAR14*x + CONST170*VAR05*y + CONST171*VAR07*VAR16)) +g_y = CONST000*g_14*y*(-CONST068*VAR06*VAR26 + CONST068*VAR08*VAR24 + CONST208*VAR04 - CONST208*VAR22) + g_1*(CONST078*VAR07*VAR24 + CONST104*VAR05*VAR26 - CONST178*VAR22*x + CONST221*VAR03) + g_10*(CONST000*y*(CONST031*VAR08*VAR24 + CONST031*VAR22 + CONST194*VAR04 + CONST194*VAR06*VAR26) + CONST006*VAR16*(-CONST154*VAR06 + CONST154*VAR24) + CONST009*VAR14*(CONST033*VAR26 + CONST192*VAR08)) + g_11*(CONST001*VAR17*(-CONST116*VAR06*z - CONST143*VAR08*VAR25 + CONST167*VAR23) + CONST004*VAR15*(CONST134*VAR08*z - CONST180*VAR25) + CONST013*VAR21 + CONST183*VAR06*VAR25 + CONST201*VAR04*z + CONST223*VAR08*VAR23) + g_12*(CONST000*y*(CONST097*VAR06*VAR26 + CONST097*VAR08*VAR24 + CONST199*VAR04 + CONST199*VAR22) + CONST006*VAR16*(CONST062*VAR08*VAR26 - CONST182*VAR06 - CONST182*VAR24)) + g_13*(CONST001*VAR17*(CONST019*VAR08*VAR25 + CONST035*VAR23 + CONST113*VAR06*z) + CONST065*VAR08*VAR23 - CONST184*VAR06*VAR25 + CONST186*VAR04*z + CONST224*VAR21) + g_15*(-CONST078*VAR06*VAR25 + CONST127*VAR08*VAR23 + CONST178*VAR04*z - CONST221*VAR21) + g_2*(CONST137*VAR05*y*z + CONST137*VAR23*x*y + CONST204*VAR07*VAR25*y) + g_3*(CONST001*VAR17*(CONST019*VAR07*VAR26 + CONST035*VAR05 + CONST114*VAR24*x) + CONST045*VAR03 + CONST066*VAR05*VAR26 + CONST184*VAR22*x - CONST186*VAR07*VAR24) + g_4*(-CONST090*VAR05*y*z + CONST187*VAR07*VAR16*z + x*(CONST090*VAR23*y - CONST187*VAR16*VAR25)) + g_5*(CONST001*VAR17*(CONST116*VAR24*x + CONST143*VAR07*VAR26 - CONST167*VAR05) + CONST004*VAR15*(-CONST134*VAR26*x + CONST180*VAR07) + CONST015*VAR05*VAR26 + CONST041*VAR07*VAR24 + CONST139*VAR03 - CONST201*VAR22*x) + g_6*(-CONST138*VAR05*y*z + VAR07*(CONST155*VAR25*y + CONST226*VAR16*z) + x*(CONST067*VAR14*z - CONST138*VAR23*y + CONST226*VAR16*VAR25)) + g_7*(CONST219*VAR03 + VAR05*(CONST142*VAR17 + CONST200*VAR26) + VAR07*(CONST152*VAR15 - CONST152*VAR17*VAR26 + CONST200*VAR24) + x*(CONST085*VAR13 + CONST140*VAR17*VAR24 + CONST152*VAR15*VAR26 + CONST219*VAR22)) + g_8*(CONST026*VAR12 - CONST052*VAR16*VAR24 + CONST084*VAR14*VAR26 + CONST179*VAR04*y + CONST181*VAR22*y + VAR06*(-CONST052*VAR16 + CONST129*VAR26*y) + VAR08*(CONST083*VAR14 + CONST128*VAR24*y + CONST148*VAR16*VAR26)) + g_9*(CONST219*VAR21 + VAR23*(CONST142*VAR17 + CONST200*VAR08) + VAR25*(CONST073*VAR08*VAR17 + CONST152*VAR15 + CONST200*VAR06) + z*(CONST086*VAR13 + CONST091*VAR04 + CONST142*VAR06*VAR17 + CONST152*VAR08*VAR15)) +g_z = g_0*(-CONST049*VAR05*VAR26 + CONST131*VAR07*VAR24 - CONST151*VAR22*x + CONST211*VAR03) + g_1*y*(-CONST050*VAR23*x + CONST054*VAR07*VAR25 + CONST071*VAR05*z) + g_10*(CONST057*VAR04*z + CONST061*VAR13*z + CONST189*VAR21 + CONST198*VAR15*VAR25 + CONST212*VAR08*VAR23 + VAR17*(CONST093*VAR08*VAR25 - CONST107*VAR23 + CONST162*VAR06*z)) + g_11*(VAR14*(-CONST133*VAR26 + CONST134*VAR08) + VAR16*(CONST048*VAR24 - CONST116*VAR06 - CONST217*VAR08*VAR26) + y*(CONST055*VAR22 + CONST136*VAR06*VAR26 + CONST183*VAR08*VAR24 + CONST201*VAR04)) + g_12*(CONST011*VAR21 + CONST092*VAR04*z + CONST182*VAR06*VAR25 + CONST202*VAR08*VAR23 + VAR15*(CONST098*VAR08*z + CONST122*VAR25) + VAR17*(-CONST102*VAR08*VAR25 + CONST121*VAR23 + CONST160*VAR06*z)) + g_13*(VAR16*(CONST043*VAR08*VAR26 + CONST113*VAR06 + CONST113*VAR24) + y*(CONST028*VAR08*VAR24 + CONST089*VAR06*VAR26 + CONST169*VAR22 + CONST186*VAR04)) + g_14*(-CONST149*VAR08*VAR23 + CONST191*VAR04*z + CONST215*VAR21 + VAR17*(-CONST109*VAR06*z - CONST149*VAR23 + CONST193*VAR08*VAR25)) + g_15*y*(CONST178*VAR04 - CONST178*VAR22 - CONST185*VAR06*VAR26 + CONST190*VAR08*VAR24) + g_16*(CONST050*VAR08*VAR23 - CONST131*VAR06*VAR25 + CONST151*VAR04*z - CONST211*VAR21) + g_2*(CONST096*VAR03 + VAR05*(-CONST149*VAR17 - CONST177*VAR26) + VAR07*(CONST070*VAR24 + CONST193*VAR17*VAR26) + x*(-CONST109*VAR17*VAR24 + CONST177*VAR22)) + g_3*(VAR16*(CONST030*VAR07*z + CONST197*VAR25*x) + y*(CONST077*VAR23*x + CONST108*VAR05*z + CONST114*VAR07*VAR25)) + g_4*(CONST080*VAR03 + VAR05*(-CONST146*VAR17 + CONST213*VAR26) + VAR07*(CONST027*VAR24 + CONST111*VAR15) + x*(CONST102*VAR17*VAR24 + CONST135*VAR15*VAR26 - CONST195*VAR22)) + g_5*(-CONST056*VAR14*x*z + VAR16*(CONST082*VAR07*z + CONST209*VAR25*x) + y*(CONST023*VAR05*z + CONST120*VAR07*VAR25 - CONST124*VAR23*x)) + g_6*(CONST225*VAR03 + VAR05*(-CONST162*VAR17 + CONST205*VAR26) + VAR07*(CONST047*VAR17*VAR26 + CONST118*VAR15 + CONST194*VAR24) + x*(CONST115*VAR15*VAR26 - CONST161*VAR13 + CONST206*VAR17*VAR24 + CONST210*VAR22)) + g_7*(CONST173*VAR05*y*z + VAR07*(-CONST052*VAR16*z + CONST147*VAR25*y) + x*(-CONST052*VAR16*VAR25 + CONST117*VAR14*z + CONST173*VAR23*y)) + g_8*(CONST007*VAR04*z + CONST007*VAR21 - CONST052*VAR15*VAR25 + CONST130*VAR17*VAR23 + CONST157*VAR13*z + VAR06*(CONST024*VAR25 + CONST129*VAR17*z) + VAR08*(CONST024*VAR23 - CONST052*VAR15*z + CONST052*VAR17*VAR25)) + g_9*(CONST001*VAR26*(CONST105*VAR08*VAR16 + CONST153*VAR14 + CONST200*VAR06*y) + CONST004*VAR24*(CONST063*VAR16 + CONST200*VAR08*y) + CONST025*VAR12 + CONST063*VAR06*VAR16 + CONST091*VAR04*y + CONST153*VAR08*VAR14 + CONST163*VAR22*y) diff --git a/notebooks/bwd_implementations/bwd_9.py b/notebooks/bwd_implementations/bwd_9.py new file mode 100644 index 0000000..5f7fb30 --- /dev/null +++ b/notebooks/bwd_implementations/bwd_9.py @@ -0,0 +1,336 @@ +# -------------------- variable and constant definitions +CONST000 = 1.59908344719522 +CONST001 = 2.00000000000000 +CONST002 = 3.00000000000000 +CONST003 = 4.00000000000000 +CONST004 = 5.00000000000000 +CONST005 = 6.39633378878088 +CONST006 = 7.00000000000000 +CONST007 = 8.63855507530412 +CONST008 = 9.59450068317133 +CONST009 = 6.39633378878088 +CONST010 = 9.82028453158308 +CONST011 = 12.7926675775618 +CONST012 = 12.7926675775618 +CONST013 = 14.7304267973746 +CONST014 = 15.5493991355474 +CONST015 = 14.3917510247570 +CONST016 = 17.3847567381802 +CONST017 = 15.0007324039945 +CONST018 = 14.4550674370400 +CONST019 = 14.4550674370400 +CONST020 = 13.3827919767794 +CONST021 = 23.8930627690618 +CONST022 = 23.8930627690618 +CONST023 = 27.0429549260581 +CONST024 = 29.2403830344269 +CONST025 = 30.0014648079890 +CONST026 = 30.9062342012093 +CONST027 = 29.2403830344269 +CONST028 = 38.3780027326853 +CONST029 = 39.2811381263323 +CONST030 = 39.2811381263323 +CONST031 = 39.2300904918661 +CONST032 = 42.9079114754785 +CONST033 = 10.7269778688696 +CONST034 = 54.0859098521163 +CONST035 = 57.8202697481601 +CONST036 = 58.9217071894985 +CONST037 = 57.8202697481601 +CONST038 = 60.0029296159779 +CONST039 = 62.4530292249704 +CONST040 = 64.3618672132178 +CONST041 = 68.5747767039748 +CONST042 = 69.1084406024329 +CONST043 = 77.2655855030233 +CONST044 = 78.5622762526647 +CONST045 = 85.8158229509570 +CONST046 = 85.8158229509570 +CONST047 = 90.1063824390370 +CONST048 = 96.7518168434061 +CONST049 = 104.749701670220 +CONST050 = 107.062335814235 +CONST051 = 108.171819704233 +CONST052 = 108.171819704233 +CONST053 = -1935.03633686812 +CONST054 = 115.640539496320 +CONST055 = 115.640539496320 +CONST056 = 117.843414378997 +CONST057 = 117.843414378997 +CONST058 = 115.640539496320 +CONST059 = 120.005859231956 +CONST060 = 2176.91587897664 +CONST061 = 2176.91587897664 +CONST062 = 137.149553407950 +CONST063 = 150.007324039945 +CONST064 = 150.007324039945 +CONST065 = -1892.23403121978 +CONST066 = -1885.49463006395 +CONST067 = 173.460809244480 +CONST068 = -1873.59087674911 +CONST069 = 176.765121568496 +CONST070 = 10.7269778688696 +CONST071 = 180.008788847934 +CONST072 = 187.359087674911 +CONST073 = 191.144502152495 +CONST074 = 13.5214774630291 +CONST075 = 196.405690631662 +CONST076 = 205.957975082297 +CONST077 = 216.343639408465 +CONST078 = 216.343639408465 +CONST079 = 4326.87278816930 +CONST080 = 233.923064275415 +CONST081 = 233.923064275415 +CONST082 = 240.011718463912 +CONST083 = 241.879542108515 +CONST084 = 241.879542108515 +CONST085 = 255.853351551235 +CONST086 = 255.853351551235 +CONST087 = 257.447468852871 +CONST088 = 257.447468852871 +CONST089 = 2312.81078992641 +CONST090 = 270.429549260581 +CONST091 = 289.101348740801 +CONST092 = 294.608535947493 +CONST093 = 300.014648079890 +CONST094 = 300.014648079890 +CONST095 = 2356.86828757994 +CONST096 = 314.249105010659 +CONST097 = 13.0937127087774 +CONST098 = 324.515459112698 +CONST099 = -3747.18175349822 +CONST100 = 6.39633378878088 +CONST101 = 353.530243136991 +CONST102 = 374.718175349822 +CONST103 = 374.718175349822 +CONST104 = 392.811381263323 +CONST105 = 404.741888237121 +CONST106 = 411.915950164594 +CONST107 = 412.451950326490 +CONST108 = 432.687278816930 +CONST109 = 435.383175795328 +CONST110 = 435.383175795327 +CONST111 = 462.562157985281 +CONST112 = 462.562157985281 +CONST113 = -1571.24552505329 +CONST114 = 483.759084217031 +CONST115 = 511.706703102471 +CONST116 = 562.077263024733 +CONST117 = 578.202697481601 +CONST118 = 589.217071894985 +CONST119 = -1451.27725265109 +CONST120 = 4.91014226579154 +CONST121 = -1451.27725265109 +CONST122 = 600.029296159779 +CONST123 = 600.029296159779 +CONST124 = -1440.07031078347 +CONST125 = 628.498210021317 +CONST126 = 628.498210021318 +CONST127 = 630.744677073259 +CONST128 = 649.030918225395 +CONST129 = -1387.68647395584 +CONST130 = -1387.68647395584 +CONST131 = -1373.05316721531 +CONST132 = -1338.01151506746 +CONST133 = 725.638626325546 +CONST134 = -1298.06183645079 +CONST135 = 785.622762526647 +CONST136 = 785.622762526647 +CONST137 = 788.430846341574 +CONST138 = -1249.06058449941 +CONST139 = -1228.09608744593 +CONST140 = -1228.09608744593 +CONST141 = 823.831900329187 +CONST142 = -3245.15459112698 +CONST143 = -1178.43414378997 +CONST144 = 870.766351590655 +CONST145 = 870.766351590655 +CONST146 = 900.043944239669 +CONST147 = -1124.15452604947 +CONST148 = 936.795438374555 +CONST149 = -3153.72338536630 +CONST150 = 960.046873855647 +CONST151 = 960.046873855647 +CONST152 = 967.518168434061 +CONST153 = -1081.71819704233 +CONST154 = 967.518168434061 +CONST155 = -1060.59072941097 +CONST156 = 1023.41340620494 +CONST157 = 1023.41340620494 +CONST158 = 1060.59072941097 +CONST159 = -967.518168434061 +CONST160 = 1081.71819704233 +CONST161 = -960.046873855647 +CONST162 = 3153.72338536630 +CONST163 = -936.795438374555 +CONST164 = 1124.15452604947 +CONST165 = -900.043944239669 +CONST166 = 1156.40539496320 +CONST167 = 1178.43414378997 +CONST168 = -2902.55450530218 +CONST169 = 3245.15459112698 +CONST170 = 11.2632978048796 +CONST171 = -785.622762526647 +CONST172 = -785.622762526647 +CONST173 = -767.560054653706 +CONST174 = 1298.06183645079 +CONST175 = 1338.01151506746 +CONST176 = -693.843236977922 +CONST177 = -693.843236977921 +CONST178 = -686.526583607656 +CONST179 = -669.005757533731 +CONST180 = -669.005757533731 +CONST181 = 1387.68647395584 +CONST182 = -649.030918225395 +CONST183 = -630.744677073259 +CONST184 = -628.498210021318 +CONST185 = -628.498210021317 +CONST186 = -600.029296159779 +CONST187 = -589.217071894985 +CONST188 = -578.202697481601 +CONST189 = 15.5493991355474 +CONST190 = -562.077263024733 +CONST191 = 1500.07324039945 +CONST192 = -480.023436927823 +CONST193 = -480.023436927823 +CONST194 = 1571.24552505329 +CONST195 = -462.562157985281 +CONST196 = -450.021972119834 +CONST197 = -412.451950326490 +CONST198 = -409.365362481977 +CONST199 = -409.365362481976 +CONST200 = -404.741888237121 +CONST201 = -392.811381263323 +CONST202 = -383.780027326853 +CONST203 = -383.780027326853 +CONST204 = 1672.51439383433 +CONST205 = -374.718175349822 +CONST206 = -353.530243136991 +CONST207 = -2400.11718463912 +CONST208 = 3747.18175349822 +CONST209 = -346.921618488961 +CONST210 = -346.921618488961 +CONST211 = -343.263291803828 +CONST212 = -338.631358951921 +CONST213 = -338.631358951921 +CONST214 = -324.515459112698 +CONST215 = -315.372338536630 +CONST216 = -314.249105010659 +CONST217 = -2356.86828757994 +CONST218 = -300.014648079890 +CONST219 = -294.608535947493 +CONST220 = -289.101348740801 +CONST221 = -270.013183271901 +CONST222 = -2312.81078992641 +CONST223 = 1800.08788847934 +CONST224 = -241.879542108515 +CONST225 = -240.011718463912 +CONST226 = -241.879542108515 +CONST227 = -4326.87278816930 +CONST228 = -216.343639408465 +CONST229 = -210.010253655923 +CONST230 = -204.682681240988 +CONST231 = -204.682681240988 +CONST232 = -204.682681240988 +CONST233 = -196.405690631662 +CONST234 = -191.144502152495 +CONST235 = -191.890013663426 +CONST236 = -191.890013663427 +CONST237 = -187.359087674911 +CONST238 = -180.008788847934 +CONST239 = -176.765121568496 +CONST240 = 1873.59087674911 +CONST241 = 1873.59087674911 +CONST242 = -173.460809244480 +CONST243 = 1885.49463006395 +CONST244 = -162.257729556349 +CONST245 = -156.920361967464 +CONST246 = -156.920361967464 +CONST247 = 1892.23403121978 +CONST248 = -150.007324039945 +CONST249 = -144.550674370400 +CONST250 = -137.149553407950 +CONST251 = -135.214774630291 +CONST252 = -127.926675775618 +CONST253 = -127.926675775618 +CONST254 = -120.939771054258 +CONST255 = -120.005859231956 +CONST256 = -120.939771054258 +CONST257 = -117.843414378997 +CONST258 = -117.843414378997 +CONST259 = -115.640539496320 +CONST260 = -115.640539496320 +CONST261 = 1935.03633686812 +CONST262 = -2163.43639408465 +CONST263 = -114.421097267943 +CONST264 = -108.171819704233 +CONST265 = -107.062335814235 +CONST266 = -108.171819704233 +CONST267 = -104.749701670220 +CONST268 = -96.7518168434061 +CONST269 = -96.7518168434061 +CONST270 = -90.0043944239669 +CONST271 = -90.1063824390370 +CONST272 = -80.2967518606762 +CONST273 = -78.4601809837321 +CONST274 = -78.4601809837321 +CONST275 = -77.2655855030233 +CONST276 = -78.5622762526647 +CONST277 = -68.5747767039748 +CONST278 = -63.9633378878088 +CONST279 = -62.4530292249704 +CONST280 = -61.8124684024186 +CONST281 = -60.0029296159779 +CONST282 = -63.9633378878088 +CONST283 = -58.9217071894985 +CONST284 = -57.8202697481601 +CONST285 = -57.8202697481601 +CONST286 = -48.3759084217030 +CONST287 = -48.3759084217031 +CONST288 = -39.2811381263323 +CONST289 = -38.6327927515116 +CONST290 = -39.2811381263323 +CONST291 = -30.9062342012093 +CONST292 = -30.0014648079890 +CONST293 = -30.0014648079890 +CONST294 = -27.6433762409732 +CONST295 = -17.3847567381802 +CONST296 = -15.0007324039945 +CONST297 = -14.7304267973746 +CONST298 = -13.5214774630291 +CONST299 = -13.0937127087774 +CONST300 = -13.3827919767794 +CONST301 = -9.82028453158308 +CONST302 = -4.91014226579154 +CONST303 = 2046.82681240988 +VAR00 = x**10 +VAR01 = x**9 +VAR02 = x**8 +VAR03 = x**7 +VAR04 = x**6 +VAR05 = x**5 +VAR06 = x**4 +VAR07 = x**3 +VAR08 = x**2 +VAR09 = y**10 +VAR10 = y**9 +VAR11 = y**8 +VAR12 = y**7 +VAR13 = y**6 +VAR14 = y**5 +VAR15 = y**4 +VAR16 = y**3 +VAR17 = y**2 +VAR18 = z**10 +VAR19 = z**9 +VAR20 = z**8 +VAR21 = z**7 +VAR22 = z**6 +VAR23 = z**5 +VAR24 = z**4 +VAR25 = z**3 +VAR26 = z**2 +# -------------------- kernel implementations +g_x = g_0*(CONST021*VAR20 + CONST022*VAR02 + CONST179*VAR04*VAR26 + CONST180*VAR08*VAR22 + CONST204*VAR06*VAR24) + g_1*y*(CONST065*VAR08*VAR23 - CONST149*VAR06*VAR25 + CONST183*VAR04*z - CONST271*VAR21) + g_10*(CONST012*VAR21*x + VAR23*(CONST028*VAR07 + CONST203*VAR17*x) + VAR25*(CONST028*VAR05 + CONST157*VAR15*x + CONST173*VAR07*VAR17) + z*(CONST011*VAR03 + CONST157*VAR07*VAR15 + CONST198*VAR13*x + CONST202*VAR05*VAR17)) + g_11*(CONST150*VAR07*VAR14 + CONST250*VAR12*x + VAR16*(CONST093*VAR24*x + CONST165*VAR05 + CONST186*VAR07*VAR26) + y*(CONST059*VAR03 + CONST071*VAR05*VAR26 + CONST281*VAR22*x)) + g_12*(VAR23*(CONST257*VAR17*x - CONST290*VAR07) + VAR25*(CONST044*VAR05 + CONST143*VAR07*VAR17 - CONST172*VAR15*x) + z*(CONST155*VAR05*VAR17 + CONST184*VAR13*x - CONST217*VAR07*VAR15 - CONST288*VAR03)) + g_13*(VAR14*(CONST129*VAR26*x - CONST195*VAR07) + VAR16*(CONST166*VAR24*x + CONST176*VAR05 - CONST222*VAR07*VAR26) + y*(CONST188*VAR07*VAR24 + CONST209*VAR05*VAR26 - CONST259*VAR03 + CONST259*VAR22*x)) + g_14*(CONST042*VAR03*z + CONST268*VAR07*VAR23 + CONST294*VAR21*x + VAR15*(CONST053*VAR25*x + CONST261*VAR07*z) + VAR17*(CONST119*VAR05*z + CONST144*VAR23*x + CONST152*VAR07*VAR25)) + g_15*(VAR16*(CONST068*VAR24*x - CONST099*VAR07*VAR26 + CONST205*VAR05) + y*(CONST050*VAR03 + CONST147*VAR05*VAR26 - CONST205*VAR22*x)) + g_16*(CONST214*VAR05*VAR25 - CONST264*VAR03*z + CONST264*VAR07*VAR23 - CONST275*VAR21*x + VAR17*(CONST079*VAR07*VAR25 + CONST134*VAR05*z + CONST134*VAR23*x)) + g_17*y*(CONST065*VAR05*VAR26 - CONST149*VAR07*VAR24 + CONST183*VAR22*x - CONST271*VAR03) + g_18*(CONST132*VAR05*VAR25 + CONST175*VAR07*VAR23 - CONST234*VAR03*z + CONST234*VAR21*x) + g_2*(CONST002*VAR08*(CONST034*VAR22 + CONST153*VAR17*VAR24) + CONST004*VAR06*(CONST023*VAR24 - CONST182*VAR17*VAR26) + CONST006*VAR04*(CONST289*VAR26 + CONST291*VAR17) - CONST228*VAR17*VAR22 - CONST295*VAR02 + CONST298*VAR20) + g_3*(VAR16*(-CONST068*VAR06*z + CONST099*VAR08*VAR25 + CONST103*VAR23) + y*(CONST116*VAR08*VAR23 - CONST163*VAR06*VAR25 + CONST190*VAR04*z + CONST272*VAR21)) + g_4*(CONST007*VAR20 + CONST014*VAR02 + CONST254*VAR06*VAR24 + CONST269*VAR04*VAR26 + VAR15*(CONST114*VAR06 + CONST114*VAR24 + CONST168*VAR08*VAR26) + VAR17*(CONST060*VAR06*VAR26 + CONST133*VAR08*VAR24 + CONST212*VAR04 + CONST224*VAR22)) + g_5*(VAR14*(CONST130*VAR08*z - CONST195*VAR25) + VAR16*(CONST195*VAR23 - CONST222*VAR06*z) + y*(CONST067*VAR08*VAR23 + CONST200*VAR04*z + CONST220*VAR06*VAR25 - CONST284*VAR21)) + g_6*(CONST002*VAR08*(CONST201*VAR15*VAR26 - CONST219*VAR17*VAR24 + CONST267*VAR13 + CONST299*VAR22) + CONST004*VAR06*(CONST036*VAR17*VAR26 - CONST233*VAR15 + CONST301*VAR24) + CONST187*VAR15*VAR24 + CONST197*VAR04*VAR17 - CONST216*VAR13*VAR26 - CONST239*VAR17*VAR22 - CONST297*VAR02 + CONST302*VAR20) + g_7*(CONST002*VAR08*(-CONST186*VAR16*VAR25 + CONST192*VAR14*z + CONST270*VAR23*y) + CONST004*VAR06*(-CONST218*VAR16*z + CONST270*VAR25*y) + CONST193*VAR14*VAR25 - CONST218*VAR16*VAR23 + CONST229*VAR04*y*z - CONST250*VAR12*z + CONST292*VAR21*y) + g_8*(CONST000*VAR20 + CONST002*VAR08*(CONST005*VAR22 + CONST115*VAR15*VAR26 + CONST230*VAR13 + CONST235*VAR17*VAR24) + CONST004*VAR06*(CONST008*VAR24 + CONST085*VAR15 + CONST235*VAR17*VAR26) + CONST006*VAR04*(CONST009*VAR26 + CONST278*VAR17) + CONST015*VAR02 + CONST024*VAR11 + CONST085*VAR15*VAR24 + CONST231*VAR13*VAR26 + CONST278*VAR17*VAR22) + g_9*(CONST245*VAR12*x + VAR14*(CONST141*VAR07 + CONST141*VAR26*x) + VAR16*(CONST131*VAR07*VAR26 + CONST178*VAR05 + CONST178*VAR24*x) + y*(CONST045*VAR03 + CONST046*VAR22*x + CONST087*VAR05*VAR26 + CONST088*VAR07*VAR24)) +g_y = CONST001*g_16*y*(CONST160*VAR06*VAR25 + CONST182*VAR08*VAR23 + CONST228*VAR04*z - CONST291*VAR21) + g_1*(-CONST183*VAR05*VAR25 + CONST183*VAR07*VAR23 + CONST271*VAR03*z - CONST271*VAR21*x) + g_10*(CONST252*VAR21*y + VAR23*(CONST157*VAR16 + CONST203*VAR08*y) + VAR25*(CONST140*VAR14 + CONST202*VAR06*y + CONST303*VAR08*VAR16) + z*(CONST080*VAR12 + CONST139*VAR08*VAR14 + CONST157*VAR06*VAR16 + CONST252*VAR04*y)) + g_11*(CONST002*VAR17*(CONST064*VAR08*VAR24 + CONST248*VAR04 + CONST248*VAR06*VAR26 - CONST248*VAR22) + CONST004*VAR15*(CONST082*VAR06 + CONST225*VAR24) + CONST006*VAR13*(CONST277*VAR08 - CONST277*VAR26) + CONST017*VAR02 + CONST025*VAR04*VAR26 + CONST293*VAR08*VAR22 + CONST296*VAR20) + g_12*(CONST056*VAR21*y + VAR23*(CONST171*VAR16 + CONST257*VAR08*y) + VAR25*(-CONST113*VAR08*VAR16 - CONST185*VAR14 + CONST187*VAR06*y) + z*(CONST066*VAR08*VAR14 + CONST206*VAR04*y - CONST217*VAR06*VAR16)) + g_13*(CONST002*VAR17*(CONST117*VAR06*VAR26 + CONST117*VAR08*VAR24 + CONST259*VAR04 + CONST260*VAR22) + CONST004*VAR15*(CONST055*VAR06 + CONST055*VAR24 + CONST176*VAR08*VAR26) + CONST018*VAR20 + CONST019*VAR02 + CONST249*VAR06*VAR24 + CONST284*VAR04*VAR26 + CONST285*VAR08*VAR22) + g_14*(CONST001*y*(CONST083*VAR06*VAR25 + CONST109*VAR08*VAR23 + CONST226*VAR04*z + CONST286*VAR21) + CONST003*VAR16*(CONST114*VAR06*z + CONST159*VAR08*VAR25 - CONST269*VAR23)) + g_15*(CONST002*VAR17*(CONST039*VAR22 - CONST163*VAR06*VAR26 + CONST163*VAR08*VAR24 + CONST279*VAR04) + CONST020*VAR02 + CONST237*VAR04*VAR26 - CONST237*VAR08*VAR22 + CONST300*VAR20) + g_17*(CONST137*VAR06*VAR24 + CONST170*VAR02 + CONST170*VAR20 + CONST215*VAR04*VAR26 + CONST215*VAR08*VAR22) + g_2*(CONST108*VAR22*x*y - CONST134*VAR05*VAR26*y + CONST262*VAR07*VAR24*y + CONST280*VAR03*y) + g_3*(CONST002*VAR17*(CONST103*VAR23*x + CONST138*VAR07*VAR25 - CONST205*VAR05*z) - CONST237*VAR05*VAR25 - CONST237*VAR07*VAR23 + CONST272*VAR03*z + CONST272*VAR21*x) + g_4*(CONST001*y*(CONST110*VAR05*VAR26 - CONST224*VAR07*VAR24 + CONST224*VAR22*x + CONST287*VAR03) + CONST003*VAR16*(CONST114*VAR24*x + CONST159*VAR07*VAR26 - CONST269*VAR05)) + g_5*(CONST002*VAR17*(CONST112*VAR05*z + CONST195*VAR23*x) + CONST004*VAR15*(CONST195*VAR07*z - CONST195*VAR25*x) + CONST037*VAR07*VAR23 + CONST284*VAR05*VAR25 - CONST284*VAR21*x + CONST285*VAR03*z) + g_6*(CONST258*VAR03*y + VAR05*(CONST057*VAR26*y - CONST171*VAR16) + VAR07*(CONST113*VAR16*VAR26 + CONST185*VAR14 - CONST187*VAR24*y) + x*(-CONST066*VAR14*VAR26 - CONST206*VAR22*y + CONST217*VAR16*VAR24)) + g_7*(CONST292*VAR03*z + VAR05*(-CONST165*VAR17*z + CONST270*VAR25) + VAR07*(CONST207*VAR15*z + CONST223*VAR17*VAR25 + CONST270*VAR23) + x*(CONST151*VAR13*z - CONST165*VAR17*VAR23 + CONST207*VAR15*VAR25 + CONST292*VAR21)) + g_8*(CONST253*VAR03*y + VAR05*(CONST156*VAR16 + CONST202*VAR26*y) + VAR07*(CONST139*VAR14 + CONST202*VAR24*y + CONST303*VAR16*VAR26) + x*(CONST081*VAR12 + CONST140*VAR14*VAR26 + CONST156*VAR16*VAR24 + CONST253*VAR22*y)) + g_9*(CONST002*VAR17*(CONST211*VAR06*VAR26 + CONST211*VAR08*VAR24 + CONST263*VAR04 + CONST263*VAR22) + CONST004*VAR15*(CONST076*VAR06 + CONST076*VAR24 + CONST106*VAR08*VAR26) + CONST006*VAR13*(CONST273*VAR26 + CONST274*VAR08) + CONST031*VAR11 + CONST032*VAR04*VAR26 + CONST032*VAR08*VAR22 + CONST033*VAR20 + CONST040*VAR06*VAR24 + CONST070*VAR02) +g_z = g_0*(CONST132*VAR07*VAR23 + CONST175*VAR05*VAR25 + CONST234*VAR03*z - CONST234*VAR21*x) + g_1*y*(-CONST065*VAR05*VAR26 + CONST149*VAR07*VAR24 - CONST183*VAR22*x + CONST271*VAR03) + g_10*(CONST000*VAR02 + CONST002*VAR26*(CONST100*VAR04 + CONST115*VAR08*VAR15 + CONST231*VAR13 + CONST235*VAR06*VAR17) + CONST004*VAR24*(CONST008*VAR06 + CONST086*VAR15 + CONST236*VAR08*VAR17) + CONST006*VAR22*(CONST005*VAR08 + CONST282*VAR17) + CONST015*VAR20 + CONST027*VAR11 + CONST086*VAR06*VAR15 + CONST232*VAR08*VAR13 + CONST282*VAR04*VAR17) + g_11*(CONST161*VAR14*VAR25 - CONST250*VAR12*z + VAR16*(CONST123*VAR08*VAR25 - CONST165*VAR23 + CONST218*VAR06*z) + y*(CONST038*VAR04*z + CONST238*VAR08*VAR23 + CONST255*VAR21)) + g_12*(CONST002*VAR26*(CONST097*VAR04 - CONST201*VAR08*VAR15 + CONST219*VAR06*VAR17 - CONST267*VAR13) + CONST004*VAR24*(CONST233*VAR15 + CONST283*VAR08*VAR17 - CONST301*VAR06) + CONST107*VAR17*VAR22 - CONST187*VAR06*VAR15 + CONST216*VAR08*VAR13 + CONST239*VAR04*VAR17 + CONST297*VAR20 - CONST302*VAR02) + g_13*(VAR14*(CONST129*VAR08*z - CONST195*VAR25) + VAR16*(CONST166*VAR06*z + CONST177*VAR23 - CONST222*VAR08*VAR25) + y*(CONST188*VAR06*VAR25 + CONST210*VAR08*VAR23 + CONST260*VAR04*z - CONST260*VAR21)) + g_14*(CONST007*VAR02 + CONST189*VAR20 + CONST256*VAR06*VAR24 + CONST269*VAR08*VAR22 + VAR15*(CONST114*VAR06 + CONST114*VAR24 + CONST168*VAR08*VAR26) + VAR17*(CONST061*VAR08*VAR24 + CONST133*VAR06*VAR26 + CONST213*VAR22 + CONST226*VAR04)) + g_15*(VAR16*(-CONST068*VAR06*z + CONST099*VAR08*VAR25 + CONST103*VAR23) + y*(-CONST147*VAR08*VAR23 + CONST205*VAR04*z + CONST265*VAR21)) + g_16*(CONST074*VAR02 + CONST090*VAR08*VAR22 + CONST244*VAR04*VAR26 + CONST251*VAR06*VAR24 + CONST295*VAR20 + VAR17*(CONST078*VAR22 - CONST142*VAR06*VAR26 + CONST142*VAR08*VAR24 + CONST228*VAR04)) + g_17*y*(CONST065*VAR08*VAR23 - CONST149*VAR06*VAR25 + CONST183*VAR04*z - CONST271*VAR21) + g_18*(CONST021*VAR02 + CONST022*VAR20 + CONST179*VAR08*VAR22 + CONST180*VAR04*VAR26 + CONST204*VAR06*VAR24) + g_2*(CONST275*VAR03*z + VAR05*(CONST052*VAR25 - CONST134*VAR17*z) + VAR07*(-CONST214*VAR23 + CONST227*VAR17*VAR25) + x*(-CONST134*VAR17*VAR23 + CONST266*VAR21)) + g_3*(VAR16*(CONST099*VAR07*VAR26 - CONST205*VAR05 + CONST241*VAR24*x) + y*(CONST116*VAR05*VAR26 - CONST163*VAR07*VAR24 + CONST190*VAR22*x + CONST272*VAR03)) + g_4*(CONST042*VAR21*x + CONST269*VAR05*VAR25 + CONST294*VAR03*z + VAR15*(CONST053*VAR07*z + CONST261*VAR25*x) + VAR17*(CONST121*VAR23*x + CONST145*VAR05*z + CONST154*VAR07*VAR25)) + g_5*(VAR14*(-CONST130*VAR26*x + CONST195*VAR07) + VAR16*(CONST112*VAR05 + CONST222*VAR24*x) + y*(CONST091*VAR07*VAR24 + CONST105*VAR22*x + CONST242*VAR05*VAR26 + CONST285*VAR03)) + g_6*(VAR05*(CONST057*VAR17*z + CONST290*VAR25) + VAR07*(-CONST143*VAR17*VAR25 + CONST172*VAR15*z + CONST276*VAR23) + x*(-CONST155*VAR17*VAR23 - CONST184*VAR13*z + CONST217*VAR15*VAR25 + CONST288*VAR21)) + g_7*(CONST292*VAR03*y + VAR05*(-CONST218*VAR16 + CONST221*VAR26*y) + VAR07*(CONST192*VAR14 + CONST196*VAR24*y + CONST223*VAR16*VAR26) + x*(CONST124*VAR14*VAR26 + CONST191*VAR16*VAR24 + CONST229*VAR22*y - CONST250*VAR12)) + g_8*(CONST011*VAR03*z + VAR05*(CONST028*VAR25 + CONST202*VAR17*z) + VAR07*(CONST028*VAR23 + CONST157*VAR15*z + CONST173*VAR17*VAR25) + x*(CONST011*VAR21 + CONST156*VAR15*VAR25 + CONST199*VAR13*z + CONST202*VAR17*VAR23)) + g_9*(CONST246*VAR12*z + VAR14*(CONST141*VAR08*z + CONST141*VAR25) + VAR16*(CONST131*VAR08*VAR25 + CONST178*VAR06*z + CONST178*VAR23) + y*(CONST046*VAR04*z + CONST046*VAR21 + CONST087*VAR08*VAR23 + CONST088*VAR06*VAR25)) diff --git a/notebooks/direct_sph_harm/l_10.json b/notebooks/direct_sph_harm/l_10.json new file mode 100644 index 0000000..fd9a7dc --- /dev/null +++ b/notebooks/direct_sph_harm/l_10.json @@ -0,0 +1,30 @@ +{ + "fwd": [ + "27.2034486491732*x**9*z - 326.441383790078*x**7*z**3 + 685.526905959165*x**5*z**5 - 326.441383790078*x**3*z**7 + 27.2034486491732*x*z**9", + "y*(12.1657520803952*x**9 - 437.967074894228*x**7*z**2 + 1532.8847621298*x**5*z**4 - 1021.9231747532*x**3*z**6 + 109.491768723557*x*z**8)", + "15.7883647328499*x**9*z - 94.7301883970997*x**7*z**3 + 94.7301883970997*x**3*z**7 - 15.7883647328499*x*z**9 + y**2*(-284.190565191299*x**7*z + 1989.33395633909*x**5*z**3 - 1989.33395633909*x**3*z**5 + 284.190565191299*x*z**7)", + "y**3*(-77.3468749368713*x**7 + 1624.2843736743*x**5*z**2 - 2707.14062279049*x**3*z**4 + 541.428124558099*x*z**6) + y*(14.5025390506634*x**9 - 290.050781013267*x**7*z**2 + 203.035546709287*x**5*z**4 + 406.071093418574*x**3*z**6 - 101.517773354644*x*z**8)", + "10.5521471197994*x**9*z + x**7*(-337.668707833581*y**2*z - 14.0695294930659*z**3) + x**5*(787.893651611688*y**4*z + 787.893651611688*y**2*z**3 - 49.2433532257305*z**5) + x**3*(-2626.31217203896*y**4*z**3 + 787.893651611688*y**2*z**5 - 14.0695294930659*z**7) + x*(787.893651611688*y**4*z**5 - 337.668707833581*y**2*z**7 + 10.5521471197994*z**9)", + "y**5*(176.178376404427*x**5 - 1761.78376404427*x**3*z**2 + 880.891882022136*x*z**4) + y**3*(-146.815313670356*x**7 + 1321.3378230332*x**5*z**2 + 734.07656835178*x**3*z**4 - 734.07656835178*x*z**6) + y*(15.7302121789667*x**9 - 125.841697431734*x**7*z**2 - 220.222970505534*x**5*z**4 + 78.6510608948335*x*z**8)", + "6.632439808434*x**9*z + x**7*(-278.562471954228*y**2*z + 13.264879616868*z**3) + x**5*(1114.24988781691*y**4*z - 278.562471954228*y**2*z**3) + x**3*(-742.833258544608*y**6*z + 278.562471954228*y**2*z**5 - 13.264879616868*z**7) + x*(742.833258544608*y**6*z**3 - 1114.24988781691*y**4*z**5 + 278.562471954228*y**2*z**7 - 6.632439808434*z**9)", + "y**7*(-150.074981259369*x**3 + 450.224943778107*x*z**2) + y**5*(393.946825805844*x**5 - 787.893651611688*x**3*z**2 - 1181.84047741753*x*z**4) + y**3*(-196.973412902922*x**7 + 196.973412902922*x**5*z**2 + 984.86706451461*x**3*z**4 + 590.920238708766*x*z**6) + y*(16.4144510752435*x**9 - 98.486706451461*x**5*z**4 - 131.315608601948*x**3*z**6 - 49.2433532257305*x*z**8)", + "3.21913870529156*x**9*z + x**7*(-154.518657853995*y**2*z + 12.8765548211663*z**3) + x**5*(772.593289269975*y**4*z - 463.555973561985*y**2*z**3 + 19.3148322317494*z**5) + x**3*(-824.099508554641*y**6*z + 1545.18657853995*y**4*z**3 - 463.555973561985*y**2*z**5 + 12.8765548211663*z**7) + x*(176.592751833137*y**8*z - 824.099508554641*y**6*z**3 + 772.593289269975*y**4*z**5 - 154.518657853995*y**2*z**7 + 3.21913870529156*z**9)", + "16.7271353825295*x**9*y + x**7*(-223.028471767059*y**3 + 66.9085415301178*y*z**2) + x**5*(535.268332240943*y**5 - 669.085415301178*y**3*z**2 + 100.362812295177*y*z**4) + x**3*(-305.867618423396*y**7 + 1070.53666448189*y**5*z**2 - 669.085415301178*y**3*z**4 + 66.9085415301178*y*z**6) + x*(33.9852909359329*y**9 - 305.867618423396*y**7*z**2 + 535.268332240943*y**5*z**4 - 223.028471767059*y**3*z**6 + 16.7271353825295*y*z**8)", + "-1.12774323743054*x**10 + x**8*(56.3871618715269*y**2 - 5.63871618715269*z**2) + x**6*(-300.731529981477*y**4 + 225.548647486108*y**2*z**2 - 11.2774323743054*z**4) + x**4*(360.877835977772*y**6 - 902.194589944431*y**4*z**2 + 338.322971229162*y**2*z**4 - 11.2774323743054*z**6) + x**2*(-103.107953136506*y**8 + 721.755671955545*y**6*z**2 - 902.194589944431*y**4*z**4 + 225.548647486108*y**2*z**6 - 5.63871618715269*z**8) + 4.58257569495584*y**10 - 103.107953136506*y**8*z**2 + 360.877835977772*y**6*z**4 - 300.731529981477*y**4*z**6 + 56.3871618715269*y**2*z**8 - 1.12774323743054*z**10", + "16.7271353825295*y*z**9 + z**7*(66.9085415301178*x**2*y - 223.028471767059*y**3) + z**5*(100.362812295177*x**4*y - 669.085415301178*x**2*y**3 + 535.268332240943*y**5) + z**3*(66.9085415301178*x**6*y - 669.085415301178*x**4*y**3 + 1070.53666448189*x**2*y**5 - 305.867618423396*y**7) + z*(16.7271353825295*x**8*y - 223.028471767059*x**6*y**3 + 535.268332240943*x**4*y**5 - 305.867618423396*x**2*y**7 + 33.9852909359329*y**9)", + "-1.60956935264578*x**10 + x**8*(77.2593289269976*y**2 - 4.82870805793735*z**2) + x**6*(-386.296644634988*y**4 + 154.518657853995*y**2*z**2 - 3.21913870529156*z**4) + x**4*(412.04975427732*y**6 - 386.296644634988*y**4*z**2 + 3.21913870529156*z**6) + x**2*(-88.2963759165686*y**8 + 386.296644634988*y**4*z**4 - 154.518657853995*y**2*z**6 + 4.82870805793735*z**8) + 88.2963759165686*y**8*z**2 - 412.04975427732*y**6*z**4 + 386.296644634988*y**4*z**6 - 77.2593289269975*y**2*z**8 + 1.60956935264578*z**10", + "y**7*(-450.224943778107*x**2*z + 150.074981259369*z**3) + y**5*(1181.84047741753*x**4*z + 787.893651611688*x**2*z**3 - 393.946825805844*z**5) + y**3*(-590.920238708766*x**6*z - 984.86706451461*x**4*z**3 - 196.973412902922*x**2*z**5 + 196.973412902922*z**7) + y*(49.2433532257305*x**8*z + 131.315608601948*x**6*z**3 + 98.486706451461*x**4*z**5 - 16.4144510752435*z**9)", + "-1.6581099521085*x**10 + x**8*(69.640617988557*y**2 + 4.9743298563255*z**2) + x**6*(-278.562471954228*y**4 - 278.562471954228*y**2*z**2 + 23.213539329519*z**4) + x**4*(185.708314636152*y**6 + 1392.81235977114*y**4*z**2 - 696.40617988557*y**2*z**4 + 23.213539329519*z**6) + x**2*(-1114.24988781691*y**6*z**2 + 1392.81235977114*y**4*z**4 - 278.562471954228*y**2*z**6 + 4.9743298563255*z**8) + 185.708314636152*y**6*z**4 - 278.562471954228*y**4*z**6 + 69.640617988557*y**2*z**8 - 1.6581099521085*z**10", + "y**5*(880.891882022136*x**4*z - 1761.78376404427*x**2*z**3 + 176.178376404427*z**5) + y**3*(-734.07656835178*x**6*z + 734.07656835178*x**4*z**3 + 1321.3378230332*x**2*z**5 - 146.815313670356*z**7) + y*(78.6510608948335*x**8*z - 220.222970505534*x**4*z**5 - 125.841697431734*x**2*z**7 + 15.7302121789667*z**9)", + "-1.75869118663323*x**10 + x**8*(56.2781179722634*y**2 + 22.862985426232*z**2) + x**6*(-131.315608601948*y**4 - 787.893651611688*y**2*z**2 + 24.6216766128653*z**4) + x**4*(1969.73412902922*y**4*z**2 - 24.6216766128653*z**6) + x**2*(-1969.73412902922*y**4*z**4 + 787.893651611688*y**2*z**6 - 22.862985426232*z**8) + 131.315608601948*y**4*z**6 - 56.2781179722634*y**2*z**8 + 1.75869118663323*z**10", + "y**3*(-541.428124558099*x**6*z + 2707.14062279049*x**4*z**3 - 1624.2843736743*x**2*z**5 + 77.3468749368712*z**7) + y*(101.517773354644*x**8*z - 406.071093418574*x**6*z**3 - 203.035546709287*x**4*z**5 + 290.050781013267*x**2*z**7 - 14.5025390506634*z**9)", + "-1.97354559160624*x**10 + 53.2857309733686*x**8*z**2 - 82.8889148474622*x**6*z**4 - 82.8889148474622*x**4*z**6 + 53.2857309733686*x**2*z**8 + y**2*(35.5238206489124*x**8 - 994.666978169547*x**6*z**2 + 2486.66744542387*x**4*z**4 - 994.666978169547*x**2*z**6 + 35.5238206489124*z**8) - 1.97354559160624*z**10", + "y*(109.491768723557*x**8*z - 1021.9231747532*x**6*z**3 + 1532.8847621298*x**4*z**5 - 437.967074894228*x**2*z**7 + 12.1657520803952*z**9)", + "-2.72034486491732*x**10 + 122.415518921279*x**8*z**2 - 571.272421632637*x**6*z**4 + 571.272421632637*x**4*z**6 - 122.415518921279*x**2*z**8 + 2.72034486491732*z**10" + ], + "bwd": { + "x": "g_0*(244.831037842559*x**8*z - 2285.08968653055*x**6*z**3 + 3427.63452979582*x**4*z**5 - 979.324151370235*x**2*z**7 + 27.2034486491732*z**9) + g_1*y*(109.491768723557*x**8 - 3065.7695242596*x**6*z**2 + 7664.42381064899*x**4*z**4 - 3065.7695242596*x**2*z**6 + 109.491768723557*z**8) + g_10*(-11.2774323743054*x**9 + 8.0*x**7*(56.3871618715269*y**2 - 5.63871618715269*z**2) + 6.0*x**5*(-300.731529981477*y**4 + 225.548647486108*y**2*z**2 - 11.2774323743054*z**4) + 4.0*x**3*(360.877835977772*y**6 - 902.194589944431*y**4*z**2 + 338.322971229162*y**2*z**4 - 11.2774323743054*z**6) + 2.0*x*(-103.107953136506*y**8 + 721.755671955545*y**6*z**2 - 902.194589944431*y**4*z**4 + 225.548647486108*y**2*z**6 - 5.63871618715269*z**8)) + g_11*(133.817083060236*x*y*z**7 + z**5*(401.451249180707*x**3*y - 1338.17083060236*x*y**3) + z**3*(401.451249180707*x**5*y - 2676.34166120471*x**3*y**3 + 2141.07332896377*x*y**5) + z*(133.817083060236*x**7*y - 1338.17083060236*x**5*y**3 + 2141.07332896377*x**3*y**5 - 611.735236846792*x*y**7)) + g_12*(-16.0956935264578*x**9 + 8.0*x**7*(77.2593289269976*y**2 - 4.82870805793735*z**2) + 6.0*x**5*(-386.296644634988*y**4 + 154.518657853995*y**2*z**2 - 3.21913870529156*z**4) + 4.0*x**3*(412.04975427732*y**6 - 386.296644634988*y**4*z**2 + 3.21913870529156*z**6) + 2.0*x*(-88.2963759165686*y**8 + 386.296644634988*y**4*z**4 - 154.518657853995*y**2*z**6 + 4.82870805793735*z**8)) + g_13*(-900.449887556215*x*y**7*z + y**5*(4727.36190967013*x**3*z + 1575.78730322338*x*z**3) + y**3*(-3545.5214322526*x**5*z - 3939.46825805844*x**3*z**3 - 393.946825805844*x*z**5) + y*(393.946825805844*x**7*z + 787.893651611688*x**5*z**3 + 393.946825805844*x**3*z**5)) + g_14*(-16.581099521085*x**9 + 8.0*x**7*(69.640617988557*y**2 + 4.9743298563255*z**2) + 6.0*x**5*(-278.562471954228*y**4 - 278.562471954228*y**2*z**2 + 23.213539329519*z**4) + 4.0*x**3*(185.708314636152*y**6 + 1392.81235977114*y**4*z**2 - 696.40617988557*y**2*z**4 + 23.213539329519*z**6) + 2.0*x*(-1114.24988781691*y**6*z**2 + 1392.81235977114*y**4*z**4 - 278.562471954228*y**2*z**6 + 4.9743298563255*z**8)) + g_15*(y**5*(3523.56752808854*x**3*z - 3523.56752808854*x*z**3) + y**3*(-4404.45941011068*x**5*z + 2936.30627340712*x**3*z**3 + 2642.67564606641*x*z**5) + y*(629.208487158668*x**7*z - 880.891882022136*x**3*z**5 - 251.683394863467*x*z**7)) + g_16*(-17.5869118663323*x**9 + 8.0*x**7*(56.2781179722634*y**2 + 22.862985426232*z**2) + 6.0*x**5*(-131.315608601948*y**4 - 787.893651611688*y**2*z**2 + 24.6216766128653*z**4) + 4.0*x**3*(1969.73412902922*y**4*z**2 - 24.6216766128653*z**6) + 2.0*x*(-1969.73412902922*y**4*z**4 + 787.893651611688*y**2*z**6 - 22.862985426232*z**8)) + g_17*(y**3*(-3248.56874734859*x**5*z + 10828.562491162*x**3*z**3 - 3248.56874734859*x*z**5) + y*(812.142186837148*x**7*z - 2436.42656051144*x**5*z**3 - 812.142186837148*x**3*z**5 + 580.101562026534*x*z**7)) + g_18*(-19.7354559160624*x**9 + 426.285847786949*x**7*z**2 - 497.333489084773*x**5*z**4 - 331.555659389849*x**3*z**6 + 106.571461946737*x*z**8 + y**2*(284.190565191299*x**7 - 5968.00186901728*x**5*z**2 + 9946.66978169547*x**3*z**4 - 1989.33395633909*x*z**6)) + g_19*y*(875.934149788456*x**7*z - 6131.53904851919*x**5*z**3 + 6131.53904851919*x**3*z**5 - 875.934149788456*x*z**7) + g_2*(142.09528259565*x**8*z - 663.111318779698*x**6*z**3 + 284.190565191299*x**2*z**7 + y**2*(-1989.33395633909*x**6*z + 9946.66978169547*x**4*z**3 - 5968.00186901728*x**2*z**5 + 284.190565191299*z**7) - 15.7883647328499*z**9) + g_20*(-27.2034486491732*x**9 + 979.324151370235*x**7*z**2 - 3427.63452979582*x**5*z**4 + 2285.08968653055*x**3*z**6 - 244.831037842559*x*z**8) + g_3*(y**3*(-541.428124558099*x**6 + 8121.42186837148*x**4*z**2 - 8121.42186837148*x**2*z**4 + 541.428124558099*z**6) + y*(130.52285145597*x**8 - 2030.35546709287*x**6*z**2 + 1015.17773354644*x**4*z**4 + 1218.21328025572*x**2*z**6 - 101.517773354644*z**8)) + g_4*(94.9693240781945*x**8*z + 7.0*x**6*(-337.668707833581*y**2*z - 14.0695294930659*z**3) + 5.0*x**4*(787.893651611688*y**4*z + 787.893651611688*y**2*z**3 - 49.2433532257305*z**5) + 3.0*x**2*(-2626.31217203896*y**4*z**3 + 787.893651611688*y**2*z**5 - 14.0695294930659*z**7) + 787.893651611688*y**4*z**5 - 337.668707833581*y**2*z**7 + 10.5521471197994*z**9) + g_5*(y**5*(880.891882022136*x**4 - 5285.35129213281*x**2*z**2 + 880.891882022136*z**4) + y**3*(-1027.70719569249*x**6 + 6606.68911516602*x**4*z**2 + 2202.22970505534*x**2*z**4 - 734.07656835178*z**6) + y*(141.5719096107*x**8 - 880.891882022136*x**6*z**2 - 1101.11485252767*x**4*z**4 + 78.6510608948335*z**8)) + g_6*(59.691958275906*x**8*z + 7.0*x**6*(-278.562471954228*y**2*z + 13.264879616868*z**3) + 5.0*x**4*(1114.24988781691*y**4*z - 278.562471954228*y**2*z**3) + 3.0*x**2*(-742.833258544608*y**6*z + 278.562471954228*y**2*z**5 - 13.264879616868*z**7) + 742.833258544608*y**6*z**3 - 1114.24988781691*y**4*z**5 + 278.562471954228*y**2*z**7 - 6.632439808434*z**9) + g_7*(y**7*(-450.224943778107*x**2 + 450.224943778107*z**2) + y**5*(1969.73412902922*x**4 - 2363.68095483506*x**2*z**2 - 1181.84047741753*z**4) + y**3*(-1378.81389032045*x**6 + 984.86706451461*x**4*z**2 + 2954.60119354383*x**2*z**4 + 590.920238708766*z**6) + y*(147.730059677192*x**8 - 492.433532257305*x**4*z**4 - 393.946825805844*x**2*z**6 - 49.2433532257305*z**8)) + g_8*(28.9722483476241*x**8*z + 7.0*x**6*(-154.518657853995*y**2*z + 12.8765548211663*z**3) + 5.0*x**4*(772.593289269975*y**4*z - 463.555973561985*y**2*z**3 + 19.3148322317494*z**5) + 3.0*x**2*(-824.099508554641*y**6*z + 1545.18657853995*y**4*z**3 - 463.555973561985*y**2*z**5 + 12.8765548211663*z**7) + 176.592751833137*y**8*z - 824.099508554641*y**6*z**3 + 772.593289269975*y**4*z**5 - 154.518657853995*y**2*z**7 + 3.21913870529156*z**9) + g_9*(150.544218442765*x**8*y + 7.0*x**6*(-223.028471767059*y**3 + 66.9085415301178*y*z**2) + 5.0*x**4*(535.268332240943*y**5 - 669.085415301178*y**3*z**2 + 100.362812295177*y*z**4) + 3.0*x**2*(-305.867618423396*y**7 + 1070.53666448189*y**5*z**2 - 669.085415301178*y**3*z**4 + 66.9085415301178*y*z**6) + 33.9852909359329*y**9 - 305.867618423396*y**7*z**2 + 535.268332240943*y**5*z**4 - 223.028471767059*y**3*z**6 + 16.7271353825295*y*z**8)", + "y": "g_1*(12.1657520803952*x**9 - 437.967074894228*x**7*z**2 + 1532.8847621298*x**5*z**4 - 1021.9231747532*x**3*z**6 + 109.491768723557*x*z**8) + g_10*(112.774323743054*x**8*y + x**6*(-1202.92611992591*y**3 + 451.097294972216*y*z**2) + x**4*(2165.26701586663*y**5 - 3608.77835977772*y**3*z**2 + 676.645942458323*y*z**4) + x**2*(-824.863625092051*y**7 + 4330.53403173327*y**5*z**2 - 3608.77835977772*y**3*z**4 + 451.097294972216*y*z**6) + 45.8257569495584*y**9 - 824.863625092051*y**7*z**2 + 2165.26701586663*y**5*z**4 - 1202.92611992591*y**3*z**6 + 112.774323743054*y*z**8) + g_11*(16.7271353825295*z**9 + z**7*(66.9085415301178*x**2 - 669.085415301178*y**2) + z**5*(100.362812295177*x**4 - 2007.25624590353*x**2*y**2 + 2676.34166120471*y**4) + z**3*(66.9085415301178*x**6 - 2007.25624590353*x**4*y**2 + 5352.68332240943*x**2*y**4 - 2141.07332896377*y**6) + z*(16.7271353825295*x**8 - 669.085415301178*x**6*y**2 + 2676.34166120471*x**4*y**4 - 2141.07332896377*x**2*y**6 + 305.867618423396*y**8)) + g_12*(154.518657853995*x**8*y + x**6*(-1545.18657853995*y**3 + 309.03731570799*y*z**2) + x**4*(2472.29852566392*y**5 - 1545.18657853995*y**3*z**2) + x**2*(-706.371007332549*y**7 + 1545.18657853995*y**3*z**4 - 309.03731570799*y*z**6) + 706.371007332549*y**7*z**2 - 2472.29852566392*y**5*z**4 + 1545.18657853995*y**3*z**6 - 154.518657853995*y*z**8) + g_13*(49.2433532257305*x**8*z + 131.315608601948*x**6*z**3 + 98.486706451461*x**4*z**5 + 7.0*y**6*(-450.224943778107*x**2*z + 150.074981259369*z**3) + 5.0*y**4*(1181.84047741753*x**4*z + 787.893651611688*x**2*z**3 - 393.946825805844*z**5) + 3.0*y**2*(-590.920238708766*x**6*z - 984.86706451461*x**4*z**3 - 196.973412902922*x**2*z**5 + 196.973412902922*z**7) - 16.4144510752435*z**9) + g_14*(139.281235977114*x**8*y + x**6*(-1114.24988781691*y**3 - 557.124943908456*y*z**2) + x**4*(1114.24988781691*y**5 + 5571.24943908456*y**3*z**2 - 1392.81235977114*y*z**4) + x**2*(-6685.49932690147*y**5*z**2 + 5571.24943908456*y**3*z**4 - 557.124943908456*y*z**6) + 1114.24988781691*y**5*z**4 - 1114.24988781691*y**3*z**6 + 139.281235977114*y*z**8) + g_15*(78.6510608948335*x**8*z - 220.222970505534*x**4*z**5 - 125.841697431734*x**2*z**7 + 5.0*y**4*(880.891882022136*x**4*z - 1761.78376404427*x**2*z**3 + 176.178376404427*z**5) + 3.0*y**2*(-734.07656835178*x**6*z + 734.07656835178*x**4*z**3 + 1321.3378230332*x**2*z**5 - 146.815313670356*z**7) + 15.7302121789667*z**9) + g_16*(112.556235944527*x**8*y + x**6*(-525.262434407792*y**3 - 1575.78730322338*y*z**2) + 7878.93651611688*x**4*y**3*z**2 + x**2*(-7878.93651611688*y**3*z**4 + 1575.78730322338*y*z**6) + 525.262434407792*y**3*z**6 - 112.556235944527*y*z**8) + g_17*(101.517773354644*x**8*z - 406.071093418574*x**6*z**3 - 203.035546709287*x**4*z**5 + 290.050781013267*x**2*z**7 + 3.0*y**2*(-541.428124558099*x**6*z + 2707.14062279049*x**4*z**3 - 1624.2843736743*x**2*z**5 + 77.3468749368712*z**7) - 14.5025390506634*z**9) + 2.0*g_18*y*(35.5238206489124*x**8 - 994.666978169547*x**6*z**2 + 2486.66744542387*x**4*z**4 - 994.666978169547*x**2*z**6 + 35.5238206489124*z**8) + g_19*(109.491768723557*x**8*z - 1021.9231747532*x**6*z**3 + 1532.8847621298*x**4*z**5 - 437.967074894228*x**2*z**7 + 12.1657520803952*z**9) + 2.0*g_2*y*(-284.190565191299*x**7*z + 1989.33395633909*x**5*z**3 - 1989.33395633909*x**3*z**5 + 284.190565191299*x*z**7) + g_3*(14.5025390506634*x**9 - 290.050781013267*x**7*z**2 + 203.035546709287*x**5*z**4 + 406.071093418574*x**3*z**6 - 101.517773354644*x*z**8 + 3.0*y**2*(-77.3468749368713*x**7 + 1624.2843736743*x**5*z**2 - 2707.14062279049*x**3*z**4 + 541.428124558099*x*z**6)) + g_4*(-675.337415667161*x**7*y*z + x**5*(3151.57460644675*y**3*z + 1575.78730322338*y*z**3) + x**3*(-10505.2486881558*y**3*z**3 + 1575.78730322338*y*z**5) + x*(3151.57460644675*y**3*z**5 - 675.337415667161*y*z**7)) + g_5*(15.7302121789667*x**9 - 125.841697431734*x**7*z**2 - 220.222970505534*x**5*z**4 + 78.6510608948335*x*z**8 + 5.0*y**4*(176.178376404427*x**5 - 1761.78376404427*x**3*z**2 + 880.891882022136*x*z**4) + 3.0*y**2*(-146.815313670356*x**7 + 1321.3378230332*x**5*z**2 + 734.07656835178*x**3*z**4 - 734.07656835178*x*z**6)) + g_6*(-557.124943908456*x**7*y*z + x**5*(4456.99955126765*y**3*z - 557.124943908456*y*z**3) + x**3*(-4456.99955126765*y**5*z + 557.124943908456*y*z**5) + x*(4456.99955126765*y**5*z**3 - 4456.99955126765*y**3*z**5 + 557.124943908456*y*z**7)) + g_7*(16.4144510752435*x**9 - 98.486706451461*x**5*z**4 - 131.315608601948*x**3*z**6 - 49.2433532257305*x*z**8 + 7.0*y**6*(-150.074981259369*x**3 + 450.224943778107*x*z**2) + 5.0*y**4*(393.946825805844*x**5 - 787.893651611688*x**3*z**2 - 1181.84047741753*x*z**4) + 3.0*y**2*(-196.973412902922*x**7 + 196.973412902922*x**5*z**2 + 984.86706451461*x**3*z**4 + 590.920238708766*x*z**6)) + g_8*(-309.03731570799*x**7*y*z + x**5*(3090.3731570799*y**3*z - 927.111947123971*y*z**3) + x**3*(-4944.59705132784*y**5*z + 6180.7463141598*y**3*z**3 - 927.111947123971*y*z**5) + x*(1412.7420146651*y**7*z - 4944.59705132784*y**5*z**3 + 3090.3731570799*y**3*z**5 - 309.03731570799*y*z**7)) + g_9*(16.7271353825295*x**9 + x**7*(-669.085415301178*y**2 + 66.9085415301178*z**2) + x**5*(2676.34166120471*y**4 - 2007.25624590353*y**2*z**2 + 100.362812295177*z**4) + x**3*(-2141.07332896377*y**6 + 5352.68332240943*y**4*z**2 - 2007.25624590353*y**2*z**4 + 66.9085415301178*z**6) + x*(305.867618423396*y**8 - 2141.07332896377*y**6*z**2 + 2676.34166120471*y**4*z**4 - 669.085415301178*y**2*z**6 + 16.7271353825295*z**8))", + "z": "g_0*(27.2034486491732*x**9 - 979.324151370235*x**7*z**2 + 3427.63452979582*x**5*z**4 - 2285.08968653055*x**3*z**6 + 244.831037842559*x*z**8) + g_1*y*(-875.934149788456*x**7*z + 6131.53904851919*x**5*z**3 - 6131.53904851919*x**3*z**5 + 875.934149788456*x*z**7) + g_10*(-11.2774323743054*x**8*z + x**6*(451.097294972216*y**2*z - 45.1097294972216*z**3) + x**4*(-1804.38917988886*y**4*z + 1353.29188491665*y**2*z**3 - 67.6645942458323*z**5) + x**2*(1443.51134391109*y**6*z - 3608.77835977772*y**4*z**3 + 1353.29188491665*y**2*z**5 - 45.1097294972216*z**7) - 206.215906273013*y**8*z + 1443.51134391109*y**6*z**3 - 1804.38917988886*y**4*z**5 + 451.097294972215*y**2*z**7 - 11.2774323743054*z**9) + g_11*(16.7271353825295*x**8*y - 223.028471767059*x**6*y**3 + 535.268332240943*x**4*y**5 - 305.867618423396*x**2*y**7 + 33.9852909359329*y**9 + 150.544218442765*y*z**8 + 7.0*z**6*(66.9085415301178*x**2*y - 223.028471767059*y**3) + 5.0*z**4*(100.362812295177*x**4*y - 669.085415301178*x**2*y**3 + 535.268332240943*y**5) + 3.0*z**2*(66.9085415301178*x**6*y - 669.085415301178*x**4*y**3 + 1070.53666448189*x**2*y**5 - 305.867618423396*y**7)) + g_12*(-9.65741611587469*x**8*z + x**6*(309.03731570799*y**2*z - 12.8765548211663*z**3) + x**4*(-772.593289269976*y**4*z + 19.3148322317494*z**5) + x**2*(1545.18657853995*y**4*z**3 - 927.11194712397*y**2*z**5 + 38.6296644634988*z**7) + 176.592751833137*y**8*z - 1648.19901710928*y**6*z**3 + 2317.77986780993*y**4*z**5 - 618.07463141598*y**2*z**7 + 16.0956935264578*z**9) + g_13*(y**7*(-450.224943778107*x**2 + 450.224943778107*z**2) + y**5*(1181.84047741753*x**4 + 2363.68095483506*x**2*z**2 - 1969.73412902922*z**4) + y**3*(-590.920238708766*x**6 - 2954.60119354383*x**4*z**2 - 984.86706451461*x**2*z**4 + 1378.81389032045*z**6) + y*(49.2433532257305*x**8 + 393.946825805844*x**6*z**2 + 492.433532257305*x**4*z**4 - 147.730059677191*z**8)) + g_14*(9.948659712651*x**8*z + x**6*(-557.124943908456*y**2*z + 92.854157318076*z**3) + x**4*(2785.62471954228*y**4*z - 2785.62471954228*y**2*z**3 + 139.281235977114*z**5) + x**2*(-2228.49977563382*y**6*z + 5571.24943908456*y**4*z**3 - 1671.37483172537*y**2*z**5 + 39.794638850604*z**7) + 742.833258544608*y**6*z**3 - 1671.37483172537*y**4*z**5 + 557.124943908456*y**2*z**7 - 16.581099521085*z**9) + g_15*(y**5*(880.891882022136*x**4 - 5285.35129213281*x**2*z**2 + 880.891882022136*z**4) + y**3*(-734.07656835178*x**6 + 2202.22970505534*x**4*z**2 + 6606.68911516602*x**2*z**4 - 1027.70719569249*z**6) + y*(78.6510608948335*x**8 - 1101.11485252767*x**4*z**4 - 880.891882022136*x**2*z**6 + 141.5719096107*z**8)) + g_16*(45.725970852464*x**8*z + x**6*(-1575.78730322338*y**2*z + 98.486706451461*z**3) + x**4*(3939.46825805844*y**4*z - 147.730059677191*z**5) + x**2*(-7878.93651611688*y**4*z**3 + 4727.36190967013*y**2*z**5 - 182.903883409856*z**7) + 787.893651611688*y**4*z**5 - 450.224943778108*y**2*z**7 + 17.5869118663323*z**9) + g_17*(y**3*(-541.428124558099*x**6 + 8121.42186837148*x**4*z**2 - 8121.42186837148*x**2*z**4 + 541.428124558099*z**6) + y*(101.517773354644*x**8 - 1218.21328025572*x**6*z**2 - 1015.17773354644*x**4*z**4 + 2030.35546709287*x**2*z**6 - 130.52285145597*z**8)) + g_18*(106.571461946737*x**8*z - 331.555659389849*x**6*z**3 - 497.333489084773*x**4*z**5 + 426.285847786948*x**2*z**7 + y**2*(-1989.33395633909*x**6*z + 9946.66978169547*x**4*z**3 - 5968.00186901728*x**2*z**5 + 284.190565191299*z**7) - 19.7354559160624*z**9) + g_19*y*(109.491768723557*x**8 - 3065.7695242596*x**6*z**2 + 7664.42381064899*x**4*z**4 - 3065.7695242596*x**2*z**6 + 109.491768723557*z**8) + g_2*(15.7883647328499*x**9 - 284.190565191299*x**7*z**2 + 663.111318779698*x**3*z**6 - 142.09528259565*x*z**8 + y**2*(-284.190565191299*x**7 + 5968.00186901728*x**5*z**2 - 9946.66978169547*x**3*z**4 + 1989.33395633909*x*z**6)) + g_20*(244.831037842559*x**8*z - 2285.08968653055*x**6*z**3 + 3427.63452979582*x**4*z**5 - 979.324151370235*x**2*z**7 + 27.2034486491732*z**9) + g_3*(y**3*(3248.56874734859*x**5*z - 10828.562491162*x**3*z**3 + 3248.56874734859*x*z**5) + y*(-580.101562026534*x**7*z + 812.142186837148*x**5*z**3 + 2436.42656051144*x**3*z**5 - 812.142186837148*x*z**7)) + g_4*(10.5521471197994*x**9 + x**7*(-337.668707833581*y**2 - 42.2085884791976*z**2) + x**5*(787.893651611688*y**4 + 2363.68095483506*y**2*z**2 - 246.216766128653*z**4) + x**3*(-7878.93651611688*y**4*z**2 + 3939.46825805844*y**2*z**4 - 98.486706451461*z**6) + x*(3939.46825805844*y**4*z**4 - 2363.68095483506*y**2*z**6 + 94.9693240781945*z**8)) + g_5*(y**5*(-3523.56752808854*x**3*z + 3523.56752808854*x*z**3) + y**3*(2642.67564606641*x**5*z + 2936.30627340712*x**3*z**3 - 4404.45941011068*x*z**5) + y*(-251.683394863467*x**7*z - 880.891882022136*x**5*z**3 + 629.208487158668*x*z**7)) + g_6*(6.632439808434*x**9 + x**7*(-278.562471954228*y**2 + 39.794638850604*z**2) + x**5*(1114.24988781691*y**4 - 835.687415862684*y**2*z**2) + x**3*(-742.833258544608*y**6 + 1392.81235977114*y**2*z**4 - 92.854157318076*z**6) + x*(2228.49977563382*y**6*z**2 - 5571.24943908456*y**4*z**4 + 1949.9373036796*y**2*z**6 - 59.691958275906*z**8)) + g_7*(900.449887556215*x*y**7*z + y**5*(-1575.78730322338*x**3*z - 4727.36190967013*x*z**3) + y**3*(393.946825805844*x**5*z + 3939.46825805844*x**3*z**3 + 3545.5214322526*x*z**5) + y*(-393.946825805844*x**5*z**3 - 787.893651611688*x**3*z**5 - 393.946825805844*x*z**7)) + g_8*(3.21913870529156*x**9 + x**7*(-154.518657853995*y**2 + 38.6296644634988*z**2) + x**5*(772.593289269975*y**4 - 1390.66792068596*y**2*z**2 + 96.5741611587469*z**4) + x**3*(-824.099508554641*y**6 + 4635.55973561985*y**4*z**2 - 2317.77986780993*y**2*z**4 + 90.1358837481638*z**6) + x*(176.592751833137*y**8 - 2472.29852566392*y**6*z**2 + 3862.96644634988*y**4*z**4 - 1081.63060497797*y**2*z**6 + 28.9722483476241*z**8)) + g_9*(133.817083060236*x**7*y*z + x**5*(-1338.17083060236*y**3*z + 401.451249180707*y*z**3) + x**3*(2141.07332896377*y**5*z - 2676.34166120471*y**3*z**3 + 401.451249180707*y*z**5) + x*(-611.735236846792*y**7*z + 2141.07332896377*y**5*z**3 - 1338.17083060236*y**3*z**5 + 133.817083060236*y*z**7))" + } +} \ No newline at end of file diff --git a/notebooks/direct_sph_harm/l_2.json b/notebooks/direct_sph_harm/l_2.json new file mode 100644 index 0000000..a0e3082 --- /dev/null +++ b/notebooks/direct_sph_harm/l_2.json @@ -0,0 +1,14 @@ +{ + "fwd": [ + "3.87298334620742*x*z", + "3.87298334620742*x*y", + "-1.11803398874989*x**2 + 2.23606797749979*y**2 - 1.11803398874989*z**2", + "3.87298334620742*y*z", + "-1.93649167310371*x**2 + 1.93649167310371*z**2" + ], + "bwd": { + "x": "3.87298334620742*g_0*z + 3.87298334620742*g_1*y - 2.23606797749979*g_2*x - 3.87298334620742*g_4*x", + "y": "3.87298334620742*g_1*x + 4.47213595499958*g_2*y + 3.87298334620742*g_3*z", + "z": "3.87298334620742*g_0*x - 2.23606797749979*g_2*z + 3.87298334620742*g_3*y + 3.87298334620742*g_4*z" + } +} \ No newline at end of file diff --git a/notebooks/direct_sph_harm/l_3.json b/notebooks/direct_sph_harm/l_3.json new file mode 100644 index 0000000..271a64f --- /dev/null +++ b/notebooks/direct_sph_harm/l_3.json @@ -0,0 +1,16 @@ +{ + "fwd": [ + "-2.09165006633519*x**3 + 6.27495019900557*x*z**2", + "10.2469507659596*x*y*z", + "-1.62018517460197*x**3 + x*(6.48074069840786*y**2 - 1.62018517460197*z**2)", + "-3.96862696659689*x**2*y + 2.64575131106459*y**3 - 3.96862696659689*y*z**2", + "-1.62018517460197*z**3 + z*(-1.62018517460197*x**2 + 6.48074069840786*y**2)", + "5.1234753829798*y*(-x**2 + z**2)", + "-6.27495019900557*x**2*z + 2.09165006633519*z**3" + ], + "bwd": { + "x": "g_0*(-6.27495019900557*x**2 + 6.27495019900557*z**2) + 10.2469507659596*g_1*y*z + g_2*(-4.8605555238059*x**2 + 6.48074069840786*y**2 - 1.62018517460197*z**2) - 7.93725393319377*g_3*x*y - 3.24037034920393*g_4*x*z - 10.2469507659596*g_5*x*y - 12.5499003980111*g_6*x*z", + "y": "10.2469507659596*g_1*x*z + 12.9614813968157*g_2*x*y + g_3*(-3.96862696659689*x**2 + 7.93725393319377*y**2 - 3.96862696659689*z**2) + 12.9614813968157*g_4*y*z + g_5*(-5.1234753829798*x**2 + 5.1234753829798*z**2)", + "z": "12.5499003980111*g_0*x*z + 10.2469507659596*g_1*x*y - 3.24037034920393*g_2*x*z - 7.93725393319377*g_3*y*z + g_4*(-1.62018517460197*x**2 + 6.48074069840786*y**2 - 4.8605555238059*z**2) + 10.2469507659596*g_5*y*z + g_6*(-6.27495019900557*x**2 + 6.27495019900557*z**2)" + } +} \ No newline at end of file diff --git a/notebooks/direct_sph_harm/l_4.json b/notebooks/direct_sph_harm/l_4.json new file mode 100644 index 0000000..9e41f26 --- /dev/null +++ b/notebooks/direct_sph_harm/l_4.json @@ -0,0 +1,18 @@ +{ + "fwd": [ + "-8.87411967464942*x**3*z + 8.87411967464942*x*z**3", + "y*(-6.27495019900557*x**3 + 18.8248505970167*x*z**2)", + "-3.35410196624968*x**3*z + x*(20.1246117974981*y**2*z - 3.35410196624968*z**3)", + "-7.11512473537885*x**3*y + x*(9.48683298050514*y**3 - 7.11512473537885*y*z**2)", + "1.125*x**4 + x**2*(-9.0*y**2 + 2.25*z**2) + 3.0*y**4 - 9.0*y**2*z**2 + 1.125*z**4", + "-7.11512473537885*y*z**3 + z*(-7.11512473537885*x**2*y + 9.48683298050514*y**3)", + "1.67705098312484*x**4 + y**2*(-10.0623058987491*x**2 + 10.0623058987491*z**2) - 1.67705098312484*z**4", + "y*(-18.8248505970167*x**2*z + 6.27495019900557*z**3)", + "2.21852991866236*x**4 - 13.3111795119741*x**2*z**2 + 2.21852991866236*z**4" + ], + "bwd": { + "x": "g_0*(-26.6223590239483*x**2*z + 8.87411967464942*z**3) + g_1*y*(-18.8248505970167*x**2 + 18.8248505970167*z**2) + g_2*(-10.0623058987491*x**2*z + 20.1246117974981*y**2*z - 3.35410196624968*z**3) + g_3*(-21.3453742061366*x**2*y + 9.48683298050514*y**3 - 7.11512473537885*y*z**2) + g_4*(4.5*x**3 + 2.0*x*(-9.0*y**2 + 2.25*z**2)) - 14.2302494707577*g_5*x*y*z + g_6*(6.70820393249937*x**3 - 20.1246117974981*x*y**2) - 37.6497011940334*g_7*x*y*z + g_8*(8.87411967464942*x**3 - 26.6223590239483*x*z**2)", + "y": "g_1*(-6.27495019900557*x**3 + 18.8248505970167*x*z**2) + 40.2492235949962*g_2*x*y*z + g_3*(-7.11512473537885*x**3 + x*(28.4604989415154*y**2 - 7.11512473537885*z**2)) + g_4*(-18.0*x**2*y + 12.0*y**3 - 18.0*y*z**2) + g_5*(-7.11512473537885*z**3 + z*(-7.11512473537885*x**2 + 28.4604989415154*y**2)) + 2.0*g_6*y*(-10.0623058987491*x**2 + 10.0623058987491*z**2) + g_7*(-18.8248505970167*x**2*z + 6.27495019900557*z**3)", + "z": "g_0*(-8.87411967464942*x**3 + 26.6223590239483*x*z**2) + 37.6497011940334*g_1*x*y*z + g_2*(-3.35410196624968*x**3 + x*(20.1246117974981*y**2 - 10.0623058987491*z**2)) - 14.2302494707577*g_3*x*y*z + g_4*(4.5*x**2*z - 18.0*y**2*z + 4.5*z**3) + g_5*(-7.11512473537885*x**2*y + 9.48683298050514*y**3 - 21.3453742061366*y*z**2) + g_6*(20.1246117974981*y**2*z - 6.70820393249937*z**3) + g_7*y*(-18.8248505970167*x**2 + 18.8248505970167*z**2) + g_8*(-26.6223590239483*x**2*z + 8.87411967464942*z**3)" + } +} \ No newline at end of file diff --git a/notebooks/direct_sph_harm/l_5.json b/notebooks/direct_sph_harm/l_5.json new file mode 100644 index 0000000..dd04bb3 --- /dev/null +++ b/notebooks/direct_sph_harm/l_5.json @@ -0,0 +1,20 @@ +{ + "fwd": [ + "2.32681380862329*x**5 - 23.2681380862329*x**3*z**2 + 11.6340690431164*x*z**4", + "y*(-29.4321253055229*x**3*z + 29.4321253055229*x*z**3)", + "1.73430461568895*x**5 + x**3*(-13.8744369255116*y**2 - 3.4686092313779*z**2) + x*(41.6233107765348*y**2*z**2 - 5.20291384706685*z**4)", + "-16.9926454679664*x**3*y*z + x*(33.9852909359329*y**3*z - 16.9926454679664*y*z**3)", + "1.60565407233314*x**5 + x**3*(-19.2678488679977*y**2 + 3.21130814466628*z**2) + x*(12.8452325786651*y**4 - 19.2678488679977*y**2*z**2 + 1.60565407233314*z**4)", + "3.3166247903554*y**5 + y**3*(-16.583123951777*x**2 - 16.583123951777*z**2) + y*(6.21867148191637*x**4 + 12.4373429638327*x**2*z**2 + 6.21867148191637*z**4)", + "1.60565407233314*z**5 + z**3*(3.21130814466628*x**2 - 19.2678488679977*y**2) + z*(1.60565407233314*x**4 - 19.2678488679977*x**2*y**2 + 12.8452325786651*y**4)", + "y**3*(-16.9926454679664*x**2 + 16.9926454679664*z**2) + y*(8.49632273398321*x**4 - 8.49632273398321*z**4)", + "-1.73430461568895*z**5 + z**3*(3.4686092313779*x**2 + 13.8744369255116*y**2) + z*(5.20291384706685*x**4 - 41.6233107765348*x**2*y**2)", + "y*(7.35803132638072*x**4 - 44.1481879582843*x**2*z**2 + 7.35803132638072*z**4)", + "11.6340690431164*x**4*z - 23.2681380862329*x**2*z**3 + 2.32681380862329*z**5" + ], + "bwd": { + "x": "g_0*(11.6340690431164*x**4 - 69.8044142586986*x**2*z**2 + 11.6340690431164*z**4) + g_1*y*(-88.2963759165686*x**2*z + 29.4321253055229*z**3) + g_10*(46.5362761724657*x**3*z - 46.5362761724657*x*z**3) + g_2*(8.67152307844476*x**4 + 3.0*x**2*(-13.8744369255116*y**2 - 3.4686092313779*z**2) + 41.6233107765348*y**2*z**2 - 5.20291384706685*z**4) + g_3*(-50.9779364038993*x**2*y*z + 33.9852909359329*y**3*z - 16.9926454679664*y*z**3) + g_4*(8.02827036166571*x**4 + 3.0*x**2*(-19.2678488679977*y**2 + 3.21130814466628*z**2) + 12.8452325786651*y**4 - 19.2678488679977*y**2*z**2 + 1.60565407233314*z**4) + g_5*(-33.166247903554*x*y**3 + y*(24.8746859276655*x**3 + 24.8746859276655*x*z**2)) + g_6*(6.42261628933256*x*z**3 + z*(6.42261628933256*x**3 - 38.5356977359954*x*y**2)) + g_7*(33.9852909359329*x**3*y - 33.9852909359329*x*y**3) + g_8*(6.9372184627558*x*z**3 + z*(20.8116553882674*x**3 - 83.2466215530696*x*y**2)) + g_9*y*(29.4321253055229*x**3 - 88.2963759165686*x*z**2)", + "y": "g_1*(-29.4321253055229*x**3*z + 29.4321253055229*x*z**3) + g_2*(-27.7488738510232*x**3*y + 83.2466215530696*x*y*z**2) + g_3*(-16.9926454679664*x**3*z + x*(101.955872807799*y**2*z - 16.9926454679664*z**3)) + g_4*(-38.5356977359954*x**3*y + x*(51.3809303146605*y**3 - 38.5356977359954*y*z**2)) + g_5*(6.21867148191637*x**4 + 12.4373429638327*x**2*z**2 + 16.583123951777*y**4 + 3.0*y**2*(-16.583123951777*x**2 - 16.583123951777*z**2) + 6.21867148191637*z**4) + g_6*(-38.5356977359954*y*z**3 + z*(-38.5356977359954*x**2*y + 51.3809303146605*y**3)) + g_7*(8.49632273398321*x**4 + 3.0*y**2*(-16.9926454679664*x**2 + 16.9926454679664*z**2) - 8.49632273398321*z**4) + g_8*(-83.2466215530696*x**2*y*z + 27.7488738510232*y*z**3) + g_9*(7.35803132638072*x**4 - 44.1481879582843*x**2*z**2 + 7.35803132638072*z**4)", + "z": "g_0*(-46.5362761724657*x**3*z + 46.5362761724657*x*z**3) + g_1*y*(-29.4321253055229*x**3 + 88.2963759165686*x*z**2) + g_10*(11.6340690431164*x**4 - 69.8044142586986*x**2*z**2 + 11.6340690431164*z**4) + g_2*(-6.9372184627558*x**3*z + x*(83.2466215530696*y**2*z - 20.8116553882674*z**3)) + g_3*(-16.9926454679664*x**3*y + x*(33.9852909359329*y**3 - 50.9779364038993*y*z**2)) + g_4*(6.42261628933256*x**3*z + x*(-38.5356977359954*y**2*z + 6.42261628933257*z**3)) + g_5*(-33.166247903554*y**3*z + y*(24.8746859276655*x**2*z + 24.8746859276655*z**3)) + g_6*(1.60565407233314*x**4 - 19.2678488679977*x**2*y**2 + 12.8452325786651*y**4 + 8.02827036166571*z**4 + 3.0*z**2*(3.21130814466628*x**2 - 19.2678488679977*y**2)) + g_7*(33.9852909359329*y**3*z - 33.9852909359329*y*z**3) + g_8*(5.20291384706685*x**4 - 41.6233107765348*x**2*y**2 - 8.67152307844475*z**4 + 3.0*z**2*(3.4686092313779*x**2 + 13.8744369255116*y**2)) + g_9*y*(-88.2963759165686*x**2*z + 29.4321253055229*z**3)" + } +} \ No newline at end of file diff --git a/notebooks/direct_sph_harm/l_6.json b/notebooks/direct_sph_harm/l_6.json new file mode 100644 index 0000000..6144889 --- /dev/null +++ b/notebooks/direct_sph_harm/l_6.json @@ -0,0 +1,22 @@ +{ + "fwd": [ + "14.5309475774982*x**5*z - 48.4364919249939*x**3*z**3 + 14.5309475774982*x*z**5", + "y*(8.38944649544891*x**5 - 83.8944649544891*x**3*z**2 + 41.9472324772445*x*z**4)", + "7.15454401062709*x**5*z - 7.15454401062709*x*z**5 + y**2*(-71.5454401062709*x**3*z + 71.5454401062709*x*z**3)", + "y**3*(-26.1247009552263*x**3 + 78.3741028656788*x*z**2) + y*(9.79676285820985*x**5 - 19.5935257164197*x**3*z**2 - 29.3902885746295*x*z**4)", + "3.26558761940328*x**5*z + x**3*(-52.2494019104525*y**2*z + 6.53117523880657*z**3) + x*(52.2494019104525*y**4*z - 52.2494019104525*y**2*z**3 + 3.26558761940328*z**5)", + "10.3266947761614*x**5*y + x**3*(-41.3067791046458*y**3 + 20.6533895523229*y*z**2) + x*(16.5227116418583*y**5 - 41.3067791046458*y**3*z**2 + 10.3266947761614*y*z**4)", + "-1.1267347735825*x**6 + x**4*(20.2812259244849*y**2 - 3.38020432074749*z**2) + x**2*(-27.0416345659799*y**4 + 40.5624518489699*y**2*z**2 - 3.38020432074749*z**4) + 3.60555127546399*y**6 - 27.0416345659799*y**4*z**2 + 20.2812259244849*y**2*z**4 - 1.1267347735825*z**6", + "10.3266947761614*y*z**5 + z**3*(20.6533895523229*x**2*y - 41.3067791046458*y**3) + z*(10.3266947761614*x**4*y - 41.3067791046458*x**2*y**3 + 16.5227116418583*y**5)", + "-1.63279380970164*x**6 + x**4*(26.1247009552263*y**2 - 1.63279380970164*z**2) + x**2*(-26.1247009552263*y**4 + 1.63279380970164*z**4) + 26.1247009552263*y**4*z**2 - 26.1247009552263*y**2*z**4 + 1.63279380970164*z**6", + "y**3*(-78.3741028656788*x**2*z + 26.1247009552263*z**3) + y*(29.3902885746295*x**4*z + 19.5935257164197*x**2*z**3 - 9.79676285820985*z**5)", + "-1.78863600265677*x**6 + x**4*(17.8863600265677*y**2 + 8.94318001328386*z**2) + x**2*(-107.318160159406*y**2*z**2 + 8.94318001328386*z**4) + 17.8863600265677*y**2*z**4 - 1.78863600265677*z**6", + "y*(41.9472324772445*x**4*z - 83.8944649544891*x**2*z**3 + 8.38944649544891*z**5)", + "-2.4218245962497*x**6 + 36.3273689437454*x**4*z**2 - 36.3273689437454*x**2*z**4 + 2.4218245962497*z**6" + ], + "bwd": { + "x": "g_0*(72.6547378874909*x**4*z - 145.309475774982*x**2*z**3 + 14.5309475774982*z**5) + g_1*y*(41.9472324772445*x**4 - 251.683394863467*x**2*z**2 + 41.9472324772445*z**4) + g_10*(-10.7318160159406*x**5 + 4.0*x**3*(17.8863600265677*y**2 + 8.94318001328386*z**2) + 2.0*x*(-107.318160159406*y**2*z**2 + 8.94318001328386*z**4)) + g_11*y*(167.788929908978*x**3*z - 167.788929908978*x*z**3) + g_12*(-14.5309475774982*x**5 + 145.309475774982*x**3*z**2 - 72.6547378874909*x*z**4) + g_2*(35.7727200531355*x**4*z + y**2*(-214.636320318813*x**2*z + 71.5454401062709*z**3) - 7.15454401062709*z**5) + g_3*(y**3*(-78.3741028656788*x**2 + 78.3741028656788*z**2) + y*(48.9838142910493*x**4 - 58.7805771492591*x**2*z**2 - 29.3902885746295*z**4)) + g_4*(16.3279380970164*x**4*z + 3.0*x**2*(-52.2494019104525*y**2*z + 6.53117523880657*z**3) + 52.2494019104525*y**4*z - 52.2494019104525*y**2*z**3 + 3.26558761940328*z**5) + g_5*(51.6334738808072*x**4*y + 3.0*x**2*(-41.3067791046458*y**3 + 20.6533895523229*y*z**2) + 16.5227116418583*y**5 - 41.3067791046458*y**3*z**2 + 10.3266947761614*y*z**4) + g_6*(-6.76040864149498*x**5 + 4.0*x**3*(20.2812259244849*y**2 - 3.38020432074749*z**2) + 2.0*x*(-27.0416345659799*y**4 + 40.5624518489699*y**2*z**2 - 3.38020432074749*z**4)) + g_7*(41.3067791046458*x*y*z**3 + z*(41.3067791046458*x**3*y - 82.6135582092915*x*y**3)) + g_8*(-9.79676285820985*x**5 + 4.0*x**3*(26.1247009552263*y**2 - 1.63279380970164*z**2) + 2.0*x*(-26.1247009552263*y**4 + 1.63279380970164*z**4)) + g_9*(-156.748205731358*x*y**3*z + y*(117.561154298518*x**3*z + 39.1870514328394*x*z**3))", + "y": "g_1*(8.38944649544891*x**5 - 83.8944649544891*x**3*z**2 + 41.9472324772445*x*z**4) + g_10*(35.7727200531355*x**4*y - 214.636320318813*x**2*y*z**2 + 35.7727200531355*y*z**4) + g_11*(41.9472324772445*x**4*z - 83.8944649544891*x**2*z**3 + 8.38944649544891*z**5) + 2.0*g_2*y*(-71.5454401062709*x**3*z + 71.5454401062709*x*z**3) + g_3*(9.79676285820985*x**5 - 19.5935257164197*x**3*z**2 - 29.3902885746295*x*z**4 + 3.0*y**2*(-26.1247009552263*x**3 + 78.3741028656788*x*z**2)) + g_4*(-104.498803820905*x**3*y*z + x*(208.99760764181*y**3*z - 104.498803820905*y*z**3)) + g_5*(10.3266947761614*x**5 + x**3*(-123.920337313937*y**2 + 20.6533895523229*z**2) + x*(82.6135582092915*y**4 - 123.920337313937*y**2*z**2 + 10.3266947761614*z**4)) + g_6*(40.5624518489699*x**4*y + x**2*(-108.16653826392*y**3 + 81.1249036979398*y*z**2) + 21.6333076527839*y**5 - 108.16653826392*y**3*z**2 + 40.5624518489699*y*z**4) + g_7*(10.3266947761614*z**5 + z**3*(20.6533895523229*x**2 - 123.920337313937*y**2) + z*(10.3266947761614*x**4 - 123.920337313937*x**2*y**2 + 82.6135582092915*y**4)) + g_8*(52.2494019104525*x**4*y - 104.498803820905*x**2*y**3 + 104.498803820905*y**3*z**2 - 52.2494019104525*y*z**4) + g_9*(29.3902885746295*x**4*z + 19.5935257164197*x**2*z**3 + 3.0*y**2*(-78.3741028656788*x**2*z + 26.1247009552263*z**3) - 9.79676285820985*z**5)", + "z": "g_0*(14.5309475774982*x**5 - 145.309475774982*x**3*z**2 + 72.6547378874909*x*z**4) + g_1*y*(-167.788929908978*x**3*z + 167.788929908978*x*z**3) + g_10*(17.8863600265677*x**4*z + x**2*(-214.636320318813*y**2*z + 35.7727200531355*z**3) + 71.5454401062709*y**2*z**3 - 10.7318160159406*z**5) + g_11*y*(41.9472324772445*x**4 - 251.683394863467*x**2*z**2 + 41.9472324772445*z**4) + g_12*(72.6547378874909*x**4*z - 145.309475774982*x**2*z**3 + 14.5309475774982*z**5) + g_2*(7.15454401062709*x**5 - 35.7727200531355*x*z**4 + y**2*(-71.5454401062709*x**3 + 214.636320318813*x*z**2)) + g_3*(156.748205731358*x*y**3*z + y*(-39.1870514328394*x**3*z - 117.561154298518*x*z**3)) + g_4*(3.26558761940328*x**5 + x**3*(-52.2494019104525*y**2 + 19.5935257164197*z**2) + x*(52.2494019104525*y**4 - 156.748205731358*y**2*z**2 + 16.3279380970164*z**4)) + g_5*(41.3067791046458*x**3*y*z + x*(-82.6135582092915*y**3*z + 41.3067791046458*y*z**3)) + g_6*(-6.76040864149498*x**4*z + x**2*(81.1249036979398*y**2*z - 13.52081728299*z**3) - 54.0832691319598*y**4*z + 81.1249036979398*y**2*z**3 - 6.76040864149498*z**5) + g_7*(10.3266947761614*x**4*y - 41.3067791046458*x**2*y**3 + 16.5227116418583*y**5 + 51.6334738808072*y*z**4 + 3.0*z**2*(20.6533895523229*x**2*y - 41.3067791046458*y**3)) + g_8*(-3.26558761940328*x**4*z + 6.53117523880657*x**2*z**3 + 52.2494019104525*y**4*z - 104.498803820905*y**2*z**3 + 9.79676285820985*z**5) + g_9*(y**3*(-78.3741028656788*x**2 + 78.3741028656788*z**2) + y*(29.3902885746295*x**4 + 58.7805771492591*x**2*z**2 - 48.9838142910492*z**4))" + } +} \ No newline at end of file diff --git a/notebooks/direct_sph_harm/l_7.json b/notebooks/direct_sph_harm/l_7.json new file mode 100644 index 0000000..8169a0b --- /dev/null +++ b/notebooks/direct_sph_harm/l_7.json @@ -0,0 +1,24 @@ +{ + "fwd": [ + "-2.50682661696018*x**7 + 52.6433589561637*x**5*z**2 - 87.7389315936062*x**3*z**4 + 17.5477863187212*x*z**6", + "y*(56.2781179722634*x**5*z - 187.593726574211*x**3*z**3 + 56.2781179722634*x*z**5)", + "-1.83950783159518*x**7 + x**5*(22.0740939791422*y**2 + 16.5555704843566*z**2) + x**3*(-220.740939791422*y**2*z**2 + 9.1975391579759*z**4) + x*(110.370469895711*y**2*z**4 - 9.1975391579759*z**6)", + "y**3*(-147.160626527614*x**3*z + 147.160626527614*x*z**3) + y*(44.1481879582843*x**5*z - 44.1481879582843*x*z**5)", + "-1.66389743899677*x**7 + x**5*(33.2779487799353*y**2 + 1.66389743899677*z**2) + x**3*(-44.3705983732471*y**4 - 66.5558975598707*y**2*z**2 + 8.31948719498384*z**4) + x*(133.111795119741*y**4*z**2 - 99.833846339806*y**2*z**4 + 4.9916923169903*z**6)", + "23.5310632462709*x**5*y*z + x**3*(-125.499003980111*y**3*z + 47.0621264925418*y*z**3) + x*(75.2994023880668*y**5*z - 125.499003980111*y**3*z**3 + 23.5310632462709*y*z**5)", + "-1.60108605718119*x**7 + x**5*(38.4260653723485*y**2 - 4.80325817154356*z**2) + x**3*(-76.852130744697*y**4 + 76.852130744697*y**2*z**2 - 4.80325817154356*z**4) + x*(20.4939015319192*y**6 - 76.852130744697*y**4*z**2 + 38.4260653723485*y**2*z**4 - 1.60108605718119*z**6)", + "3.87298334620742*y**7 + y**5*(-40.6663251351779*x**2 - 40.6663251351779*z**2) + y**3*(50.8329064189723*x**4 + 101.665812837945*x**2*z**2 + 50.8329064189723*z**4) + y*(-8.47215106982872*x**6 - 25.4164532094862*x**4*z**2 - 25.4164532094862*x**2*z**4 - 8.47215106982872*z**6)", + "-1.60108605718119*z**7 + z**5*(-4.80325817154356*x**2 + 38.4260653723485*y**2) + z**3*(-4.80325817154356*x**4 + 76.852130744697*x**2*y**2 - 76.852130744697*y**4) + z*(-1.60108605718119*x**6 + 38.4260653723485*x**4*y**2 - 76.852130744697*x**2*y**4 + 20.4939015319192*y**6)", + "y**5*(-37.6497011940334*x**2 + 37.6497011940334*z**2) + y**3*(62.7495019900557*x**4 - 62.7495019900557*z**4) + y*(-11.7655316231354*x**6 - 11.7655316231354*x**4*z**2 + 11.7655316231354*x**2*z**4 + 11.7655316231354*z**6)", + "1.66389743899677*z**7 + z**5*(-1.66389743899677*x**2 - 33.2779487799353*y**2) + z**3*(-8.31948719498384*x**4 + 66.5558975598707*x**2*y**2 + 44.3705983732471*y**4) + z*(-4.9916923169903*x**6 + 99.833846339806*x**4*y**2 - 133.111795119741*x**2*y**4)", + "y**3*(36.7901566319036*x**4 - 220.740939791422*x**2*z**2 + 36.7901566319036*z**4) + y*(-11.0370469895711*x**6 + 55.1852349478554*x**4*z**2 + 55.1852349478554*x**2*z**4 - 11.0370469895711*z**6)", + "-1.83950783159518*z**7 + z**5*(16.5555704843566*x**2 + 22.0740939791422*y**2) + z**3*(9.1975391579759*x**4 - 220.740939791422*x**2*y**2) + z*(-9.1975391579759*x**6 + 110.370469895711*x**4*y**2)", + "y*(-9.37968632871057*x**6 + 140.695294930659*x**4*z**2 - 140.695294930659*x**2*z**4 + 9.37968632871057*z**6)", + "-17.5477863187212*x**6*z + 87.7389315936062*x**4*z**3 - 52.6433589561637*x**2*z**5 + 2.50682661696018*z**7" + ], + "bwd": { + "x": "g_0*(-17.5477863187212*x**6 + 263.216794780818*x**4*z**2 - 263.216794780819*x**2*z**4 + 17.5477863187212*z**6) + g_1*y*(281.390589861317*x**4*z - 562.781179722634*x**2*z**3 + 56.2781179722634*z**5) + g_10*(-3.32779487799353*x*z**5 + z**3*(-33.2779487799353*x**3 + 133.111795119741*x*y**2) + z*(-29.9501539019418*x**5 + 399.335385359224*x**3*y**2 - 266.223590239483*x*y**4)) + g_11*(y**3*(147.160626527614*x**3 - 441.481879582843*x*z**2) + y*(-66.2222819374265*x**5 + 220.740939791422*x**3*z**2 + 110.370469895711*x*z**4)) + g_12*(33.1111409687132*x*z**5 + z**3*(36.7901566319036*x**3 - 441.481879582843*x*y**2) + z*(-55.1852349478554*x**5 + 441.481879582843*x**3*y**2)) + g_13*y*(-56.2781179722634*x**5 + 562.781179722634*x**3*z**2 - 281.390589861317*x*z**4) + g_14*(-105.286717912327*x**5*z + 350.955726374425*x**3*z**3 - 105.286717912327*x*z**5) + g_2*(-12.8765548211663*x**6 + 5.0*x**4*(22.0740939791422*y**2 + 16.5555704843566*z**2) + 3.0*x**2*(-220.740939791422*y**2*z**2 + 9.1975391579759*z**4) + 110.370469895711*y**2*z**4 - 9.1975391579759*z**6) + g_3*(y**3*(-441.481879582843*x**2*z + 147.160626527614*z**3) + y*(220.740939791422*x**4*z - 44.1481879582843*z**5)) + g_4*(-11.6472820729774*x**6 + 5.0*x**4*(33.2779487799353*y**2 + 1.66389743899677*z**2) + 3.0*x**2*(-44.3705983732471*y**4 - 66.5558975598707*y**2*z**2 + 8.31948719498384*z**4) + 133.111795119741*y**4*z**2 - 99.833846339806*y**2*z**4 + 4.9916923169903*z**6) + g_5*(117.655316231354*x**4*y*z + 3.0*x**2*(-125.499003980111*y**3*z + 47.0621264925418*y*z**3) + 75.2994023880668*y**5*z - 125.499003980111*y**3*z**3 + 23.5310632462709*y*z**5) + g_6*(-11.2076024002683*x**6 + 5.0*x**4*(38.4260653723485*y**2 - 4.80325817154356*z**2) + 3.0*x**2*(-76.852130744697*y**4 + 76.852130744697*y**2*z**2 - 4.80325817154356*z**4) + 20.4939015319192*y**6 - 76.852130744697*y**4*z**2 + 38.4260653723485*y**2*z**4 - 1.60108605718119*z**6) + g_7*(-81.3326502703558*x*y**5 + y**3*(203.331625675889*x**3 + 203.331625675889*x*z**2) + y*(-50.8329064189723*x**5 - 101.665812837945*x**3*z**2 - 50.8329064189723*x*z**4)) + g_8*(-9.60651634308713*x*z**5 + z**3*(-19.2130326861743*x**3 + 153.704261489394*x*y**2) + z*(-9.60651634308712*x**5 + 153.704261489394*x**3*y**2 - 153.704261489394*x*y**4)) + g_9*(250.998007960223*x**3*y**3 - 75.2994023880668*x*y**5 + y*(-70.5931897388126*x**5 - 47.0621264925418*x**3*z**2 + 23.5310632462709*x*z**4))", + "y": "g_1*(56.2781179722634*x**5*z - 187.593726574211*x**3*z**3 + 56.2781179722634*x*z**5) + g_10*(-66.5558975598707*y*z**5 + z**3*(133.111795119741*x**2*y + 177.482393492989*y**3) + z*(199.667692679612*x**4*y - 532.447180478965*x**2*y**3)) + g_11*(-11.0370469895711*x**6 + 55.1852349478554*x**4*z**2 + 55.1852349478554*x**2*z**4 + 3.0*y**2*(36.7901566319036*x**4 - 220.740939791422*x**2*z**2 + 36.7901566319036*z**4) - 11.0370469895711*z**6) + g_12*(220.740939791422*x**4*y*z - 441.481879582843*x**2*y*z**3 + 44.1481879582843*y*z**5) + g_13*(-9.37968632871057*x**6 + 140.695294930659*x**4*z**2 - 140.695294930659*x**2*z**4 + 9.37968632871057*z**6) + g_2*(44.1481879582843*x**5*y - 441.481879582843*x**3*y*z**2 + 220.740939791422*x*y*z**4) + g_3*(44.1481879582843*x**5*z - 44.1481879582843*x*z**5 + 3.0*y**2*(-147.160626527614*x**3*z + 147.160626527614*x*z**3)) + g_4*(66.5558975598707*x**5*y + x**3*(-177.482393492989*y**3 - 133.111795119741*y*z**2) + x*(532.447180478965*y**3*z**2 - 199.667692679612*y*z**4)) + g_5*(23.5310632462709*x**5*z + x**3*(-376.497011940334*y**2*z + 47.0621264925418*z**3) + x*(376.497011940334*y**4*z - 376.497011940334*y**2*z**3 + 23.5310632462709*z**5)) + g_6*(76.852130744697*x**5*y + x**3*(-307.408522978788*y**3 + 153.704261489394*y*z**2) + x*(122.963409191515*y**5 - 307.408522978788*y**3*z**2 + 76.852130744697*y*z**4)) + g_7*(-8.47215106982872*x**6 - 25.4164532094862*x**4*z**2 - 25.4164532094862*x**2*z**4 + 27.1108834234519*y**6 + 5.0*y**4*(-40.6663251351779*x**2 - 40.6663251351779*z**2) + 3.0*y**2*(50.8329064189723*x**4 + 101.665812837945*x**2*z**2 + 50.8329064189723*z**4) - 8.47215106982872*z**6) + g_8*(76.852130744697*y*z**5 + z**3*(153.704261489394*x**2*y - 307.408522978788*y**3) + z*(76.852130744697*x**4*y - 307.408522978788*x**2*y**3 + 122.963409191515*y**5)) + g_9*(-11.7655316231354*x**6 - 11.7655316231354*x**4*z**2 + 11.7655316231354*x**2*z**4 + 5.0*y**4*(-37.6497011940334*x**2 + 37.6497011940334*z**2) + 3.0*y**2*(62.7495019900557*x**4 - 62.7495019900557*z**4) + 11.7655316231354*z**6)", + "z": "g_0*(105.286717912327*x**5*z - 350.955726374425*x**3*z**3 + 105.286717912327*x*z**5) + g_1*y*(56.2781179722634*x**5 - 562.781179722634*x**3*z**2 + 281.390589861317*x*z**4) + g_10*(-4.9916923169903*x**6 + 99.833846339806*x**4*y**2 - 133.111795119741*x**2*y**4 + 11.6472820729774*z**6 + 5.0*z**4*(-1.66389743899677*x**2 - 33.2779487799353*y**2) + 3.0*z**2*(-8.31948719498384*x**4 + 66.5558975598707*x**2*y**2 + 44.3705983732471*y**4)) + g_11*(y**3*(-441.481879582843*x**2*z + 147.160626527614*z**3) + y*(110.370469895711*x**4*z + 220.740939791422*x**2*z**3 - 66.2222819374265*z**5)) + g_12*(-9.1975391579759*x**6 + 110.370469895711*x**4*y**2 - 12.8765548211663*z**6 + 5.0*z**4*(16.5555704843566*x**2 + 22.0740939791422*y**2) + 3.0*z**2*(9.1975391579759*x**4 - 220.740939791422*x**2*y**2)) + g_13*y*(281.390589861317*x**4*z - 562.781179722634*x**2*z**3 + 56.2781179722634*z**5) + g_14*(-17.5477863187212*x**6 + 263.216794780819*x**4*z**2 - 263.216794780818*x**2*z**4 + 17.5477863187212*z**6) + g_2*(33.1111409687132*x**5*z + x**3*(-441.481879582843*y**2*z + 36.7901566319036*z**3) + x*(441.481879582843*y**2*z**3 - 55.1852349478554*z**5)) + g_3*(y**3*(-147.160626527614*x**3 + 441.481879582843*x*z**2) + y*(44.1481879582843*x**5 - 220.740939791422*x*z**4)) + g_4*(3.32779487799353*x**5*z + x**3*(-133.111795119741*y**2*z + 33.2779487799353*z**3) + x*(266.223590239483*y**4*z - 399.335385359224*y**2*z**3 + 29.9501539019418*z**5)) + g_5*(23.5310632462709*x**5*y + x**3*(-125.499003980111*y**3 + 141.186379477625*y*z**2) + x*(75.2994023880668*y**5 - 376.497011940334*y**3*z**2 + 117.655316231354*y*z**4)) + g_6*(-9.60651634308712*x**5*z + x**3*(153.704261489394*y**2*z - 19.2130326861743*z**3) + x*(-153.704261489394*y**4*z + 153.704261489394*y**2*z**3 - 9.60651634308713*z**5)) + g_7*(-81.3326502703557*y**5*z + y**3*(203.331625675889*x**2*z + 203.331625675889*z**3) + y*(-50.8329064189724*x**4*z - 101.665812837945*x**2*z**3 - 50.8329064189723*z**5)) + g_8*(-1.60108605718119*x**6 + 38.4260653723485*x**4*y**2 - 76.852130744697*x**2*y**4 + 20.4939015319192*y**6 - 11.2076024002683*z**6 + 5.0*z**4*(-4.80325817154356*x**2 + 38.4260653723485*y**2) + 3.0*z**2*(-4.80325817154356*x**4 + 76.852130744697*x**2*y**2 - 76.852130744697*y**4)) + g_9*(75.2994023880668*y**5*z - 250.998007960223*y**3*z**3 + y*(-23.5310632462709*x**4*z + 47.0621264925417*x**2*z**3 + 70.5931897388126*z**5))" + } +} \ No newline at end of file diff --git a/notebooks/direct_sph_harm/l_8.json b/notebooks/direct_sph_harm/l_8.json new file mode 100644 index 0000000..43e17a8 --- /dev/null +++ b/notebooks/direct_sph_harm/l_8.json @@ -0,0 +1,26 @@ +{ + "fwd": [ + "-20.6718218536732*x**7*z + 144.702752975712*x**5*z**3 - 144.702752975712*x**3*z**5 + 20.6718218536732*x*z**7", + "y*(-10.3359109268366*x**7 + 217.054129463568*x**5*z**2 - 361.756882439281*x**3*z**4 + 72.3513764878561*x*z**6)", + "-11.3224231339851*x**7*z + x**5*(158.513923875791*y**2*z + 26.4189873126318*z**3) + x**3*(-528.379746252636*y**2*z**3 + 26.4189873126318*z**5) + x*(158.513923875791*y**2*z**5 - 11.3224231339851*z**7)", + "y**3*(48.9184589393411*x**5 - 489.184589393411*x**3*z**2 + 244.592294696706*x*z**4) + y*(-12.2296147348353*x**7 + 110.066532613517*x**5*z**2 + 61.1480736741764*x**3*z**4 - 61.1480736741764*x*z**6)", + "-6.78376969317208*x**7*z + x**5*(162.81047263613*y**2*z - 6.78376969317208*z**3) + x**3*(-271.350787726883*y**4*z + 6.78376969317208*z**5) + x*(271.350787726883*y**4*z**3 - 162.81047263613*y**2*z**5 + 6.78376969317208*z**7)", + "y**5*(-70.0624721230988*x**3 + 210.187416369296*x*z**2) + y**3*(87.5780901538735*x**5 - 175.156180307747*x**3*z**2 - 262.734270461621*x*z**4) + y*(-13.136713523081*x**7 + 13.136713523081*x**5*z**2 + 65.6835676154051*x**3*z**4 + 39.4101405692431*x*z**6)", + "-3.23403530824881*x**7*z + x**5*(97.0210592474644*y**2*z - 9.70210592474644*z**3) + x**3*(-258.722824659905*y**4*z + 194.042118494929*y**2*z**3 - 9.70210592474644*z**5) + x*(103.489129863962*y**6*z - 258.722824659905*y**4*z**3 + 97.0210592474644*y**2*z**5 - 3.23403530824881*z**7)", + "-13.5289403340579*x**7*y + x**5*(108.231522672464*y**3 - 40.5868210021738*y*z**2) + x**3*(-129.877827206956*y**5 + 216.463045344927*y**3*z**2 - 40.5868210021738*y*z**4) + x*(24.738633753706*y**7 - 129.877827206956*y**5*z**2 + 108.231522672464*y**3*z**4 - 13.5289403340579*y*z**6)", + "1.12741169450483*x**8 + x**6*(-36.0771742241545*y**2 + 4.50964677801932*z**2) + x**4*(108.231522672464*y**4 - 108.231522672464*y**2*z**2 + 6.76447016702898*z**4) + x**2*(-57.7234787586472*y**6 + 216.463045344927*y**4*z**2 - 108.231522672464*y**2*z**4 + 4.50964677801932*z**6) + 4.12310562561766*y**8 - 57.7234787586472*y**6*z**2 + 108.231522672464*y**4*z**4 - 36.0771742241545*y**2*z**6 + 1.12741169450483*z**8", + "-13.5289403340579*y*z**7 + z**5*(-40.5868210021738*x**2*y + 108.231522672464*y**3) + z**3*(-40.5868210021738*x**4*y + 216.463045344927*x**2*y**3 - 129.877827206956*y**5) + z*(-13.5289403340579*x**6*y + 108.231522672464*x**4*y**3 - 129.877827206956*x**2*y**5 + 24.738633753706*y**7)", + "1.61701765412441*x**8 + 3.23403530824881*x**6*z**2 - 3.23403530824881*x**2*z**6 + y**6*(-51.744564931981*x**2 + 51.744564931981*z**2) + y**4*(129.361412329953*x**4 - 129.361412329953*z**4) + y**2*(-48.5105296237322*x**6 - 48.5105296237322*x**4*z**2 + 48.5105296237322*x**2*z**4 + 48.5105296237322*z**6) - 1.61701765412441*z**8", + "y**5*(-210.187416369296*x**2*z + 70.0624721230988*z**3) + y**3*(262.734270461621*x**4*z + 175.156180307747*x**2*z**3 - 87.5780901538735*z**5) + y*(-39.4101405692431*x**6*z - 65.6835676154052*x**4*z**3 - 13.136713523081*x**2*z**5 + 13.136713523081*z**7)", + "1.69594242329302*x**8 - 6.78376969317208*x**6*z**2 - 16.9594242329302*x**4*z**4 - 6.78376969317208*x**2*z**6 + y**4*(67.8376969317208*x**4 - 407.026181590325*x**2*z**2 + 67.8376969317208*z**4) + y**2*(-40.7026181590325*x**6 + 203.513090795162*x**4*z**2 + 203.513090795162*x**2*z**4 - 40.7026181590325*z**6) + 1.69594242329302*z**8", + "y**3*(244.592294696706*x**4*z - 489.184589393411*x**2*z**3 + 48.9184589393411*z**5) + y*(-61.1480736741764*x**6*z + 61.1480736741764*x**4*z**3 + 110.066532613517*x**2*z**5 - 12.2296147348353*z**7)", + "1.88707052233084*x**8 - 26.4189873126318*x**6*z**2 + 26.4189873126318*x**2*z**6 + y**2*(-26.4189873126318*x**6 + 396.284809689477*x**4*z**2 - 396.284809689477*x**2*z**4 + 26.4189873126318*z**6) - 1.88707052233084*z**8", + "y*(-72.3513764878561*x**6*z + 361.756882439281*x**4*z**3 - 217.054129463568*x**2*z**5 + 10.3359109268366*z**7)", + "2.58397773170915*x**8 - 72.3513764878561*x**6*z**2 + 180.87844121964*x**4*z**4 - 72.3513764878561*x**2*z**6 + 2.58397773170915*z**8" + ], + "bwd": { + "x": "g_0*(-144.702752975712*x**6*z + 723.513764878561*x**4*z**3 - 434.108258927137*x**2*z**5 + 20.6718218536732*z**7) + g_1*y*(-72.3513764878561*x**6 + 1085.27064731784*x**4*z**2 - 1085.27064731784*x**2*z**4 + 72.3513764878561*z**6) + g_10*(12.9361412329953*x**7 + 19.4042118494929*x**5*z**2 + 517.44564931981*x**3*y**4 - 103.489129863962*x*y**6 - 6.46807061649763*x*z**6 + y**2*(-291.063177742393*x**5 - 194.042118494929*x**3*z**2 + 97.0210592474644*x*z**4)) + g_11*(-420.374832738593*x*y**5*z + y**3*(1050.93708184648*x**3*z + 350.312360615494*x*z**3) + y*(-236.460843415458*x**5*z - 262.734270461621*x**3*z**3 - 26.2734270461621*x*z**5)) + g_12*(13.5675393863442*x**7 - 40.7026181590325*x**5*z**2 - 67.8376969317208*x**3*z**4 - 13.5675393863442*x*z**6 + y**4*(271.350787726883*x**3 - 814.05236318065*x*z**2) + y**2*(-244.215708954195*x**5 + 814.05236318065*x**3*z**2 + 407.026181590325*x*z**4)) + g_13*(y**3*(978.369178786822*x**3*z - 978.369178786822*x*z**3) + y*(-366.888442045058*x**5*z + 244.592294696705*x**3*z**3 + 220.133065227035*x*z**5)) + g_14*(15.0965641786467*x**7 - 158.513923875791*x**5*z**2 + 52.8379746252636*x*z**6 + y**2*(-158.513923875791*x**5 + 1585.13923875791*x**3*z**2 - 792.569619378954*x*z**4)) + g_15*y*(-434.108258927137*x**5*z + 1447.02752975712*x**3*z**3 - 434.108258927137*x*z**5) + g_16*(20.6718218536732*x**7 - 434.108258927137*x**5*z**2 + 723.513764878561*x**3*z**4 - 144.702752975712*x*z**6) + g_2*(-79.2569619378954*x**6*z + 5.0*x**4*(158.513923875791*y**2*z + 26.4189873126318*z**3) + 3.0*x**2*(-528.379746252636*y**2*z**3 + 26.4189873126318*z**5) + 158.513923875791*y**2*z**5 - 11.3224231339851*z**7) + g_3*(y**3*(244.592294696706*x**4 - 1467.55376818023*x**2*z**2 + 244.592294696706*z**4) + y*(-85.6073031438469*x**6 + 550.332663067587*x**4*z**2 + 183.444221022529*x**2*z**4 - 61.1480736741764*z**6)) + g_4*(-47.4863878522046*x**6*z + 5.0*x**4*(162.81047263613*y**2*z - 6.78376969317208*z**3) + 3.0*x**2*(-271.350787726883*y**4*z + 6.78376969317208*z**5) + 271.350787726883*y**4*z**3 - 162.81047263613*y**2*z**5 + 6.78376969317208*z**7) + g_5*(y**5*(-210.187416369296*x**2 + 210.187416369296*z**2) + y**3*(437.890450769368*x**4 - 525.468540923241*x**2*z**2 - 262.734270461621*z**4) + y*(-91.9569946615672*x**6 + 65.6835676154051*x**4*z**2 + 197.050702846215*x**2*z**4 + 39.4101405692431*z**6)) + g_6*(-22.6382471577417*x**6*z + 5.0*x**4*(97.0210592474644*y**2*z - 9.70210592474644*z**3) + 3.0*x**2*(-258.722824659905*y**4*z + 194.042118494929*y**2*z**3 - 9.70210592474644*z**5) + 103.489129863962*y**6*z - 258.722824659905*y**4*z**3 + 97.0210592474644*y**2*z**5 - 3.23403530824881*z**7) + g_7*(-94.7025823384056*x**6*y + 5.0*x**4*(108.231522672464*y**3 - 40.5868210021738*y*z**2) + 3.0*x**2*(-129.877827206956*y**5 + 216.463045344927*y**3*z**2 - 40.5868210021738*y*z**4) + 24.738633753706*y**7 - 129.877827206956*y**5*z**2 + 108.231522672464*y**3*z**4 - 13.5289403340579*y*z**6) + g_8*(9.01929355603863*x**7 + 6.0*x**5*(-36.0771742241545*y**2 + 4.50964677801932*z**2) + 4.0*x**3*(108.231522672464*y**4 - 108.231522672464*y**2*z**2 + 6.76447016702898*z**4) + 2.0*x*(-57.7234787586472*y**6 + 216.463045344927*y**4*z**2 - 108.231522672464*y**2*z**4 + 4.50964677801932*z**6)) + g_9*(-81.1736420043477*x*y*z**5 + z**3*(-162.347284008695*x**3*y + 432.926090689854*x*y**3) + z*(-81.1736420043477*x**5*y + 432.926090689854*x**3*y**3 - 259.755654413913*x*y**5))", + "y": "g_1*(-10.3359109268366*x**7 + 217.054129463568*x**5*z**2 - 361.756882439281*x**3*z**4 + 72.3513764878561*x*z**6) + g_10*(6.0*y**5*(-51.744564931981*x**2 + 51.744564931981*z**2) + 4.0*y**3*(129.361412329953*x**4 - 129.361412329953*z**4) + 2.0*y*(-48.5105296237322*x**6 - 48.5105296237322*x**4*z**2 + 48.5105296237322*x**2*z**4 + 48.5105296237322*z**6)) + g_11*(-39.4101405692431*x**6*z - 65.6835676154052*x**4*z**3 - 13.136713523081*x**2*z**5 + 5.0*y**4*(-210.187416369296*x**2*z + 70.0624721230988*z**3) + 3.0*y**2*(262.734270461621*x**4*z + 175.156180307747*x**2*z**3 - 87.5780901538735*z**5) + 13.136713523081*z**7) + g_12*(4.0*y**3*(67.8376969317208*x**4 - 407.026181590325*x**2*z**2 + 67.8376969317208*z**4) + 2.0*y*(-40.7026181590325*x**6 + 203.513090795162*x**4*z**2 + 203.513090795162*x**2*z**4 - 40.7026181590325*z**6)) + g_13*(-61.1480736741764*x**6*z + 61.1480736741764*x**4*z**3 + 110.066532613517*x**2*z**5 + 3.0*y**2*(244.592294696706*x**4*z - 489.184589393411*x**2*z**3 + 48.9184589393411*z**5) - 12.2296147348353*z**7) + 2.0*g_14*y*(-26.4189873126318*x**6 + 396.284809689477*x**4*z**2 - 396.284809689477*x**2*z**4 + 26.4189873126318*z**6) + g_15*(-72.3513764878561*x**6*z + 361.756882439281*x**4*z**3 - 217.054129463568*x**2*z**5 + 10.3359109268366*z**7) + g_2*(317.027847751582*x**5*y*z - 1056.75949250527*x**3*y*z**3 + 317.027847751582*x*y*z**5) + g_3*(-12.2296147348353*x**7 + 110.066532613517*x**5*z**2 + 61.1480736741764*x**3*z**4 - 61.1480736741764*x*z**6 + 3.0*y**2*(48.9184589393411*x**5 - 489.184589393411*x**3*z**2 + 244.592294696706*x*z**4)) + g_4*(325.62094527226*x**5*y*z - 1085.40315090753*x**3*y**3*z + x*(1085.40315090753*y**3*z**3 - 325.62094527226*y*z**5)) + g_5*(-13.136713523081*x**7 + 13.136713523081*x**5*z**2 + 65.6835676154051*x**3*z**4 + 39.4101405692431*x*z**6 + 5.0*y**4*(-70.0624721230988*x**3 + 210.187416369296*x*z**2) + 3.0*y**2*(87.5780901538735*x**5 - 175.156180307747*x**3*z**2 - 262.734270461621*x*z**4)) + g_6*(194.042118494929*x**5*y*z + x**3*(-1034.89129863962*y**3*z + 388.084236989858*y*z**3) + x*(620.934779183772*y**5*z - 1034.89129863962*y**3*z**3 + 194.042118494929*y*z**5)) + g_7*(-13.5289403340579*x**7 + x**5*(324.694568017391*y**2 - 40.5868210021738*z**2) + x**3*(-649.389136034782*y**4 + 649.389136034782*y**2*z**2 - 40.5868210021738*z**4) + x*(173.170436275942*y**6 - 649.389136034782*y**4*z**2 + 324.694568017391*y**2*z**4 - 13.5289403340579*z**6)) + g_8*(-72.1543484483091*x**6*y + x**4*(432.926090689854*y**3 - 216.463045344927*y*z**2) + x**2*(-346.340872551883*y**5 + 865.852181379709*y**3*z**2 - 216.463045344927*y*z**4) + 32.9848450049413*y**7 - 346.340872551883*y**5*z**2 + 432.926090689854*y**3*z**4 - 72.1543484483091*y*z**6) + g_9*(-13.5289403340579*z**7 + z**5*(-40.5868210021738*x**2 + 324.694568017391*y**2) + z**3*(-40.5868210021738*x**4 + 649.389136034781*x**2*y**2 - 649.389136034782*y**4) + z*(-13.5289403340579*x**6 + 324.694568017391*x**4*y**2 - 649.389136034782*x**2*y**4 + 173.170436275942*y**6))", + "z": "g_0*(-20.6718218536732*x**7 + 434.108258927137*x**5*z**2 - 723.513764878561*x**3*z**4 + 144.702752975712*x*z**6) + g_1*y*(434.108258927137*x**5*z - 1447.02752975712*x**3*z**3 + 434.108258927137*x*z**5) + g_10*(6.46807061649763*x**6*z - 19.4042118494929*x**2*z**5 + 103.489129863962*y**6*z - 517.44564931981*y**4*z**3 + y**2*(-97.0210592474644*x**4*z + 194.042118494929*x**2*z**3 + 291.063177742393*z**5) - 12.9361412329953*z**7) + g_11*(y**5*(-210.187416369296*x**2 + 210.187416369296*z**2) + y**3*(262.734270461621*x**4 + 525.468540923241*x**2*z**2 - 437.890450769368*z**4) + y*(-39.4101405692431*x**6 - 197.050702846215*x**4*z**2 - 65.6835676154052*x**2*z**4 + 91.9569946615672*z**6)) + g_12*(-13.5675393863442*x**6*z - 67.8376969317208*x**4*z**3 - 40.7026181590325*x**2*z**5 + y**4*(-814.05236318065*x**2*z + 271.350787726883*z**3) + y**2*(407.026181590325*x**4*z + 814.05236318065*x**2*z**3 - 244.215708954195*z**5) + 13.5675393863442*z**7) + g_13*(y**3*(244.592294696706*x**4 - 1467.55376818023*x**2*z**2 + 244.592294696706*z**4) + y*(-61.1480736741764*x**6 + 183.444221022529*x**4*z**2 + 550.332663067587*x**2*z**4 - 85.6073031438469*z**6)) + g_14*(-52.8379746252636*x**6*z + 158.513923875791*x**2*z**5 + y**2*(792.569619378954*x**4*z - 1585.13923875791*x**2*z**3 + 158.513923875791*z**5) - 15.0965641786467*z**7) + g_15*y*(-72.3513764878561*x**6 + 1085.27064731784*x**4*z**2 - 1085.27064731784*x**2*z**4 + 72.3513764878561*z**6) + g_16*(-144.702752975712*x**6*z + 723.513764878561*x**4*z**3 - 434.108258927137*x**2*z**5 + 20.6718218536732*z**7) + g_2*(-11.3224231339851*x**7 + x**5*(158.513923875791*y**2 + 79.2569619378954*z**2) + x**3*(-1585.13923875791*y**2*z**2 + 132.094936563159*z**4) + x*(792.569619378954*y**2*z**4 - 79.2569619378954*z**6)) + g_3*(y**3*(-978.369178786822*x**3*z + 978.369178786822*x*z**3) + y*(220.133065227035*x**5*z + 244.592294696706*x**3*z**3 - 366.888442045058*x*z**5)) + g_4*(-6.78376969317208*x**7 + x**5*(162.81047263613*y**2 - 20.3513090795162*z**2) + x**3*(-271.350787726883*y**4 + 33.9188484658604*z**4) + x*(814.05236318065*y**4*z**2 - 814.05236318065*y**2*z**4 + 47.4863878522046*z**6)) + g_5*(420.374832738593*x*y**5*z + y**3*(-350.312360615494*x**3*z - 1050.93708184648*x*z**3) + y*(26.2734270461621*x**5*z + 262.734270461621*x**3*z**3 + 236.460843415458*x*z**5)) + g_6*(-3.23403530824881*x**7 + x**5*(97.0210592474644*y**2 - 29.1063177742393*z**2) + x**3*(-258.722824659905*y**4 + 582.126355484786*y**2*z**2 - 48.5105296237322*z**4) + x*(103.489129863962*y**6 - 776.168473979715*y**4*z**2 + 485.105296237322*y**2*z**4 - 22.6382471577417*z**6)) + g_7*(-81.1736420043477*x**5*y*z + x**3*(432.926090689854*y**3*z - 162.347284008695*y*z**3) + x*(-259.755654413913*y**5*z + 432.926090689854*y**3*z**3 - 81.1736420043477*y*z**5)) + g_8*(9.01929355603863*x**6*z + x**4*(-216.463045344927*y**2*z + 27.0578806681159*z**3) + x**2*(432.926090689854*y**4*z - 432.926090689854*y**2*z**3 + 27.0578806681159*z**5) - 115.446957517294*y**6*z + 432.926090689854*y**4*z**3 - 216.463045344927*y**2*z**5 + 9.01929355603863*z**7) + g_9*(-13.5289403340579*x**6*y + 108.231522672464*x**4*y**3 - 129.877827206956*x**2*y**5 + 24.738633753706*y**7 - 94.7025823384056*y*z**6 + 5.0*z**4*(-40.5868210021738*x**2*y + 108.231522672464*y**3) + 3.0*z**2*(-40.5868210021738*x**4*y + 216.463045344927*x**2*y**3 - 129.877827206956*y**5))" + } +} \ No newline at end of file diff --git a/notebooks/direct_sph_harm/l_9.json b/notebooks/direct_sph_harm/l_9.json new file mode 100644 index 0000000..a3ca649 --- /dev/null +++ b/notebooks/direct_sph_harm/l_9.json @@ -0,0 +1,28 @@ +{ + "fwd": [ + "2.65478475211798*x**9 - 95.5722510762473*x**7*z**2 + 334.502878766866*x**5*z**4 - 223.00191917791*x**3*z**6 + 23.8930627690618*x*z**8", + "y*(-90.106382439037*x**7*z + 630.744677073259*x**5*z**3 - 630.744677073259*x**3*z**5 + 90.106382439037*x*z**7)", + "1.93163963757558*x**9 + x**7*(-30.9062342012093*y**2 - 38.6327927515116*z**2) + x**5*(649.030918225395*y**2*z**2 + 27.0429549260581*z**4) + x**3*(-1081.71819704233*y**2*z**4 + 54.0859098521163*z**6) + x*(216.343639408465*y**2*z**6 - 13.5214774630291*z**8)", + "y**3*(374.718175349822*x**5*z - 1249.06058449941*x**3*z**3 + 374.718175349822*x*z**5) + y*(-80.2967518606762*x**7*z + 187.359087674911*x**5*z**3 + 187.359087674911*x**3*z**5 - 80.2967518606762*x*z**7)", + "1.72771101506082*x**9 - 13.8216881204866*x**7*z**2 - 24.1879542108515*x**5*z**4 + 8.63855507530412*x*z**8 + y**4*(96.7518168434061*x**5 - 967.518168434061*x**3*z**2 + 483.759084217031*x*z**4) + y**2*(-48.3759084217031*x**7 + 435.383175795327*x**5*z**2 + 241.879542108515*x**3*z**4 - 241.879542108515*x*z**6)", + "y**5*(-462.562157985281*x**3*z + 462.562157985281*x*z**3) + y**3*(462.562157985281*x**5*z - 462.562157985281*x*z**5) + y*(-57.8202697481601*x**7*z - 57.8202697481601*x**5*z**3 + 57.8202697481601*x**3*z**5 + 57.8202697481601*x*z**7)", + "1.63671408859718*x**9 - 58.9217071894985*x**7*y**2 + x**5*(196.405690631662*y**4 + 58.9217071894985*y**2*z**2 - 9.82028453158308*z**4) + x**3*(-104.74970167022*y**6 - 392.811381263323*y**4*z**2 + 294.608535947493*y**2*z**4 - 13.0937127087774*z**6) + x*(314.249105010659*y**6*z**2 - 589.217071894985*y**4*z**4 + 176.765121568496*y**2*z**6 - 4.91014226579154*z**8)", + "-30.001464807989*x**7*y*z + x**5*(300.01464807989*y**3*z - 90.0043944239669*y*z**3) + x**3*(-480.023436927823*y**5*z + 600.029296159779*y**3*z**3 - 90.0043944239669*y*z**5) + x*(137.14955340795*y**7*z - 480.023436927823*y**5*z**3 + 300.01464807989*y**3*z**5 - 30.001464807989*y*z**7)", + "1.59908344719522*x**9 + x**7*(-63.9633378878088*y**2 + 6.39633378878088*z**2) + x**5*(255.853351551235*y**4 - 191.890013663426*y**2*z**2 + 9.59450068317133*z**4) + x**3*(-204.682681240988*y**6 + 511.706703102471*y**4*z**2 - 191.890013663426*y**2*z**4 + 6.39633378878088*z**6) + x*(29.2403830344269*y**8 - 204.682681240988*y**6*z**2 + 255.853351551235*y**4*z**4 - 63.9633378878088*y**2*z**6 + 1.59908344719522*z**8)", + "4.35889894354067*y**9 + y**7*(-78.4601809837321*x**2 - 78.4601809837321*z**2) + y**5*(205.957975082297*x**4 + 411.915950164594*x**2*z**2 + 205.957975082297*z**4) + y**3*(-114.421097267943*x**6 - 343.263291803828*x**4*z**2 - 343.263291803828*x**2*z**4 - 114.421097267943*z**6) + y*(10.7269778688696*x**8 + 42.9079114754785*x**6*z**2 + 64.3618672132178*x**4*z**4 + 42.9079114754785*x**2*z**6 + 10.7269778688696*z**8)", + "1.59908344719522*z**9 + z**7*(6.39633378878088*x**2 - 63.9633378878088*y**2) + z**5*(9.59450068317133*x**4 - 191.890013663427*x**2*y**2 + 255.853351551235*y**4) + z**3*(6.39633378878088*x**6 - 191.890013663426*x**4*y**2 + 511.706703102471*x**2*y**4 - 204.682681240988*y**6) + z*(1.59908344719522*x**8 - 63.9633378878088*x**6*y**2 + 255.853351551235*x**4*y**4 - 204.682681240988*x**2*y**6 + 29.2403830344269*y**8)", + "y**7*(-68.5747767039748*x**2 + 68.5747767039748*z**2) + y**5*(240.011718463912*x**4 - 240.011718463912*z**4) + y**3*(-150.007324039945*x**6 - 150.007324039945*x**4*z**2 + 150.007324039945*x**2*z**4 + 150.007324039945*z**6) + y*(15.0007324039945*x**8 + 30.001464807989*x**6*z**2 - 30.001464807989*x**2*z**6 - 15.0007324039945*z**8)", + "58.9217071894985*y**2*z**7 - 1.63671408859718*z**9 + z**5*(9.82028453158308*x**4 - 58.9217071894985*x**2*y**2 - 196.405690631662*y**4) + z**3*(13.0937127087774*x**6 - 294.608535947493*x**4*y**2 + 392.811381263323*x**2*y**4 + 104.74970167022*y**6) + z*(4.91014226579154*x**8 - 176.765121568496*x**6*y**2 + 589.217071894985*x**4*y**4 - 314.249105010659*x**2*y**6)", + "y**5*(115.64053949632*x**4 - 693.843236977922*x**2*z**2 + 115.64053949632*z**4) + y**3*(-115.64053949632*x**6 + 578.202697481601*x**4*z**2 + 578.202697481601*x**2*z**4 - 115.64053949632*z**6) + y*(14.45506743704*x**8 - 57.8202697481601*x**6*z**2 - 144.5506743704*x**4*z**4 - 57.8202697481601*x**2*z**6 + 14.45506743704*z**8)", + "8.63855507530412*x**8*z - 24.1879542108515*x**4*z**5 - 13.8216881204866*x**2*z**7 + y**4*(483.759084217031*x**4*z - 967.518168434061*x**2*z**3 + 96.7518168434061*z**5) + y**2*(-241.879542108515*x**6*z + 241.879542108515*x**4*z**3 + 435.383175795328*x**2*z**5 - 48.375908421703*z**7) + 1.72771101506082*z**9", + "y**3*(-62.4530292249704*x**6 + 936.795438374555*x**4*z**2 - 936.795438374555*x**2*z**4 + 62.4530292249704*z**6) + y*(13.3827919767794*x**8 - 187.359087674911*x**6*z**2 + 187.359087674911*x**2*z**6 - 13.3827919767794*z**8)", + "13.5214774630291*x**8*z - 54.0859098521163*x**6*z**3 - 27.0429549260581*x**4*z**5 + 38.6327927515116*x**2*z**7 + y**2*(-216.343639408465*x**6*z + 1081.71819704233*x**4*z**3 - 649.030918225395*x**2*z**5 + 30.9062342012093*z**7) - 1.93163963757558*z**9", + "y*(11.2632978048796*x**8 - 315.37233853663*x**6*z**2 + 788.430846341574*x**4*z**4 - 315.37233853663*x**2*z**6 + 11.2632978048796*z**8)", + "23.8930627690618*x**8*z - 223.00191917791*x**6*z**3 + 334.502878766866*x**4*z**5 - 95.5722510762473*x**2*z**7 + 2.65478475211798*z**9" + ], + "bwd": { + "x": "g_0*(23.8930627690618*x**8 - 669.005757533731*x**6*z**2 + 1672.51439383433*x**4*z**4 - 669.005757533731*x**2*z**6 + 23.8930627690618*z**8) + g_1*y*(-630.744677073259*x**6*z + 3153.7233853663*x**4*z**3 - 1892.23403121978*x**2*z**5 + 90.106382439037*z**7) + g_10*(12.7926675775618*x*z**7 + z**5*(38.3780027326853*x**3 - 383.780027326853*x*y**2) + z**3*(38.3780027326853*x**5 - 767.560054653706*x**3*y**2 + 1023.41340620494*x*y**4) + z*(12.7926675775618*x**7 - 383.780027326853*x**5*y**2 + 1023.41340620494*x**3*y**4 - 409.365362481977*x*y**6)) + g_11*(960.046873855647*x**3*y**5 - 137.14955340795*x*y**7 + y**3*(-900.043944239669*x**5 - 600.029296159779*x**3*z**2 + 300.01464807989*x*z**4) + y*(120.005859231956*x**7 + 180.008788847934*x**5*z**2 - 60.0029296159779*x*z**6)) + g_12*(z**5*(39.2811381263323*x**3 - 117.843414378997*x*y**2) + z**3*(78.5622762526647*x**5 - 1178.43414378997*x**3*y**2 + 785.622762526647*x*y**4) + z*(39.2811381263323*x**7 - 1060.59072941097*x**5*y**2 + 2356.86828757994*x**3*y**4 - 628.498210021318*x*y**6)) + g_13*(y**5*(462.562157985281*x**3 - 1387.68647395584*x*z**2) + y**3*(-693.843236977922*x**5 + 2312.81078992641*x**3*z**2 + 1156.4053949632*x*z**4) + y*(115.64053949632*x**7 - 346.921618488961*x**5*z**2 - 578.202697481601*x**3*z**4 - 115.64053949632*x*z**6)) + g_14*(69.1084406024329*x**7*z - 96.7518168434061*x**3*z**5 - 27.6433762409732*x*z**7 + y**4*(1935.03633686812*x**3*z - 1935.03633686812*x*z**3) + y**2*(-1451.27725265109*x**5*z + 967.518168434061*x**3*z**3 + 870.766351590655*x*z**5)) + g_15*(y**3*(-374.718175349822*x**5 + 3747.18175349822*x**3*z**2 - 1873.59087674911*x*z**4) + y*(107.062335814235*x**7 - 1124.15452604947*x**5*z**2 + 374.718175349822*x*z**6)) + g_16*(108.171819704233*x**7*z - 324.515459112698*x**5*z**3 - 108.171819704233*x**3*z**5 + 77.2655855030233*x*z**7 + y**2*(-1298.06183645079*x**5*z + 4326.8727881693*x**3*z**3 - 1298.06183645079*x*z**5)) + g_17*y*(90.106382439037*x**7 - 1892.23403121978*x**5*z**2 + 3153.7233853663*x**3*z**4 - 630.744677073259*x*z**6) + g_18*(191.144502152495*x**7*z - 1338.01151506746*x**5*z**3 + 1338.01151506746*x**3*z**5 - 191.144502152495*x*z**7) + g_2*(17.3847567381802*x**8 + 7.0*x**6*(-30.9062342012093*y**2 - 38.6327927515116*z**2) + 5.0*x**4*(649.030918225395*y**2*z**2 + 27.0429549260581*z**4) + 3.0*x**2*(-1081.71819704233*y**2*z**4 + 54.0859098521163*z**6) + 216.343639408465*y**2*z**6 - 13.5214774630291*z**8) + g_3*(y**3*(1873.59087674911*x**4*z - 3747.18175349822*x**2*z**3 + 374.718175349822*z**5) + y*(-562.077263024733*x**6*z + 936.795438374555*x**4*z**3 + 562.077263024733*x**2*z**5 - 80.2967518606762*z**7)) + g_4*(15.5493991355474*x**8 - 96.7518168434061*x**6*z**2 - 120.939771054258*x**4*z**4 + y**4*(483.759084217031*x**4 - 2902.55450530218*x**2*z**2 + 483.759084217031*z**4) + y**2*(-338.631358951921*x**6 + 2176.91587897664*x**4*z**2 + 725.638626325546*x**2*z**4 - 241.879542108515*z**6) + 8.63855507530412*z**8) + g_5*(y**5*(-1387.68647395584*x**2*z + 462.562157985281*z**3) + y**3*(2312.81078992641*x**4*z - 462.562157985281*z**5) + y*(-404.741888237121*x**6*z - 289.101348740801*x**4*z**3 + 173.46080924448*x**2*z**5 + 57.8202697481601*z**7)) + g_6*(14.7304267973746*x**8 - 412.45195032649*x**6*y**2 + 5.0*x**4*(196.405690631662*y**4 + 58.9217071894985*y**2*z**2 - 9.82028453158308*z**4) + 3.0*x**2*(-104.74970167022*y**6 - 392.811381263323*y**4*z**2 + 294.608535947493*y**2*z**4 - 13.0937127087774*z**6) + 314.249105010659*y**6*z**2 - 589.217071894985*y**4*z**4 + 176.765121568496*y**2*z**6 - 4.91014226579154*z**8) + g_7*(-210.010253655923*x**6*y*z + 5.0*x**4*(300.01464807989*y**3*z - 90.0043944239669*y*z**3) + 3.0*x**2*(-480.023436927823*y**5*z + 600.029296159779*y**3*z**3 - 90.0043944239669*y*z**5) + 137.14955340795*y**7*z - 480.023436927823*y**5*z**3 + 300.01464807989*y**3*z**5 - 30.001464807989*y*z**7) + g_8*(14.391751024757*x**8 + 7.0*x**6*(-63.9633378878088*y**2 + 6.39633378878088*z**2) + 5.0*x**4*(255.853351551235*y**4 - 191.890013663426*y**2*z**2 + 9.59450068317133*z**4) + 3.0*x**2*(-204.682681240988*y**6 + 511.706703102471*y**4*z**2 - 191.890013663426*y**2*z**4 + 6.39633378878088*z**6) + 29.2403830344269*y**8 - 204.682681240988*y**6*z**2 + 255.853351551235*y**4*z**4 - 63.9633378878088*y**2*z**6 + 1.59908344719522*z**8) + g_9*(-156.920361967464*x*y**7 + y**5*(823.831900329187*x**3 + 823.831900329187*x*z**2) + y**3*(-686.526583607656*x**5 - 1373.05316721531*x**3*z**2 - 686.526583607656*x*z**4) + y*(85.815822950957*x**7 + 257.447468852871*x**5*z**2 + 257.447468852871*x**3*z**4 + 85.815822950957*x*z**6))", + "y": "g_1*(-90.106382439037*x**7*z + 630.744677073259*x**5*z**3 - 630.744677073259*x**3*z**5 + 90.106382439037*x*z**7) + g_10*(-127.926675775618*y*z**7 + z**5*(-383.780027326853*x**2*y + 1023.41340620494*y**3) + z**3*(-383.780027326853*x**4*y + 2046.82681240988*x**2*y**3 - 1228.09608744593*y**5) + z*(-127.926675775618*x**6*y + 1023.41340620494*x**4*y**3 - 1228.09608744593*x**2*y**5 + 233.923064275415*y**7)) + g_11*(15.0007324039945*x**8 + 30.001464807989*x**6*z**2 - 30.001464807989*x**2*z**6 + 7.0*y**6*(-68.5747767039748*x**2 + 68.5747767039748*z**2) + 5.0*y**4*(240.011718463912*x**4 - 240.011718463912*z**4) + 3.0*y**2*(-150.007324039945*x**6 - 150.007324039945*x**4*z**2 + 150.007324039945*x**2*z**4 + 150.007324039945*z**6) - 15.0007324039945*z**8) + g_12*(117.843414378997*y*z**7 + z**5*(-117.843414378997*x**2*y - 785.622762526647*y**3) + z**3*(-589.217071894985*x**4*y + 1571.24552505329*x**2*y**3 + 628.498210021317*y**5) + z*(-353.530243136991*x**6*y + 2356.86828757994*x**4*y**3 - 1885.49463006395*x**2*y**5)) + g_13*(14.45506743704*x**8 - 57.8202697481601*x**6*z**2 - 144.5506743704*x**4*z**4 - 57.8202697481601*x**2*z**6 + 5.0*y**4*(115.64053949632*x**4 - 693.843236977922*x**2*z**2 + 115.64053949632*z**4) + 3.0*y**2*(-115.64053949632*x**6 + 578.202697481601*x**4*z**2 + 578.202697481601*x**2*z**4 - 115.64053949632*z**6) + 14.45506743704*z**8) + g_14*(4.0*y**3*(483.759084217031*x**4*z - 967.518168434061*x**2*z**3 + 96.7518168434061*z**5) + 2.0*y*(-241.879542108515*x**6*z + 241.879542108515*x**4*z**3 + 435.383175795328*x**2*z**5 - 48.375908421703*z**7)) + g_15*(13.3827919767794*x**8 - 187.359087674911*x**6*z**2 + 187.359087674911*x**2*z**6 + 3.0*y**2*(-62.4530292249704*x**6 + 936.795438374555*x**4*z**2 - 936.795438374555*x**2*z**4 + 62.4530292249704*z**6) - 13.3827919767794*z**8) + 2.0*g_16*y*(-216.343639408465*x**6*z + 1081.71819704233*x**4*z**3 - 649.030918225395*x**2*z**5 + 30.9062342012093*z**7) + g_17*(11.2632978048796*x**8 - 315.37233853663*x**6*z**2 + 788.430846341574*x**4*z**4 - 315.37233853663*x**2*z**6 + 11.2632978048796*z**8) + g_2*(-61.8124684024186*x**7*y + 1298.06183645079*x**5*y*z**2 - 2163.43639408465*x**3*y*z**4 + 432.68727881693*x*y*z**6) + g_3*(-80.2967518606762*x**7*z + 187.359087674911*x**5*z**3 + 187.359087674911*x**3*z**5 - 80.2967518606762*x*z**7 + 3.0*y**2*(374.718175349822*x**5*z - 1249.06058449941*x**3*z**3 + 374.718175349822*x*z**5)) + g_4*(4.0*y**3*(96.7518168434061*x**5 - 967.518168434061*x**3*z**2 + 483.759084217031*x*z**4) + 2.0*y*(-48.3759084217031*x**7 + 435.383175795327*x**5*z**2 + 241.879542108515*x**3*z**4 - 241.879542108515*x*z**6)) + g_5*(-57.8202697481601*x**7*z - 57.8202697481601*x**5*z**3 + 57.8202697481601*x**3*z**5 + 57.8202697481601*x*z**7 + 5.0*y**4*(-462.562157985281*x**3*z + 462.562157985281*x*z**3) + 3.0*y**2*(462.562157985281*x**5*z - 462.562157985281*x*z**5)) + g_6*(-117.843414378997*x**7*y + x**5*(785.622762526647*y**3 + 117.843414378997*y*z**2) + x**3*(-628.498210021317*y**5 - 1571.24552505329*y**3*z**2 + 589.217071894985*y*z**4) + x*(1885.49463006395*y**5*z**2 - 2356.86828757994*y**3*z**4 + 353.530243136991*y*z**6)) + g_7*(-30.001464807989*x**7*z + x**5*(900.043944239669*y**2*z - 90.0043944239669*z**3) + x**3*(-2400.11718463912*y**4*z + 1800.08788847934*y**2*z**3 - 90.0043944239669*z**5) + x*(960.046873855647*y**6*z - 2400.11718463912*y**4*z**3 + 900.043944239669*y**2*z**5 - 30.001464807989*z**7)) + g_8*(-127.926675775618*x**7*y + x**5*(1023.41340620494*y**3 - 383.780027326853*y*z**2) + x**3*(-1228.09608744593*y**5 + 2046.82681240988*y**3*z**2 - 383.780027326853*y*z**4) + x*(233.923064275415*y**7 - 1228.09608744593*y**5*z**2 + 1023.41340620494*y**3*z**4 - 127.926675775618*y*z**6)) + g_9*(10.7269778688696*x**8 + 42.9079114754785*x**6*z**2 + 64.3618672132178*x**4*z**4 + 42.9079114754785*x**2*z**6 + 39.2300904918661*y**8 + 7.0*y**6*(-78.4601809837321*x**2 - 78.4601809837321*z**2) + 5.0*y**4*(205.957975082297*x**4 + 411.915950164594*x**2*z**2 + 205.957975082297*z**4) + 3.0*y**2*(-114.421097267943*x**6 - 343.263291803828*x**4*z**2 - 343.263291803828*x**2*z**4 - 114.421097267943*z**6) + 10.7269778688696*z**8)", + "z": "g_0*(-191.144502152495*x**7*z + 1338.01151506746*x**5*z**3 - 1338.01151506746*x**3*z**5 + 191.144502152495*x*z**7) + g_1*y*(-90.106382439037*x**7 + 1892.23403121978*x**5*z**2 - 3153.7233853663*x**3*z**4 + 630.744677073259*x*z**6) + g_10*(1.59908344719522*x**8 - 63.9633378878088*x**6*y**2 + 255.853351551235*x**4*y**4 - 204.682681240988*x**2*y**6 + 29.2403830344269*y**8 + 14.391751024757*z**8 + 7.0*z**6*(6.39633378878088*x**2 - 63.9633378878088*y**2) + 5.0*z**4*(9.59450068317133*x**4 - 191.890013663427*x**2*y**2 + 255.853351551235*y**4) + 3.0*z**2*(6.39633378878088*x**6 - 191.890013663426*x**4*y**2 + 511.706703102471*x**2*y**4 - 204.682681240988*y**6)) + g_11*(137.14955340795*y**7*z - 960.046873855647*y**5*z**3 + y**3*(-300.01464807989*x**4*z + 600.029296159779*x**2*z**3 + 900.043944239669*z**5) + y*(60.0029296159779*x**6*z - 180.008788847934*x**2*z**5 - 120.005859231956*z**7)) + g_12*(4.91014226579154*x**8 - 176.765121568496*x**6*y**2 + 589.217071894985*x**4*y**4 - 314.249105010659*x**2*y**6 + 412.45195032649*y**2*z**6 - 14.7304267973746*z**8 + 5.0*z**4*(9.82028453158308*x**4 - 58.9217071894985*x**2*y**2 - 196.405690631662*y**4) + 3.0*z**2*(13.0937127087774*x**6 - 294.608535947493*x**4*y**2 + 392.811381263323*x**2*y**4 + 104.74970167022*y**6)) + g_13*(y**5*(-1387.68647395584*x**2*z + 462.562157985281*z**3) + y**3*(1156.4053949632*x**4*z + 2312.81078992641*x**2*z**3 - 693.843236977921*z**5) + y*(-115.64053949632*x**6*z - 578.202697481601*x**4*z**3 - 346.921618488961*x**2*z**5 + 115.64053949632*z**7)) + g_14*(8.63855507530412*x**8 - 120.939771054258*x**4*z**4 - 96.7518168434061*x**2*z**6 + y**4*(483.759084217031*x**4 - 2902.55450530218*x**2*z**2 + 483.759084217031*z**4) + y**2*(-241.879542108515*x**6 + 725.638626325546*x**4*z**2 + 2176.91587897664*x**2*z**4 - 338.631358951921*z**6) + 15.5493991355474*z**8) + g_15*(y**3*(1873.59087674911*x**4*z - 3747.18175349822*x**2*z**3 + 374.718175349822*z**5) + y*(-374.718175349822*x**6*z + 1124.15452604947*x**2*z**5 - 107.062335814235*z**7)) + g_16*(13.5214774630291*x**8 - 162.257729556349*x**6*z**2 - 135.214774630291*x**4*z**4 + 270.429549260581*x**2*z**6 + y**2*(-216.343639408465*x**6 + 3245.15459112698*x**4*z**2 - 3245.15459112698*x**2*z**4 + 216.343639408465*z**6) - 17.3847567381802*z**8) + g_17*y*(-630.744677073259*x**6*z + 3153.7233853663*x**4*z**3 - 1892.23403121978*x**2*z**5 + 90.106382439037*z**7) + g_18*(23.8930627690618*x**8 - 669.005757533731*x**6*z**2 + 1672.51439383433*x**4*z**4 - 669.005757533731*x**2*z**6 + 23.8930627690618*z**8) + g_2*(-77.2655855030233*x**7*z + x**5*(1298.06183645079*y**2*z + 108.171819704233*z**3) + x**3*(-4326.8727881693*y**2*z**3 + 324.515459112698*z**5) + x*(1298.06183645079*y**2*z**5 - 108.171819704233*z**7)) + g_3*(y**3*(374.718175349822*x**5 - 3747.18175349822*x**3*z**2 + 1873.59087674911*x*z**4) + y*(-80.2967518606762*x**7 + 562.077263024733*x**5*z**2 + 936.795438374555*x**3*z**4 - 562.077263024733*x*z**6)) + g_4*(-27.6433762409732*x**7*z - 96.7518168434061*x**5*z**3 + 69.1084406024329*x*z**7 + y**4*(-1935.03633686812*x**3*z + 1935.03633686812*x*z**3) + y**2*(870.766351590655*x**5*z + 967.518168434061*x**3*z**3 - 1451.27725265109*x*z**5)) + g_5*(y**5*(-462.562157985281*x**3 + 1387.68647395584*x*z**2) + y**3*(462.562157985281*x**5 - 2312.81078992641*x*z**4) + y*(-57.8202697481601*x**7 - 173.46080924448*x**5*z**2 + 289.101348740801*x**3*z**4 + 404.741888237121*x*z**6)) + g_6*(x**5*(117.843414378997*y**2*z - 39.2811381263323*z**3) + x**3*(-785.622762526647*y**4*z + 1178.43414378997*y**2*z**3 - 78.5622762526647*z**5) + x*(628.498210021318*y**6*z - 2356.86828757994*y**4*z**3 + 1060.59072941097*y**2*z**5 - 39.2811381263323*z**7)) + g_7*(-30.001464807989*x**7*y + x**5*(300.01464807989*y**3 - 270.013183271901*y*z**2) + x**3*(-480.023436927823*y**5 + 1800.08788847934*y**3*z**2 - 450.021972119834*y*z**4) + x*(137.14955340795*y**7 - 1440.07031078347*y**5*z**2 + 1500.07324039945*y**3*z**4 - 210.010253655923*y*z**6)) + g_8*(12.7926675775618*x**7*z + x**5*(-383.780027326853*y**2*z + 38.3780027326853*z**3) + x**3*(1023.41340620494*y**4*z - 767.560054653706*y**2*z**3 + 38.3780027326853*z**5) + x*(-409.365362481976*y**6*z + 1023.41340620494*y**4*z**3 - 383.780027326853*y**2*z**5 + 12.7926675775618*z**7)) + g_9*(-156.920361967464*y**7*z + y**5*(823.831900329187*x**2*z + 823.831900329187*z**3) + y**3*(-686.526583607656*x**4*z - 1373.05316721531*x**2*z**3 - 686.526583607656*z**5) + y*(85.815822950957*x**6*z + 257.447468852871*x**4*z**3 + 257.447468852871*x**2*z**5 + 85.815822950957*z**7))" + } +} \ No newline at end of file diff --git a/notebooks/fwd_implementations/fwd_10.py b/notebooks/fwd_implementations/fwd_10.py new file mode 100644 index 0000000..fc2cb74 --- /dev/null +++ b/notebooks/fwd_implementations/fwd_10.py @@ -0,0 +1,235 @@ +# -------------------- variable and constant definitions +CONST000 = 1.60956935264578 +CONST001 = 1.75869118663323 +CONST002 = -1021.92317475320 +CONST003 = 3.21913870529156 +CONST004 = 4.58257569495584 +CONST005 = 6.63243980843400 +CONST006 = 4.82870805793735 +CONST007 = 4.97432985632550 +CONST008 = 1545.18657853995 +CONST009 = 10.5521471197994 +CONST010 = 12.1657520803952 +CONST011 = 13.2648796168680 +CONST012 = 14.5025390506634 +CONST013 = 15.7883647328499 +CONST014 = 15.7302121789667 +CONST015 = 16.4144510752435 +CONST016 = 12.8765548211663 +CONST017 = 19.3148322317494 +CONST018 = 16.7271353825295 +CONST019 = 22.8629854262320 +CONST020 = 535.268332240943 +CONST021 = 23.2135393295190 +CONST022 = 24.6216766128653 +CONST023 = 27.2034486491732 +CONST024 = 541.428124558099 +CONST025 = -994.666978169547 +CONST026 = 33.9852909359329 +CONST027 = 33.9852909359329 +CONST028 = 35.5238206489124 +CONST029 = -984.867064514610 +CONST030 = -4.82870805793735 +CONST031 = 1070.53666448189 +CONST032 = -463.555973561985 +CONST033 = 49.2433532257305 +CONST034 = 53.2857309733686 +CONST035 = 53.2857309733686 +CONST036 = 56.3871618715269 +CONST037 = 56.3871618715269 +CONST038 = 56.2781179722634 +CONST039 = -1989.33395633909 +CONST040 = 571.272421632637 +CONST041 = -450.224943778107 +CONST042 = 66.9085415301178 +CONST043 = 69.6406179885570 +CONST044 = 69.6406179885570 +CONST045 = -437.967074894228 +CONST046 = 77.2593289269976 +CONST047 = 78.6510608948335 +CONST048 = 590.920238708766 +CONST049 = -1969.73412902922 +CONST050 = 77.3468749368712 +CONST051 = 1624.28437367430 +CONST052 = 88.2963759165686 +CONST053 = 1114.24988781691 +CONST054 = 94.7301883970997 +CONST055 = 98.4867064514610 +CONST056 = 100.362812295177 +CONST057 = -412.049754277320 +CONST058 = 101.517773354644 +CONST059 = -5.63871618715269 +CONST060 = -406.071093418574 +CONST061 = 109.491768723557 +CONST062 = -393.946825805844 +CONST063 = -902.194589944431 +CONST064 = 122.415518921279 +CONST065 = -386.296644634988 +CONST066 = -386.296644634988 +CONST067 = 131.315608601948 +CONST068 = 131.315608601948 +CONST069 = 2707.14062279049 +CONST070 = 4.97432985632550 +CONST071 = 150.074981259369 +CONST072 = 154.518657853995 +CONST073 = 1181.84047741753 +CONST074 = 685.526905959165 +CONST075 = -337.668707833581 +CONST076 = -337.668707833581 +CONST077 = 176.178376404427 +CONST078 = 176.592751833137 +CONST079 = 185.708314636152 +CONST080 = -326.441383790078 +CONST081 = -1.60956935264578 +CONST082 = -1.97354559160624 +CONST083 = 196.973412902922 +CONST084 = 196.973412902922 +CONST085 = -824.099508554641 +CONST086 = 203.035546709287 +CONST087 = -1.97354559160624 +CONST088 = -305.867618423396 +CONST089 = -305.867618423396 +CONST090 = 721.755671955545 +CONST091 = -305.867618423396 +CONST092 = -300.731529981477 +CONST093 = -300.731529981477 +CONST094 = -1.75869118663323 +CONST095 = -290.050781013267 +CONST096 = 734.076568351780 +CONST097 = 225.548647486108 +CONST098 = 225.548647486108 +CONST099 = -284.190565191299 +CONST100 = 742.833258544608 +CONST101 = -278.562471954228 +CONST102 = -278.562471954228 +CONST103 = -787.893651611688 +CONST104 = -787.893651611688 +CONST105 = 772.593289269975 +CONST106 = 787.893651611688 +CONST107 = 787.893651611688 +CONST108 = 278.562471954228 +CONST109 = -742.833258544608 +CONST110 = -1.65810995210850 +CONST111 = 284.190565191299 +CONST112 = -1761.78376404427 +CONST113 = -223.028471767059 +CONST114 = -734.076568351780 +CONST115 = 290.050781013267 +CONST116 = -220.222970505534 +CONST117 = 1321.33782303320 +CONST118 = 1321.33782303320 +CONST119 = -203.035546709287 +CONST120 = -1.65810995210850 +CONST121 = -196.973412902922 +CONST122 = -196.973412902922 +CONST123 = -696.406179885570 +CONST124 = 2.72034486491732 +CONST125 = 338.322971229162 +CONST126 = -1181.84047741753 +CONST127 = -669.085415301178 +CONST128 = -669.085415301178 +CONST129 = -154.518657853995 +CONST130 = -154.518657853995 +CONST131 = 360.877835977772 +CONST132 = -150.074981259369 +CONST133 = -2707.14062279049 +CONST134 = -146.815313670356 +CONST135 = 880.891882022136 +CONST136 = 1392.81235977114 +CONST137 = 1392.81235977114 +CONST138 = -131.315608601948 +CONST139 = -131.315608601948 +CONST140 = 386.296644634988 +CONST141 = -125.841697431734 +CONST142 = -125.841697431734 +CONST143 = -122.415518921279 +CONST144 = 393.946825805844 +CONST145 = 406.071093418574 +CONST146 = -103.107953136506 +CONST147 = -103.107953136506 +CONST148 = -101.517773354644 +CONST149 = -98.4867064514610 +CONST150 = 412.049754277320 +CONST151 = -94.7301883970997 +CONST152 = -1114.24988781691 +CONST153 = -88.2963759165686 +CONST154 = -1624.28437367430 +CONST155 = -82.8889148474622 +CONST156 = -82.8889148474622 +CONST157 = 1969.73412902922 +CONST158 = -590.920238708766 +CONST159 = -77.3468749368713 +CONST160 = -77.2593289269975 +CONST161 = 2486.66744542387 +CONST162 = -2626.31217203896 +CONST163 = 450.224943778107 +CONST164 = 1989.33395633909 +CONST165 = -571.272421632637 +CONST166 = -56.2781179722634 +CONST167 = -49.2433532257305 +CONST168 = -49.2433532257305 +CONST169 = 984.867064514610 +CONST170 = -541.428124558099 +CONST171 = -24.6216766128653 +CONST172 = -22.8629854262320 +CONST173 = -16.4144510752435 +CONST174 = -15.7883647328499 +CONST175 = -14.0695294930659 +CONST176 = -13.2648796168680 +CONST177 = -11.2774323743054 +CONST178 = -14.5025390506634 +CONST179 = -6.63243980843400 +CONST180 = -5.63871618715269 +CONST181 = 1532.88476212980 +CONST182 = -3.21913870529156 +CONST183 = -2.72034486491732 +CONST184 = -1.12774323743054 +VAR00 = x**10 +VAR01 = x**9 +VAR02 = x**8 +VAR03 = x**7 +VAR04 = x**6 +VAR05 = x**5 +VAR06 = x**4 +VAR07 = x**3 +VAR08 = x**2 +VAR09 = y**10 +VAR10 = y**9 +VAR11 = y**8 +VAR12 = y**7 +VAR13 = y**6 +VAR14 = y**5 +VAR15 = y**4 +VAR16 = y**3 +VAR17 = y**2 +VAR18 = z**10 +VAR19 = z**9 +VAR20 = z**8 +VAR21 = z**7 +VAR22 = z**6 +VAR23 = z**5 +VAR24 = z**4 +VAR25 = z**3 +VAR26 = z**2 +# -------------------- kernel implementations +Y00 = CONST023*VAR01*z + CONST023*VAR19*x + CONST074*VAR05*VAR23 + CONST080*VAR03*VAR25 + CONST080*VAR07*VAR21 +Y01 = y*(CONST002*VAR07*VAR22 + CONST010*VAR01 + CONST045*VAR03*VAR26 + CONST061*VAR20*x + CONST181*VAR05*VAR24) +Y02 = CONST013*VAR01*z + CONST054*VAR07*VAR21 + CONST151*VAR03*VAR25 + CONST174*VAR19*x + VAR17*(-CONST039*VAR05*VAR25 + CONST039*VAR07*VAR23 + CONST099*VAR03*z - CONST099*VAR21*x) +Y03 = VAR16*(CONST024*VAR22*x + CONST051*VAR05*VAR26 + CONST133*VAR07*VAR24 + CONST159*VAR03) + y*(CONST095*VAR03*VAR26 - CONST119*VAR05*VAR24 + CONST145*VAR07*VAR22 + CONST148*VAR20*x - CONST178*VAR01) +Y04 = CONST009*VAR01*z + VAR03*(CONST076*VAR17*z + CONST175*VAR25) + VAR05*(CONST106*VAR15*z + CONST107*VAR17*VAR25 + CONST167*VAR23) + VAR07*(CONST106*VAR17*VAR23 + CONST162*VAR15*VAR25 + CONST175*VAR21) + x*(CONST009*VAR19 + CONST075*VAR17*VAR21 + CONST106*VAR15*VAR23) +Y05 = VAR14*(CONST077*VAR05 + CONST112*VAR07*VAR26 + CONST135*VAR24*x) + VAR16*(-CONST114*VAR07*VAR24 + CONST114*VAR22*x + CONST117*VAR05*VAR26 + CONST134*VAR03) + y*(CONST014*VAR01 + CONST047*VAR20*x + CONST116*VAR05*VAR24 + CONST141*VAR03*VAR26) +Y06 = CONST005*VAR01*z + VAR03*(CONST011*VAR25 + CONST102*VAR17*z) + VAR05*(CONST101*VAR17*VAR25 - CONST152*VAR15*z) + VAR07*(CONST108*VAR17*VAR23 + CONST109*VAR13*z + CONST176*VAR21) + x*(CONST108*VAR17*VAR21 - CONST109*VAR13*VAR25 + CONST152*VAR15*VAR23 + CONST179*VAR19) +Y07 = VAR12*(-CONST041*VAR26*x + CONST132*VAR07) + VAR14*(-CONST062*VAR05 + CONST103*VAR07*VAR26 + CONST126*VAR24*x) + VAR16*(CONST083*VAR05*VAR26 + CONST121*VAR03 - CONST158*VAR22*x + CONST169*VAR07*VAR24) + y*(CONST015*VAR01 + CONST138*VAR07*VAR22 + CONST149*VAR05*VAR24 + CONST168*VAR20*x) +Y08 = -CONST182*VAR01*z + VAR03*(CONST016*VAR25 + CONST129*VAR17*z) + VAR05*(CONST017*VAR23 + CONST032*VAR17*VAR25 + CONST105*VAR15*z) + VAR07*(CONST008*VAR15*VAR25 + CONST016*VAR21 + CONST032*VAR17*VAR23 + CONST085*VAR13*z) + x*(CONST078*VAR11*z + CONST085*VAR13*VAR25 + CONST105*VAR15*VAR23 + CONST129*VAR17*VAR21 - CONST182*VAR19) +Y09 = CONST018*VAR01*y + VAR03*(CONST042*VAR26*y + CONST113*VAR16) + VAR05*(CONST020*VAR14 + CONST056*VAR24*y + CONST128*VAR16*VAR26) + VAR07*(CONST031*VAR14*VAR26 + CONST042*VAR22*y + CONST088*VAR12 + CONST127*VAR16*VAR24) + x*(CONST018*VAR20*y + CONST020*VAR14*VAR24 + CONST026*VAR10 + CONST088*VAR12*VAR26 + CONST113*VAR16*VAR22) +Y10 = CONST004*VAR09 + CONST037*VAR17*VAR20 + CONST093*VAR15*VAR22 + CONST131*VAR13*VAR24 + CONST147*VAR11*VAR26 + CONST184*VAR00 + CONST184*VAR18 + VAR02*(CONST036*VAR17 + CONST059*VAR26) + VAR04*(CONST092*VAR15 + CONST098*VAR17*VAR26 + CONST177*VAR24) + VAR06*(CONST063*VAR15*VAR26 + CONST125*VAR17*VAR24 + CONST131*VAR13 + CONST177*VAR22) + VAR08*(CONST063*VAR15*VAR24 + CONST090*VAR13*VAR26 + CONST097*VAR17*VAR22 + CONST146*VAR11 + CONST180*VAR20) +Y11 = CONST018*VAR19*y + VAR21*(CONST042*VAR08*y + CONST113*VAR16) + VAR23*(CONST020*VAR14 + CONST056*VAR06*y + CONST128*VAR08*VAR16) + VAR25*(CONST031*VAR08*VAR14 + CONST042*VAR04*y + CONST091*VAR12 + CONST127*VAR06*VAR16) + z*(CONST018*VAR02*y + CONST020*VAR06*VAR14 + CONST027*VAR10 + CONST089*VAR08*VAR12 + CONST113*VAR04*VAR16) +Y12 = CONST057*VAR13*VAR24 - CONST066*VAR15*VAR22 + CONST081*VAR00 - CONST081*VAR18 - CONST153*VAR11*VAR26 + CONST160*VAR17*VAR20 + VAR02*(CONST030*VAR26 + CONST046*VAR17) + VAR04*(CONST066*VAR15 - CONST129*VAR17*VAR26 + CONST182*VAR24) + VAR06*(CONST065*VAR15*VAR26 + CONST150*VAR13 - CONST182*VAR22) + VAR08*(CONST006*VAR20 - CONST066*VAR15*VAR24 + CONST130*VAR17*VAR22 + CONST153*VAR11) +Y13 = VAR12*(CONST041*VAR08*z + CONST071*VAR25) + VAR14*(CONST062*VAR23 + CONST107*VAR08*VAR25 - CONST126*VAR06*z) + VAR16*(CONST029*VAR06*VAR25 - CONST121*VAR21 + CONST122*VAR08*VAR23 + CONST158*VAR04*z) + y*(-CONST138*VAR04*VAR25 - CONST149*VAR06*VAR23 - CONST168*VAR02*z + CONST173*VAR19) +Y14 = CONST044*VAR17*VAR20 + CONST079*VAR13*VAR24 + CONST101*VAR15*VAR22 + CONST110*VAR00 + CONST120*VAR18 + VAR02*(CONST043*VAR17 + CONST070*VAR26) + VAR04*(CONST021*VAR24 + CONST101*VAR15 + CONST101*VAR17*VAR26) + VAR06*(CONST021*VAR22 + CONST079*VAR13 + CONST123*VAR17*VAR24 + CONST137*VAR15*VAR26) + VAR08*(CONST007*VAR20 + CONST101*VAR17*VAR22 + CONST136*VAR15*VAR24 + CONST152*VAR13*VAR26) +Y15 = VAR14*(CONST077*VAR23 + CONST112*VAR08*VAR25 + CONST135*VAR06*z) + VAR16*(CONST114*VAR04*z - CONST114*VAR06*VAR25 + CONST118*VAR08*VAR23 + CONST134*VAR21) + y*(CONST014*VAR19 + CONST047*VAR02*z + CONST116*VAR06*VAR23 + CONST142*VAR08*VAR21) +Y16 = CONST001*VAR18 + CONST094*VAR00 - CONST139*VAR15*VAR22 + CONST166*VAR17*VAR20 + VAR02*(CONST019*VAR26 - CONST166*VAR17) + VAR04*(CONST022*VAR24 + CONST104*VAR17*VAR26 + CONST139*VAR15) + VAR06*(-CONST049*VAR15*VAR26 + CONST171*VAR22) + VAR08*(CONST049*VAR15*VAR24 + CONST106*VAR17*VAR22 + CONST172*VAR20) +Y17 = VAR16*(CONST050*VAR21 - CONST133*VAR06*VAR25 + CONST154*VAR08*VAR23 + CONST170*VAR04*z) + y*(CONST058*VAR02*z + CONST060*VAR04*VAR25 - CONST095*VAR08*VAR21 + CONST119*VAR06*VAR23 + CONST178*VAR19) +Y18 = CONST034*VAR02*VAR26 + CONST035*VAR08*VAR20 + CONST082*VAR00 + CONST087*VAR18 + CONST155*VAR04*VAR24 + CONST156*VAR06*VAR22 + VAR17*(CONST025*VAR04*VAR26 + CONST025*VAR08*VAR22 + CONST028*VAR02 + CONST028*VAR20 + CONST161*VAR06*VAR24) +Y19 = y*(CONST002*VAR04*VAR25 + CONST010*VAR19 + CONST045*VAR08*VAR21 + CONST061*VAR02*z + CONST181*VAR06*VAR23) +Y20 = -CONST143*VAR02*VAR26 + CONST143*VAR08*VAR20 + CONST165*VAR04*VAR24 - CONST165*VAR06*VAR22 + CONST183*VAR00 - CONST183*VAR18 diff --git a/notebooks/fwd_implementations/fwd_2.py b/notebooks/fwd_implementations/fwd_2.py new file mode 100644 index 0000000..faa1e85 --- /dev/null +++ b/notebooks/fwd_implementations/fwd_2.py @@ -0,0 +1,39 @@ +# -------------------- variable and constant definitions +CONST000 = 1.93649167310371 +CONST001 = 2.23606797749979 +CONST002 = 3.87298334620742 +CONST003 = -1.93649167310371 +CONST004 = -1.11803398874989 +VAR00 = x**10 +VAR01 = x**9 +VAR02 = x**8 +VAR03 = x**7 +VAR04 = x**6 +VAR05 = x**5 +VAR06 = x**4 +VAR07 = x**3 +VAR08 = x**2 +VAR09 = y**10 +VAR10 = y**9 +VAR11 = y**8 +VAR12 = y**7 +VAR13 = y**6 +VAR14 = y**5 +VAR15 = y**4 +VAR16 = y**3 +VAR17 = y**2 +VAR18 = z**10 +VAR19 = z**9 +VAR20 = z**8 +VAR21 = z**7 +VAR22 = z**6 +VAR23 = z**5 +VAR24 = z**4 +VAR25 = z**3 +VAR26 = z**2 +# -------------------- kernel implementations +Y00 = CONST002*x*z +Y01 = CONST002*x*y +Y02 = CONST001*VAR17 + CONST004*VAR08 + CONST004*VAR26 +Y03 = CONST002*y*z +Y04 = CONST003*VAR08 - CONST003*VAR26 diff --git a/notebooks/fwd_implementations/fwd_3.py b/notebooks/fwd_implementations/fwd_3.py new file mode 100644 index 0000000..0630c5c --- /dev/null +++ b/notebooks/fwd_implementations/fwd_3.py @@ -0,0 +1,47 @@ +# -------------------- variable and constant definitions +CONST000 = 2.64575131106459 +CONST001 = 2.09165006633519 +CONST002 = 5.12347538297980 +CONST003 = 6.27495019900557 +CONST004 = 6.48074069840786 +CONST005 = 10.2469507659596 +CONST006 = -2.09165006633519 +CONST007 = -1 +CONST008 = -6.27495019900557 +CONST009 = -3.96862696659689 +CONST010 = -1.62018517460197 +VAR00 = x**10 +VAR01 = x**9 +VAR02 = x**8 +VAR03 = x**7 +VAR04 = x**6 +VAR05 = x**5 +VAR06 = x**4 +VAR07 = x**3 +VAR08 = x**2 +VAR09 = y**10 +VAR10 = y**9 +VAR11 = y**8 +VAR12 = y**7 +VAR13 = y**6 +VAR14 = y**5 +VAR15 = y**4 +VAR16 = y**3 +VAR17 = y**2 +VAR18 = z**10 +VAR19 = z**9 +VAR20 = z**8 +VAR21 = z**7 +VAR22 = z**6 +VAR23 = z**5 +VAR24 = z**4 +VAR25 = z**3 +VAR26 = z**2 +# -------------------- kernel implementations +Y00 = CONST006*VAR07 - CONST008*VAR26*x +Y01 = CONST005*x*y*z +Y02 = CONST010*VAR07 + x*(CONST004*VAR17 + CONST010*VAR26) +Y03 = CONST000*VAR16 + CONST009*VAR08*y + CONST009*VAR26*y +Y04 = CONST010*VAR25 + z*(CONST004*VAR17 + CONST010*VAR08) +Y05 = CONST002*y*(CONST007*VAR08 + VAR26) +Y06 = -CONST006*VAR25 + CONST008*VAR08*z diff --git a/notebooks/fwd_implementations/fwd_4.py b/notebooks/fwd_implementations/fwd_4.py new file mode 100644 index 0000000..9a3d2d8 --- /dev/null +++ b/notebooks/fwd_implementations/fwd_4.py @@ -0,0 +1,58 @@ +# -------------------- variable and constant definitions +CONST000 = 1.12500000000000 +CONST001 = 2.25000000000000 +CONST002 = 3.00000000000000 +CONST003 = 1.67705098312484 +CONST004 = 6.27495019900557 +CONST005 = 2.21852991866236 +CONST006 = 8.87411967464942 +CONST007 = 9.48683298050514 +CONST008 = 10.0623058987491 +CONST009 = 18.8248505970167 +CONST010 = 20.1246117974981 +CONST011 = -18.8248505970167 +CONST012 = -13.3111795119741 +CONST013 = -10.0623058987491 +CONST014 = -9.00000000000000 +CONST015 = -8.87411967464942 +CONST016 = -7.11512473537885 +CONST017 = -6.27495019900557 +CONST018 = -3.35410196624968 +CONST019 = -1.67705098312484 +VAR00 = x**10 +VAR01 = x**9 +VAR02 = x**8 +VAR03 = x**7 +VAR04 = x**6 +VAR05 = x**5 +VAR06 = x**4 +VAR07 = x**3 +VAR08 = x**2 +VAR09 = y**10 +VAR10 = y**9 +VAR11 = y**8 +VAR12 = y**7 +VAR13 = y**6 +VAR14 = y**5 +VAR15 = y**4 +VAR16 = y**3 +VAR17 = y**2 +VAR18 = z**10 +VAR19 = z**9 +VAR20 = z**8 +VAR21 = z**7 +VAR22 = z**6 +VAR23 = z**5 +VAR24 = z**4 +VAR25 = z**3 +VAR26 = z**2 +# -------------------- kernel implementations +Y00 = CONST015*VAR07*z - CONST015*VAR25*x +Y01 = y*(-CONST011*VAR26*x + CONST017*VAR07) +Y02 = CONST018*VAR07*z + x*(CONST010*VAR17*z + CONST018*VAR25) +Y03 = CONST016*VAR07*y + x*(CONST007*VAR16 + CONST016*VAR26*y) +Y04 = CONST000*VAR06 + CONST000*VAR24 + CONST002*VAR15 + CONST014*VAR17*VAR26 + VAR08*(CONST001*VAR26 + CONST014*VAR17) +Y05 = CONST016*VAR25*y + z*(CONST007*VAR16 + CONST016*VAR08*y) +Y06 = -CONST019*VAR06 + CONST019*VAR24 + VAR17*(CONST013*VAR08 - CONST013*VAR26) +Y07 = y*(CONST011*VAR08*z - CONST017*VAR25) +Y08 = CONST005*VAR06 + CONST005*VAR24 + CONST012*VAR08*VAR26 diff --git a/notebooks/fwd_implementations/fwd_5.py b/notebooks/fwd_implementations/fwd_5.py new file mode 100644 index 0000000..df1e2c9 --- /dev/null +++ b/notebooks/fwd_implementations/fwd_5.py @@ -0,0 +1,75 @@ +# -------------------- variable and constant definitions +CONST000 = 1.73430461568895 +CONST001 = 2.32681380862329 +CONST002 = 1.60565407233314 +CONST003 = 3.21130814466628 +CONST004 = 3.31662479035540 +CONST005 = 6.21867148191637 +CONST006 = 6.21867148191637 +CONST007 = 1.60565407233314 +CONST008 = 8.49632273398321 +CONST009 = 11.6340690431164 +CONST010 = 12.8452325786651 +CONST011 = 12.4373429638327 +CONST012 = 12.8452325786651 +CONST013 = 13.8744369255116 +CONST014 = 16.9926454679664 +CONST015 = 5.20291384706685 +CONST016 = 29.4321253055229 +CONST017 = 33.9852909359329 +CONST018 = 7.35803132638072 +CONST019 = 41.6233107765348 +CONST020 = -44.1481879582843 +CONST021 = -41.6233107765348 +CONST022 = -29.4321253055229 +CONST023 = -23.2681380862329 +CONST024 = -19.2678488679977 +CONST025 = -19.2678488679977 +CONST026 = -16.9926454679664 +CONST027 = -16.9926454679664 +CONST028 = -13.8744369255116 +CONST029 = -16.5831239517770 +CONST030 = 3.46860923137790 +CONST031 = -8.49632273398321 +CONST032 = -5.20291384706685 +CONST033 = -3.46860923137790 +CONST034 = -1.73430461568895 +VAR00 = x**10 +VAR01 = x**9 +VAR02 = x**8 +VAR03 = x**7 +VAR04 = x**6 +VAR05 = x**5 +VAR06 = x**4 +VAR07 = x**3 +VAR08 = x**2 +VAR09 = y**10 +VAR10 = y**9 +VAR11 = y**8 +VAR12 = y**7 +VAR13 = y**6 +VAR14 = y**5 +VAR15 = y**4 +VAR16 = y**3 +VAR17 = y**2 +VAR18 = z**10 +VAR19 = z**9 +VAR20 = z**8 +VAR21 = z**7 +VAR22 = z**6 +VAR23 = z**5 +VAR24 = z**4 +VAR25 = z**3 +VAR26 = z**2 +# -------------------- kernel implementations +Y00 = CONST001*VAR05 + CONST009*VAR24*x + CONST023*VAR07*VAR26 +Y01 = y*(CONST022*VAR07*z - CONST022*VAR25*x) +Y02 = CONST000*VAR05 + VAR07*(CONST028*VAR17 + CONST033*VAR26) + x*(-CONST021*VAR17*VAR26 + CONST032*VAR24) +Y03 = CONST027*VAR07*y*z + x*(CONST017*VAR16*z + CONST026*VAR25*y) +Y04 = CONST002*VAR05 + VAR07*(CONST003*VAR26 + CONST025*VAR17) + x*(CONST002*VAR24 + CONST010*VAR15 + CONST024*VAR17*VAR26) +Y05 = CONST004*VAR14 + VAR16*(CONST029*VAR08 + CONST029*VAR26) + y*(CONST005*VAR06 + CONST006*VAR24 + CONST011*VAR08*VAR26) +Y06 = CONST002*VAR23 + VAR25*(CONST003*VAR08 + CONST024*VAR17) + z*(CONST007*VAR06 + CONST012*VAR15 + CONST024*VAR08*VAR17) +Y07 = VAR16*(CONST026*VAR08 - CONST026*VAR26) + y*(-CONST031*VAR06 + CONST031*VAR24) +Y08 = CONST034*VAR23 + VAR25*(CONST013*VAR17 + CONST030*VAR08) + z*(CONST021*VAR08*VAR17 - CONST032*VAR06) +Y09 = y*(CONST018*VAR06 + CONST018*VAR24 + CONST020*VAR08*VAR26) +Y10 = CONST001*VAR23 + CONST009*VAR06*z + CONST023*VAR08*VAR25 diff --git a/notebooks/fwd_implementations/fwd_6.py b/notebooks/fwd_implementations/fwd_6.py new file mode 100644 index 0000000..4aa8b02 --- /dev/null +++ b/notebooks/fwd_implementations/fwd_6.py @@ -0,0 +1,90 @@ +# -------------------- variable and constant definitions +CONST000 = 1.63279380970164 +CONST001 = 2.42182459624970 +CONST002 = 3.26558761940328 +CONST003 = 3.26558761940328 +CONST004 = 6.53117523880657 +CONST005 = 7.15454401062709 +CONST006 = 8.38944649544891 +CONST007 = 9.79676285820985 +CONST008 = 10.3266947761614 +CONST009 = 3.60555127546399 +CONST010 = -1.78863600265677 +CONST011 = 14.5309475774982 +CONST012 = 8.94318001328386 +CONST013 = 16.5227116418583 +CONST014 = 16.5227116418583 +CONST015 = 17.8863600265677 +CONST016 = 19.5935257164197 +CONST017 = 20.6533895523229 +CONST018 = 20.2812259244849 +CONST019 = -107.318160159406 +CONST020 = 17.8863600265677 +CONST021 = 26.1247009552263 +CONST022 = 29.3902885746295 +CONST023 = 36.3273689437454 +CONST024 = 40.5624518489699 +CONST025 = 41.9472324772445 +CONST026 = -1.63279380970164 +CONST027 = -83.8944649544891 +CONST028 = -78.3741028656788 +CONST029 = 52.2494019104525 +CONST030 = -71.5454401062709 +CONST031 = 71.5454401062709 +CONST032 = -52.2494019104525 +CONST033 = -52.2494019104525 +CONST034 = 78.3741028656788 +CONST035 = -48.4364919249939 +CONST036 = -41.3067791046458 +CONST037 = -36.3273689437454 +CONST038 = -29.3902885746295 +CONST039 = -27.0416345659799 +CONST040 = -26.1247009552263 +CONST041 = -26.1247009552263 +CONST042 = -19.5935257164197 +CONST043 = -2.42182459624970 +CONST044 = -9.79676285820985 +CONST045 = -7.15454401062709 +CONST046 = -3.38020432074749 +CONST047 = -1.12673477358250 +VAR00 = x**10 +VAR01 = x**9 +VAR02 = x**8 +VAR03 = x**7 +VAR04 = x**6 +VAR05 = x**5 +VAR06 = x**4 +VAR07 = x**3 +VAR08 = x**2 +VAR09 = y**10 +VAR10 = y**9 +VAR11 = y**8 +VAR12 = y**7 +VAR13 = y**6 +VAR14 = y**5 +VAR15 = y**4 +VAR16 = y**3 +VAR17 = y**2 +VAR18 = z**10 +VAR19 = z**9 +VAR20 = z**8 +VAR21 = z**7 +VAR22 = z**6 +VAR23 = z**5 +VAR24 = z**4 +VAR25 = z**3 +VAR26 = z**2 +# -------------------- kernel implementations +Y00 = CONST011*VAR05*z + CONST011*VAR23*x + CONST035*VAR07*VAR25 +Y01 = y*(CONST006*VAR05 + CONST025*VAR24*x + CONST027*VAR07*VAR26) +Y02 = -CONST045*VAR05*z + CONST045*VAR23*x + VAR17*(CONST030*VAR07*z - CONST030*VAR25*x) +Y03 = VAR16*(-CONST028*VAR26*x + CONST040*VAR07) + y*(CONST007*VAR05 + CONST038*VAR24*x + CONST042*VAR07*VAR26) +Y04 = CONST003*VAR05*z + VAR07*(CONST004*VAR25 + CONST033*VAR17*z) + x*(CONST002*VAR23 - CONST032*VAR15*z + CONST032*VAR17*VAR25) +Y05 = CONST008*VAR05*y + VAR07*(CONST017*VAR26*y + CONST036*VAR16) + x*(CONST008*VAR24*y + CONST013*VAR14 + CONST036*VAR16*VAR26) +Y06 = CONST009*VAR13 + CONST018*VAR17*VAR24 + CONST039*VAR15*VAR26 + CONST047*VAR04 + CONST047*VAR22 + VAR06*(CONST018*VAR17 + CONST046*VAR26) + VAR08*(CONST024*VAR17*VAR26 + CONST039*VAR15 + CONST046*VAR24) +Y07 = CONST008*VAR23*y + VAR25*(CONST017*VAR08*y + CONST036*VAR16) + z*(CONST008*VAR06*y + CONST014*VAR14 + CONST036*VAR08*VAR16) +Y08 = CONST026*VAR04 - CONST026*VAR22 + CONST040*VAR17*VAR24 - CONST041*VAR15*VAR26 + VAR06*(CONST026*VAR26 - CONST041*VAR17) + VAR08*(-CONST026*VAR24 + CONST041*VAR15) +Y09 = VAR16*(CONST028*VAR08*z - CONST041*VAR25) + y*(CONST022*VAR06*z - CONST042*VAR08*VAR25 + CONST044*VAR23) +Y10 = CONST010*VAR04 + CONST010*VAR22 + CONST020*VAR17*VAR24 + VAR06*(CONST012*VAR26 + CONST015*VAR17) + VAR08*(CONST012*VAR24 + CONST019*VAR17*VAR26) +Y11 = y*(CONST006*VAR23 + CONST025*VAR06*z + CONST027*VAR08*VAR25) +Y12 = -CONST037*VAR06*VAR26 + CONST037*VAR08*VAR24 + CONST043*VAR04 - CONST043*VAR22 diff --git a/notebooks/fwd_implementations/fwd_7.py b/notebooks/fwd_implementations/fwd_7.py new file mode 100644 index 0000000..257071b --- /dev/null +++ b/notebooks/fwd_implementations/fwd_7.py @@ -0,0 +1,125 @@ +# -------------------- variable and constant definitions +CONST000 = 1.66389743899677 +CONST001 = 2.50682661696018 +CONST002 = 3.87298334620742 +CONST003 = 4.99169231699030 +CONST004 = 8.31948719498384 +CONST005 = 9.19753915797590 +CONST006 = 9.19753915797590 +CONST007 = 11.7655316231354 +CONST008 = 11.7655316231354 +CONST009 = 9.37968632871057 +CONST010 = 16.5555704843566 +CONST011 = 17.5477863187212 +CONST012 = 20.4939015319192 +CONST013 = 20.4939015319192 +CONST014 = 22.0740939791422 +CONST015 = 23.5310632462709 +CONST016 = 33.2779487799353 +CONST017 = 36.7901566319036 +CONST018 = 37.6497011940334 +CONST019 = 38.4260653723485 +CONST020 = 38.4260653723485 +CONST021 = 38.4260653723485 +CONST022 = 44.1481879582843 +CONST023 = -4.99169231699030 +CONST024 = 44.3705983732471 +CONST025 = 47.0621264925418 +CONST026 = 50.8329064189723 +CONST027 = 52.6433589561637 +CONST028 = 55.1852349478554 +CONST029 = 56.2781179722634 +CONST030 = 56.2781179722634 +CONST031 = 62.7495019900557 +CONST032 = 66.5558975598707 +CONST033 = 75.2994023880668 +CONST034 = 76.8521307446970 +CONST035 = 87.7389315936062 +CONST036 = 99.8338463398060 +CONST037 = 101.665812837945 +CONST038 = 110.370469895711 +CONST039 = 133.111795119741 +CONST040 = 140.695294930659 +CONST041 = 147.160626527614 +CONST042 = -1.66389743899677 +CONST043 = -9.37968632871057 +CONST044 = -1.66389743899677 +CONST045 = -220.740939791422 +CONST046 = -220.740939791422 +CONST047 = -1.60108605718119 +CONST048 = -187.593726574211 +CONST049 = -9.19753915797590 +CONST050 = -1.83950783159518 +CONST051 = -1.83950783159518 +CONST052 = -4.80325817154356 +CONST053 = -147.160626527614 +CONST054 = -140.695294930659 +CONST055 = -133.111795119741 +CONST056 = -125.499003980111 +CONST057 = -125.499003980111 +CONST058 = -99.8338463398060 +CONST059 = -87.7389315936062 +CONST060 = -76.8521307446970 +CONST061 = -66.5558975598707 +CONST062 = -62.7495019900557 +CONST063 = -52.6433589561637 +CONST064 = -44.1481879582843 +CONST065 = -44.3705983732471 +CONST066 = -40.6663251351779 +CONST067 = -40.6663251351779 +CONST068 = -8.31948719498384 +CONST069 = -37.6497011940334 +CONST070 = -33.2779487799353 +CONST071 = -25.4164532094862 +CONST072 = -25.4164532094862 +CONST073 = -17.5477863187212 +CONST074 = -11.7655316231354 +CONST075 = -11.0370469895711 +CONST076 = -9.19753915797590 +CONST077 = -8.47215106982872 +CONST078 = -4.80325817154356 +CONST079 = -2.50682661696018 +CONST080 = -1.60108605718119 +VAR00 = x**10 +VAR01 = x**9 +VAR02 = x**8 +VAR03 = x**7 +VAR04 = x**6 +VAR05 = x**5 +VAR06 = x**4 +VAR07 = x**3 +VAR08 = x**2 +VAR09 = y**10 +VAR10 = y**9 +VAR11 = y**8 +VAR12 = y**7 +VAR13 = y**6 +VAR14 = y**5 +VAR15 = y**4 +VAR16 = y**3 +VAR17 = y**2 +VAR18 = z**10 +VAR19 = z**9 +VAR20 = z**8 +VAR21 = z**7 +VAR22 = z**6 +VAR23 = z**5 +VAR24 = z**4 +VAR25 = z**3 +VAR26 = z**2 +# -------------------- kernel implementations +Y00 = CONST059*VAR07*VAR24 - CONST063*VAR05*VAR26 - CONST073*VAR22*x + CONST079*VAR03 +Y01 = y*(CONST029*VAR23*x + CONST030*VAR05*z + CONST048*VAR07*VAR25) +Y02 = CONST050*VAR03 + VAR05*(CONST010*VAR26 + CONST014*VAR17) + VAR07*(CONST045*VAR17*VAR26 - CONST076*VAR24) + x*(CONST038*VAR17*VAR24 + CONST076*VAR22) +Y03 = VAR16*(CONST041*VAR25*x + CONST053*VAR07*z) + y*(-CONST064*VAR05*z + CONST064*VAR23*x) +Y04 = CONST042*VAR03 + VAR05*(-CONST042*VAR26 - CONST070*VAR17) + VAR07*(CONST061*VAR17*VAR26 + CONST065*VAR15 - CONST068*VAR24) + x*(-CONST023*VAR22 - CONST055*VAR15*VAR26 + CONST058*VAR17*VAR24) +Y05 = CONST015*VAR05*y*z + VAR07*(CONST025*VAR25*y + CONST057*VAR16*z) + x*(CONST015*VAR23*y + CONST033*VAR14*z + CONST056*VAR16*VAR25) +Y06 = CONST047*VAR03 + VAR05*(CONST020*VAR17 + CONST078*VAR26) + VAR07*(CONST052*VAR24 + CONST060*VAR15 - CONST060*VAR17*VAR26) + x*(CONST012*VAR13 + CONST019*VAR17*VAR24 + CONST060*VAR15*VAR26 + CONST080*VAR22) +Y07 = CONST002*VAR12 + VAR14*(CONST066*VAR08 + CONST067*VAR26) + VAR16*(CONST026*VAR06 + CONST026*VAR24 + CONST037*VAR08*VAR26) + y*(CONST071*VAR06*VAR26 + CONST072*VAR08*VAR24 + CONST077*VAR04 + CONST077*VAR22) +Y08 = CONST047*VAR21 + VAR23*(CONST020*VAR17 + CONST052*VAR08) + VAR25*(CONST052*VAR06 - CONST060*VAR08*VAR17 + CONST060*VAR15) + z*(CONST013*VAR13 + CONST021*VAR06*VAR17 + CONST047*VAR04 + CONST060*VAR08*VAR15) +Y09 = VAR14*(CONST069*VAR08 - CONST069*VAR26) + VAR16*(-CONST062*VAR06 + CONST062*VAR24) + y*(CONST008*VAR08*VAR24 + CONST074*VAR04 + CONST074*VAR06*VAR26 - CONST074*VAR22) +Y10 = -CONST042*VAR21 + VAR23*(CONST044*VAR08 + CONST070*VAR17) + VAR25*(CONST032*VAR08*VAR17 - CONST065*VAR15 + CONST068*VAR06) + z*(CONST023*VAR04 + CONST055*VAR08*VAR15 - CONST058*VAR06*VAR17) +Y11 = VAR16*(CONST017*VAR06 + CONST017*VAR24 + CONST046*VAR08*VAR26) + y*(CONST028*VAR06*VAR26 + CONST028*VAR08*VAR24 + CONST075*VAR04 + CONST075*VAR22) +Y12 = CONST051*VAR21 + VAR23*(CONST010*VAR08 + CONST014*VAR17) + VAR25*(CONST045*VAR08*VAR17 - CONST049*VAR06) + z*(CONST038*VAR06*VAR17 + CONST049*VAR04) +Y13 = y*(CONST043*VAR04 - CONST043*VAR22 - CONST054*VAR06*VAR26 + CONST054*VAR08*VAR24) +Y14 = -CONST059*VAR06*VAR25 + CONST063*VAR08*VAR23 + CONST073*VAR04*z - CONST079*VAR21 diff --git a/notebooks/fwd_implementations/fwd_8.py b/notebooks/fwd_implementations/fwd_8.py new file mode 100644 index 0000000..ff11a70 --- /dev/null +++ b/notebooks/fwd_implementations/fwd_8.py @@ -0,0 +1,150 @@ +# -------------------- variable and constant definitions +CONST000 = 1.12741169450483 +CONST001 = 1.61701765412441 +CONST002 = 3.23403530824881 +CONST003 = 4.12310562561766 +CONST004 = 4.50964677801932 +CONST005 = 6.78376969317208 +CONST006 = 6.76447016702898 +CONST007 = 1.69594242329302 +CONST008 = 1.88707052233084 +CONST009 = 10.3359109268366 +CONST010 = 2.58397773170915 +CONST011 = 13.1367135230810 +CONST012 = 13.1367135230810 +CONST013 = 20.6718218536732 +CONST014 = -489.184589393411 +CONST015 = 24.7386337537060 +CONST016 = 26.4189873126318 +CONST017 = 24.7386337537060 +CONST018 = 39.4101405692431 +CONST019 = 48.9184589393411 +CONST020 = 48.5105296237322 +CONST021 = 51.7445649319810 +CONST022 = 61.1480736741764 +CONST023 = 61.1480736741764 +CONST024 = 65.6835676154051 +CONST025 = 67.8376969317208 +CONST026 = 70.0624721230988 +CONST027 = 72.3513764878561 +CONST028 = 87.5780901538735 +CONST029 = 97.0210592474644 +CONST030 = -6.78376969317208 +CONST031 = 103.489129863962 +CONST032 = -407.026181590325 +CONST033 = 108.231522672464 +CONST034 = 108.231522672464 +CONST035 = 110.066532613517 +CONST036 = 110.066532613517 +CONST037 = -396.284809689477 +CONST038 = 129.361412329953 +CONST039 = 144.702752975712 +CONST040 = -361.756882439281 +CONST041 = -1.88707052233084 +CONST042 = 158.513923875791 +CONST043 = 162.810472636130 +CONST044 = 175.156180307747 +CONST045 = 180.878441219640 +CONST046 = 194.042118494929 +CONST047 = -12.2296147348353 +CONST048 = 203.513090795162 +CONST049 = 210.187416369296 +CONST050 = 216.463045344927 +CONST051 = 217.054129463568 +CONST052 = 216.463045344927 +CONST053 = -6.78376969317208 +CONST054 = -271.350787726883 +CONST055 = 244.592294696706 +CONST056 = 244.592294696706 +CONST057 = -262.734270461621 +CONST058 = -258.722824659905 +CONST059 = 262.734270461621 +CONST060 = 271.350787726883 +CONST061 = -217.054129463568 +CONST062 = -210.187416369296 +CONST063 = -175.156180307747 +CONST064 = -162.810472636130 +CONST065 = 361.756882439281 +CONST066 = -144.702752975712 +CONST067 = -129.877827206956 +CONST068 = -129.361412329953 +CONST069 = 396.284809689477 +CONST070 = -108.231522672464 +CONST071 = -108.231522672464 +CONST072 = -87.5780901538735 +CONST073 = -3.23403530824881 +CONST074 = -72.3513764878561 +CONST075 = -70.0624721230988 +CONST076 = -65.6835676154052 +CONST077 = -61.1480736741764 +CONST078 = -61.1480736741764 +CONST079 = -57.7234787586472 +CONST080 = -57.7234787586472 +CONST081 = -51.7445649319810 +CONST082 = -48.5105296237322 +CONST083 = -40.5868210021738 +CONST084 = -39.4101405692431 +CONST085 = -40.7026181590325 +CONST086 = -36.0771742241545 +CONST087 = -36.0771742241545 +CONST088 = -26.4189873126318 +CONST089 = -20.6718218536732 +CONST090 = -528.379746252636 +CONST091 = -16.9594242329302 +CONST092 = -13.1367135230810 +CONST093 = -12.2296147348353 +CONST094 = -11.3224231339851 +CONST095 = -10.3359109268366 +CONST096 = -9.70210592474644 +CONST097 = -11.3224231339851 +CONST098 = -13.5289403340579 +CONST099 = -6.78376969317208 +CONST100 = -13.5289403340579 +CONST101 = -13.1367135230810 +CONST102 = -3.23403530824881 +CONST103 = -1.61701765412441 +VAR00 = x**10 +VAR01 = x**9 +VAR02 = x**8 +VAR03 = x**7 +VAR04 = x**6 +VAR05 = x**5 +VAR06 = x**4 +VAR07 = x**3 +VAR08 = x**2 +VAR09 = y**10 +VAR10 = y**9 +VAR11 = y**8 +VAR12 = y**7 +VAR13 = y**6 +VAR14 = y**5 +VAR15 = y**4 +VAR16 = y**3 +VAR17 = y**2 +VAR18 = z**10 +VAR19 = z**9 +VAR20 = z**8 +VAR21 = z**7 +VAR22 = z**6 +VAR23 = z**5 +VAR24 = z**4 +VAR25 = z**3 +VAR26 = z**2 +# -------------------- kernel implementations +Y00 = -CONST066*VAR05*VAR25 + CONST066*VAR07*VAR23 + CONST089*VAR03*z - CONST089*VAR21*x +Y01 = y*(CONST040*VAR07*VAR24 + CONST051*VAR05*VAR26 - CONST074*VAR22*x + CONST095*VAR03) +Y02 = CONST097*VAR03*z + VAR05*(CONST042*VAR17*z - CONST088*VAR25) + VAR07*(-CONST088*VAR23 + CONST090*VAR17*VAR25) + x*(CONST042*VAR17*VAR23 + CONST094*VAR21) +Y03 = VAR16*(CONST014*VAR07*VAR26 + CONST019*VAR05 + CONST055*VAR24*x) + y*(CONST035*VAR05*VAR26 + CONST077*VAR22*x - CONST078*VAR07*VAR24 + CONST093*VAR03) +Y04 = CONST099*VAR03*z + VAR05*(-CONST064*VAR17*z + CONST099*VAR25) + VAR07*(-CONST053*VAR23 + CONST054*VAR15*z) + x*(-CONST053*VAR21 - CONST054*VAR15*VAR25 + CONST064*VAR17*VAR23) +Y05 = VAR14*(-CONST062*VAR26*x + CONST075*VAR07) + VAR16*(CONST057*VAR24*x + CONST063*VAR07*VAR26 - CONST072*VAR05) + y*(CONST011*VAR05*VAR26 + CONST024*VAR07*VAR24 - CONST084*VAR22*x + CONST092*VAR03) +Y06 = CONST102*VAR03*z + VAR05*(CONST029*VAR17*z + CONST096*VAR25) + VAR07*(CONST046*VAR17*VAR25 + CONST058*VAR15*z + CONST096*VAR23) + x*(CONST029*VAR17*VAR23 + CONST031*VAR13*z + CONST058*VAR15*VAR25 + CONST102*VAR21) +Y07 = CONST098*VAR03*y + VAR05*(CONST033*VAR16 + CONST083*VAR26*y) + VAR07*(CONST050*VAR16*VAR26 + CONST067*VAR14 + CONST083*VAR24*y) + x*(CONST015*VAR12 + CONST067*VAR14*VAR26 - CONST070*VAR16*VAR24 + CONST098*VAR22*y) +Y08 = CONST000*VAR02 + CONST000*VAR20 + CONST003*VAR11 - CONST070*VAR15*VAR24 + CONST080*VAR13*VAR26 + CONST087*VAR17*VAR22 + VAR04*(CONST004*VAR26 + CONST086*VAR17) + VAR06*(CONST006*VAR24 - CONST070*VAR15 + CONST071*VAR17*VAR26) + VAR08*(CONST004*VAR22 + CONST050*VAR15*VAR26 + CONST070*VAR17*VAR24 + CONST079*VAR13) +Y09 = CONST098*VAR21*y + VAR23*(CONST033*VAR16 + CONST083*VAR08*y) + VAR25*(CONST052*VAR08*VAR16 + CONST067*VAR14 + CONST083*VAR06*y) + z*(CONST017*VAR12 + CONST033*VAR06*VAR16 + CONST067*VAR08*VAR14 + CONST100*VAR04*y) +Y10 = CONST073*VAR08*VAR22 - CONST102*VAR04*VAR26 - CONST103*VAR02 + CONST103*VAR20 + VAR13*(CONST021*VAR26 + CONST081*VAR08) + VAR15*(-CONST068*VAR06 + CONST068*VAR24) + VAR17*(CONST020*VAR08*VAR24 + CONST020*VAR22 + CONST082*VAR04 + CONST082*VAR06*VAR26) +Y11 = VAR14*(CONST062*VAR08*z - CONST075*VAR25) + VAR16*(-CONST057*VAR06*z - CONST063*VAR08*VAR25 + CONST072*VAR23) + y*(CONST012*VAR21 + CONST076*VAR06*VAR25 + CONST084*VAR04*z + CONST101*VAR08*VAR23) +Y12 = CONST007*VAR02 + CONST007*VAR20 + CONST030*VAR04*VAR26 + CONST053*VAR08*VAR22 + CONST091*VAR06*VAR24 + VAR15*(CONST025*VAR06 + CONST025*VAR24 + CONST032*VAR08*VAR26) + VAR17*(CONST048*VAR06*VAR26 + CONST048*VAR08*VAR24 + CONST085*VAR04 + CONST085*VAR22) +Y13 = VAR16*(CONST014*VAR08*VAR25 + CONST019*VAR23 + CONST056*VAR06*z) + y*(CONST036*VAR08*VAR23 + CONST047*VAR21 - CONST077*VAR06*VAR25 + CONST078*VAR04*z) +Y14 = CONST008*VAR02 + CONST041*VAR20 + CONST088*VAR04*VAR26 - CONST088*VAR08*VAR22 + VAR17*(-CONST037*VAR06*VAR26 + CONST037*VAR08*VAR24 + CONST088*VAR04 - CONST088*VAR22) +Y15 = y*(-CONST040*VAR06*VAR25 + CONST061*VAR08*VAR23 + CONST074*VAR04*z - CONST095*VAR21) +Y16 = CONST010*VAR02 + CONST010*VAR20 + CONST045*VAR06*VAR24 + CONST074*VAR04*VAR26 + CONST074*VAR08*VAR22 diff --git a/notebooks/fwd_implementations/fwd_9.py b/notebooks/fwd_implementations/fwd_9.py new file mode 100644 index 0000000..3d2bc05 --- /dev/null +++ b/notebooks/fwd_implementations/fwd_9.py @@ -0,0 +1,194 @@ +# -------------------- variable and constant definitions +CONST000 = 1.93163963757558 +CONST001 = 2.65478475211798 +CONST002 = 1.72771101506082 +CONST003 = 1.63671408859718 +CONST004 = 1.59908344719522 +CONST005 = 6.39633378878088 +CONST006 = 6.39633378878088 +CONST007 = 8.63855507530412 +CONST008 = 9.59450068317133 +CONST009 = 4.35889894354067 +CONST010 = 10.7269778688696 +CONST011 = 10.7269778688696 +CONST012 = 6.39633378878088 +CONST013 = 15.0007324039945 +CONST014 = 13.0937127087774 +CONST015 = 9.82028453158308 +CONST016 = 14.4550674370400 +CONST017 = 14.4550674370400 +CONST018 = 13.3827919767794 +CONST019 = 13.5214774630291 +CONST020 = 23.8930627690618 +CONST021 = 27.0429549260581 +CONST022 = 29.2403830344269 +CONST023 = 29.2403830344269 +CONST024 = 30.0014648079890 +CONST025 = -480.023436927823 +CONST026 = -480.023436927823 +CONST027 = 30.9062342012093 +CONST028 = 38.6327927515116 +CONST029 = 42.9079114754785 +CONST030 = -462.562157985281 +CONST031 = 54.0859098521163 +CONST032 = -967.518168434061 +CONST033 = 57.8202697481601 +CONST034 = 57.8202697481601 +CONST035 = 58.9217071894985 +CONST036 = 58.9217071894985 +CONST037 = 62.4530292249704 +CONST038 = 1081.71819704233 +CONST039 = 64.3618672132178 +CONST040 = 578.202697481601 +CONST041 = 68.5747767039748 +CONST042 = 589.217071894985 +CONST043 = 4.91014226579154 +CONST044 = 600.029296159779 +CONST045 = -936.795438374555 +CONST046 = 90.1063824390370 +CONST047 = 96.7518168434061 +CONST048 = 104.749701670220 +CONST049 = 115.640539496320 +CONST050 = 630.744677073259 +CONST051 = -392.811381263323 +CONST052 = 649.030918225395 +CONST053 = 137.149553407950 +CONST054 = 150.007324039945 +CONST055 = 150.007324039945 +CONST056 = -343.263291803828 +CONST057 = 176.765121568496 +CONST058 = 11.2632978048796 +CONST059 = 187.359087674911 +CONST060 = 196.405690631662 +CONST061 = -315.372338536630 +CONST062 = -314.249105010659 +CONST063 = 205.957975082297 +CONST064 = 216.343639408465 +CONST065 = -294.608535947493 +CONST066 = 240.011718463912 +CONST067 = 241.879542108515 +CONST068 = 241.879542108515 +CONST069 = 255.853351551235 +CONST070 = 255.853351551235 +CONST071 = -241.879542108515 +CONST072 = -240.011718463912 +CONST073 = -241.879542108515 +CONST074 = 788.430846341574 +CONST075 = 1.72771101506082 +CONST076 = -1.93163963757558 +CONST077 = -1249.06058449941 +CONST078 = -223.001919177910 +CONST079 = 294.608535947493 +CONST080 = -216.343639408465 +CONST081 = 300.014648079890 +CONST082 = -204.682681240988 +CONST083 = -204.682681240988 +CONST084 = -204.682681240988 +CONST085 = 314.249105010659 +CONST086 = -196.405690631662 +CONST087 = -191.890013663426 +CONST088 = -191.890013663427 +CONST089 = -187.359087674911 +CONST090 = -693.843236977922 +CONST091 = 334.502878766866 +CONST092 = -176.765121568496 +CONST093 = -150.007324039945 +CONST094 = -144.550674370400 +CONST095 = 374.718175349822 +CONST096 = 374.718175349822 +CONST097 = -649.030918225395 +CONST098 = 392.811381263323 +CONST099 = -630.744677073259 +CONST100 = -115.640539496320 +CONST101 = -114.421097267943 +CONST102 = -115.640539496320 +CONST103 = -104.749701670220 +CONST104 = 411.915950164594 +CONST105 = -95.5722510762473 +CONST106 = -90.1063824390370 +CONST107 = -90.0043944239669 +CONST108 = 936.795438374555 +CONST109 = -80.2967518606762 +CONST110 = -78.4601809837321 +CONST111 = 435.383175795327 +CONST112 = -589.217071894985 +CONST113 = -78.4601809837321 +CONST114 = 435.383175795328 +CONST115 = -68.5747767039748 +CONST116 = -63.9633378878088 +CONST117 = -63.9633378878088 +CONST118 = -62.4530292249704 +CONST119 = -58.9217071894985 +CONST120 = -1081.71819704233 +CONST121 = -57.8202697481601 +CONST122 = -57.8202697481601 +CONST123 = -58.9217071894985 +CONST124 = -54.0859098521163 +CONST125 = 462.562157985281 +CONST126 = 462.562157985281 +CONST127 = -48.3759084217031 +CONST128 = -48.3759084217030 +CONST129 = -38.6327927515116 +CONST130 = -30.9062342012093 +CONST131 = 483.759084217031 +CONST132 = -30.0014648079890 +CONST133 = -30.0014648079890 +CONST134 = -27.0429549260581 +CONST135 = -24.1879542108515 +CONST136 = -24.1879542108515 +CONST137 = -1.63671408859718 +CONST138 = -15.0007324039945 +CONST139 = -13.5214774630291 +CONST140 = -13.8216881204866 +CONST141 = -13.0937127087774 +CONST142 = -13.3827919767794 +CONST143 = -9.82028453158308 +CONST144 = -4.91014226579154 +CONST145 = 511.706703102471 +VAR00 = x**10 +VAR01 = x**9 +VAR02 = x**8 +VAR03 = x**7 +VAR04 = x**6 +VAR05 = x**5 +VAR06 = x**4 +VAR07 = x**3 +VAR08 = x**2 +VAR09 = y**10 +VAR10 = y**9 +VAR11 = y**8 +VAR12 = y**7 +VAR13 = y**6 +VAR14 = y**5 +VAR15 = y**4 +VAR16 = y**3 +VAR17 = y**2 +VAR18 = z**10 +VAR19 = z**9 +VAR20 = z**8 +VAR21 = z**7 +VAR22 = z**6 +VAR23 = z**5 +VAR24 = z**4 +VAR25 = z**3 +VAR26 = z**2 +# -------------------- kernel implementations +Y00 = CONST001*VAR01 + CONST020*VAR20*x + CONST078*VAR07*VAR22 + CONST091*VAR05*VAR24 + CONST105*VAR03*VAR26 +Y01 = y*(-CONST099*VAR05*VAR25 + CONST099*VAR07*VAR23 + CONST106*VAR03*z - CONST106*VAR21*x) +Y02 = CONST000*VAR01 + VAR03*(CONST129*VAR26 + CONST130*VAR17) + VAR05*(CONST021*VAR24 - CONST097*VAR17*VAR26) + VAR07*(CONST120*VAR17*VAR24 - CONST124*VAR22) + x*(-CONST080*VAR17*VAR22 + CONST139*VAR20) +Y03 = VAR16*(CONST077*VAR07*VAR25 + CONST095*VAR05*z + CONST096*VAR23*x) + y*(-CONST089*VAR05*VAR25 - CONST089*VAR07*VAR23 + CONST109*VAR03*z + CONST109*VAR21*x) +Y04 = CONST002*VAR01 + CONST007*VAR20*x + CONST135*VAR05*VAR24 + CONST140*VAR03*VAR26 + VAR15*(CONST032*VAR07*VAR26 + CONST047*VAR05 + CONST131*VAR24*x) + VAR17*(-CONST071*VAR07*VAR24 + CONST071*VAR22*x + CONST111*VAR05*VAR26 + CONST127*VAR03) +Y05 = VAR14*(CONST030*VAR07*z - CONST030*VAR25*x) + VAR16*(CONST030*VAR23*x + CONST125*VAR05*z) + y*(CONST034*VAR07*VAR23 + CONST121*VAR05*VAR25 - CONST121*VAR21*x + CONST122*VAR03*z) +Y06 = CONST119*VAR03*VAR17 - CONST137*VAR01 + VAR05*(CONST035*VAR17*VAR26 - CONST086*VAR15 + CONST143*VAR24) + VAR07*(CONST051*VAR15*VAR26 - CONST065*VAR17*VAR24 + CONST103*VAR13 + CONST141*VAR22) + x*(-CONST062*VAR13*VAR26 - CONST092*VAR17*VAR22 + CONST112*VAR15*VAR24 + CONST144*VAR20) +Y07 = CONST132*VAR03*y*z + VAR05*(CONST081*VAR16*z + CONST107*VAR25*y) + VAR07*(CONST026*VAR14*z + CONST044*VAR16*VAR25 + CONST107*VAR23*y) + x*(CONST025*VAR14*VAR25 + CONST053*VAR12*z + CONST081*VAR16*VAR23 + CONST132*VAR21*y) +Y08 = CONST004*VAR01 + VAR03*(CONST006*VAR26 + CONST116*VAR17) + VAR05*(CONST008*VAR24 + CONST069*VAR15 + CONST087*VAR17*VAR26) + VAR07*(CONST005*VAR22 + CONST083*VAR13 + CONST087*VAR17*VAR24 + CONST145*VAR15*VAR26) + x*(CONST004*VAR20 + CONST022*VAR11 + CONST069*VAR15*VAR24 + CONST082*VAR13*VAR26 + CONST116*VAR17*VAR22) +Y09 = CONST009*VAR10 + VAR12*(CONST110*VAR26 + CONST113*VAR08) + VAR14*(CONST063*VAR06 + CONST063*VAR24 + CONST104*VAR08*VAR26) + VAR16*(CONST056*VAR06*VAR26 + CONST056*VAR08*VAR24 + CONST101*VAR04 + CONST101*VAR22) + y*(CONST010*VAR20 + CONST011*VAR02 + CONST029*VAR04*VAR26 + CONST029*VAR08*VAR22 + CONST039*VAR06*VAR24) +Y10 = CONST004*VAR19 + VAR21*(CONST005*VAR08 + CONST117*VAR17) + VAR23*(CONST008*VAR06 + CONST070*VAR15 + CONST088*VAR08*VAR17) + VAR25*(CONST012*VAR04 + CONST082*VAR13 + CONST087*VAR06*VAR17 + CONST145*VAR08*VAR15) + z*(CONST004*VAR02 + CONST023*VAR11 + CONST070*VAR06*VAR15 + CONST084*VAR08*VAR13 + CONST117*VAR04*VAR17) +Y11 = VAR12*(CONST115*VAR08 - CONST115*VAR26) + VAR14*(CONST066*VAR06 + CONST072*VAR24) + VAR16*(CONST055*VAR08*VAR24 + CONST093*VAR04 + CONST093*VAR06*VAR26 - CONST093*VAR22) + y*(CONST013*VAR02 + CONST024*VAR04*VAR26 + CONST133*VAR08*VAR22 + CONST138*VAR20) +Y12 = CONST036*VAR17*VAR21 + CONST137*VAR19 + VAR23*(CONST086*VAR15 + CONST123*VAR08*VAR17 - CONST143*VAR06) + VAR25*(CONST014*VAR04 - CONST051*VAR08*VAR15 + CONST065*VAR06*VAR17 - CONST103*VAR13) + z*(CONST062*VAR08*VAR13 + CONST092*VAR04*VAR17 - CONST112*VAR06*VAR15 - CONST144*VAR02) +Y13 = VAR14*(CONST049*VAR06 + CONST049*VAR24 + CONST090*VAR08*VAR26) + VAR16*(CONST040*VAR06*VAR26 + CONST040*VAR08*VAR24 + CONST100*VAR22 + CONST102*VAR04) + y*(CONST016*VAR20 + CONST017*VAR02 + CONST094*VAR06*VAR24 + CONST121*VAR04*VAR26 + CONST122*VAR08*VAR22) +Y14 = CONST007*VAR02*z + CONST075*VAR19 + CONST136*VAR06*VAR23 + CONST140*VAR08*VAR21 + VAR15*(CONST032*VAR08*VAR25 + CONST047*VAR23 + CONST131*VAR06*z) + VAR17*(CONST068*VAR06*VAR25 + CONST073*VAR04*z + CONST114*VAR08*VAR23 + CONST128*VAR21) +Y15 = VAR16*(CONST037*VAR22 - CONST045*VAR06*VAR26 + CONST045*VAR08*VAR24 + CONST118*VAR04) + y*(CONST018*VAR02 + CONST089*VAR04*VAR26 - CONST089*VAR08*VAR22 + CONST142*VAR20) +Y16 = CONST019*VAR02*z + CONST076*VAR19 + CONST124*VAR04*VAR25 - CONST129*VAR08*VAR21 + CONST134*VAR06*VAR23 + VAR17*(CONST038*VAR06*VAR25 + CONST080*VAR04*z + CONST097*VAR08*VAR23 - CONST130*VAR21) +Y17 = y*(CONST058*VAR02 + CONST058*VAR20 + CONST061*VAR04*VAR26 + CONST061*VAR08*VAR22 + CONST074*VAR06*VAR24) +Y18 = CONST001*VAR19 + CONST020*VAR02*z + CONST078*VAR04*VAR25 + CONST091*VAR06*VAR23 + CONST105*VAR08*VAR21 diff --git a/pyproject.toml b/pyproject.toml index e2e5acf..4cbc7fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,8 @@ dependencies = [ "triton", "torch", "e3nn", - "tqdm" + "tqdm", + "matplotlib" ] description = "Triton-lang implementations of kernels for equivariant neural networks." dynamic = ["version", "readme"] @@ -33,6 +34,14 @@ dev = [ "pytest-pretty", "jupyter" ] +train = [ + "torch-geometric", + "torch-scatter", + "torch-sparse", + "torch-cluster", + "pytorch-lightning==2.2.4", + "jsonargparse[signatures]" +] [tool.setuptools.dynamic] readme = {file = ["README.md"]} diff --git a/scripts/.gitignore b/scripts/.gitignore new file mode 100644 index 0000000..22b1587 --- /dev/null +++ b/scripts/.gitignore @@ -0,0 +1,4 @@ +qm9_data/ +wandb/ +lightning_logs/ +artifacts/ diff --git a/scripts/experiment-requirements.txt b/scripts/experiment-requirements.txt new file mode 100644 index 0000000..ecff73f --- /dev/null +++ b/scripts/experiment-requirements.txt @@ -0,0 +1,3 @@ +phate==1.0.11 +rdkit==2023.9.5 +wandb==0.17.7 diff --git a/scripts/generate_phate_embeddings.py b/scripts/generate_phate_embeddings.py new file mode 100644 index 0000000..8641c9c --- /dev/null +++ b/scripts/generate_phate_embeddings.py @@ -0,0 +1,235 @@ +from __future__ import annotations + +from argparse import ArgumentParser +from pathlib import Path + +import numpy as np +import pytorch_lightning as pl +import torch +import wandb +from e3nn import o3 +from phate import PHATE +from rdkit import Chem +from torch_geometric.data import Data +from tqdm import tqdm + +from equitriton.model.lightning import EquiTritonLitModule, LightningQM9 +from equitriton.utils import separate_embedding_irreps + + +def graph_to_rdkit(batched_graph: Data) -> Chem.Mol: + """ + Simple list comprehension to unpack a batch into SMILES, then ``Mol`` + objects. + + Parameters + ---------- + batched_graph : Data + Batched graph structure from QM9 that contains multiple + individual molecules. Individual SMILES are extracted from + the ``smiles`` attribute. + """ + mols = [Chem.MolFromSmiles(smi, sanitize=False) for smi in batched_graph.smiles] + return mols + + +def score_molecule(molecule: Chem.Mol) -> dict[str, int]: + """ + Given an RDKit molecule, compute the NSPS related metrics. + + Parameters + ---------- + molecule : Chem.Mol + Molecule representation in RDKit. + + Returns + ------- + dict[str, int] + Dictionary mapping of property and value. + """ + enum = {"SP": 1, "SP2": 2, "SP3": 3} + scores = {"stereo": 0, "hybrid": 0, "aromatic": 0, "heavy_atoms": 0} + for atom in tqdm( + molecule.GetAtoms(), desc="Atoms in a molecule", leave=False, position=3 + ): + hybrid = enum.get(str(atom.GetHybridization()), 0) + # loop over bonds on the atom to check if it has stereoisomers + has_stereo = any( + [ + True if b.GetStereo() == Chem.BondStereo.STEREOE else False + for b in atom.GetBonds() + ] + ) + s = 2 if has_stereo else 1 + r = int(atom.GetIsAromatic()) + heavy_atoms = sum( + [neighbor.GetAtomicNum() > 1 for neighbor in atom.GetNeighbors()] + ) + scores["stereo"] += s + scores["hybrid"] += hybrid + scores["aromatic"] += r + scores["heavy_atoms"] += heavy_atoms + return scores + + +def calculate_scores_for_batch(molecules) -> list[dict[str, int]]: + """ + Calculates scores for every graph in a batch. + """ + scores = [ + score_molecule(mol) + for mol in tqdm(molecules, desc="Scoring molecules", leave=False, position=2) + ] + return scores + + +def run_phate_projection( + results: list[dict], irreps: o3.Irreps, **phate_kwargs +) -> dict[str, np.ndarray]: + """ + Wrapper function that applies the PHATE method to individual + irreducible representations, given a set of ``o3.Irreps``. + + We apply the same PHATE hyperparameters to each decomposed + irreducible representation, followed by the same method + applied to the joint embeddings (i.e. a vector comprising + all representations). + + Parameters + ---------- + results : list[dict] + List of dictionaries containing the results from inference, + which contains graph embeddings under the ``embeddings`` key. + irreps : o3.Irreps + Object corresponding to the group of irreducible representations + contained in the set of embeddings. This is passed into the + ``separate_embedding_irreps`` function that splits them into + a order/vector mapping. + + Returns + ------- + dict[str, np.ndarray] + Dictionary mapping of keys (irrep. order) and PHATE projections. + Each order key is represented by the number, with the exception + of ``full`` which contains PHATE projections for the joint graph + embeddings. + """ + phate_kwargs.setdefault("knn", 10) + phate_kwargs.setdefault("random_state", 21516) + embeddings = torch.vstack([r["embeddings"][1] for r in results]).numpy() + # separate embeddings into individual chunks + chunk_dict = separate_embedding_irreps(embeddings, irreps, return_numpy=True) + embeddings_dict = {} + for order, chunk in chunk_dict.items(): + print(f"Running PHATE on order {order}") + # collect up all the embeddings + phate = PHATE(**phate_kwargs) + phate_embeddings = phate.fit_transform(chunk) + embeddings_dict[order] = phate_embeddings + # run once more on the full embedding set + phate = PHATE(**phate_kwargs) + phate_embeddings = phate.fit_transform(embeddings) + embeddings_dict["full"] = phate_embeddings + return embeddings_dict + + +def main(): + parser = ArgumentParser() + parser.add_argument( + "artifact_path", type=str, help="wandb path to a model artifact." + ) + pl.seed_everything(215162) + + args = parser.parse_args() + + inference_run = wandb.init( + job_type="eval", + entity="laserkelvin", + project="equitriton-qm9", + tags=["inference", "embeddings", "qm9"], + ) + + artifact = inference_run.use_artifact(args.artifact_path, type="model") + artifact_dir = artifact.download() + ckpt_path = Path(artifact_dir).joinpath("model.ckpt") + + datamodule = LightningQM9("./qm9_data", num_workers=0) + model = EquiTritonLitModule.load_from_checkpoint(str(ckpt_path)).eval() + + datamodule.setup("test") + test_loader = datamodule.test_dataloader() + + results = [] + all_smi = [] + all_error = [] + score_dict = {} + for _, batch in tqdm( + enumerate(test_loader), + desc="Batches to process", + leave=False, + position=1, + total=len(test_loader), + ): + embeddings = model.model.embed(batch.to("cuda")) + with torch.no_grad(): + g_z, z = model.model(batch) + pred_energies = model.output_head(g_z) + # un-normalize energy + pred_energies = (model.hparams["e_std"] * pred_energies) + model.hparams[ + "e_mean" + ] + # retrieve targets + target_energies = batch.y[:, 12].unsqueeze(-1) + error = (pred_energies - target_energies).pow(2.0).cpu().tolist() + mols = graph_to_rdkit(batch) + scores = calculate_scores_for_batch(mols) + package = { + "embeddings": embeddings["graph_z"], + "scores": scores, + "smi": batch.smiles, + "error": error, + } + all_smi.extend(batch.smiles) + all_error.extend(error) + # reformat scores into a flat dictionary + for score in scores: + for key, value in score.items(): + if key not in score_dict: + score_dict[key] = [] + score_dict[key].append(value) + results.append(package) + print("Running PHATE on each Irreps") + phate_embeddings = run_phate_projection( + results, model.model.initial_layer.output_irreps + ) + to_save = {"phate": phate_embeddings, "data": results} + # save a local version of the results + torch.save(to_save, Path(artifact_dir).joinpath("results.pt")) + # formatting stuff to log to wandb + embedding_table = wandb.Table( + columns=["F1", "F2"], data=phate_embeddings["full"].tolist() + ) + for key, array in phate_embeddings.items(): + if key != "full": + for axis in [0, 1]: + embedding_table.add_column(name=f"O{key}_{axis}", data=array[:, axis]) + # this initializes the table + joint_table = wandb.Table(columns=["smiles"]) + # i'm not sure why, but the table kept fussing about not being + # to add the list of smiles directly, which is why it's written as a loop + for smi in all_smi: + joint_table.add_data(smi) + joint_table.add_column(name="squared_error_eV", data=all_error) + # now add the descriptors as well + for key, value in score_dict.items(): + joint_table.add_column(name=key, data=value) + # package stuff up and log to wandb + inference_artifact = wandb.Artifact("qm9_inference", type="eval") + inference_artifact.add(embedding_table, "phate") + inference_artifact.add(joint_table, "descriptors") + inference_run.log_artifact(inference_artifact) + wandb.finish() + + +if __name__ == "__main__": + main() diff --git a/scripts/model_configs/10-lonely.yaml b/scripts/model_configs/10-lonely.yaml new file mode 100644 index 0000000..d29e04e --- /dev/null +++ b/scripts/model_configs/10-lonely.yaml @@ -0,0 +1,46 @@ +# pytorch_lightning==2.2.4 +seed_everything: 21616 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: null + logger: + class_path: pytorch_lightning.loggers.WandbLogger + init_args: + project: "equitriton-qm9" + entity: "laserkelvin" + log_model: true + #callbacks: + # - class_path: pytorch_lightning.callbacks.EarlyStopping + # init_args: + # monitor: "val_loss_epoch" + # patience: 5 + # mode: "min" + max_epochs: 100 + min_epochs: 15 +model: + model_class: equitriton.model.EquiTritonModel + model_kwargs: + initial_atom_dim: 64 + num_layers: 3 + output_dim: 1 + l_values: [0, 1, 2, 10] # just including 10 + edge_dim: 20 + hidden_dim: 16 + radius_cutoff: 6.0 + degree_norm: 6.08275253 # sqrt(37), avg degree + sph_harm_kwargs: + use_e3nn: false + e_mean: -76.1160 + e_std: 10.3238 + lr: 0.001 + weight_decay: 0.0 + atom_weighted_loss: false +data: + root_path: ./qm9_data + batch_size: 32 + train_frac: 0.8 + val_frac: 0.1 + num_workers: 4 diff --git a/scripts/model_configs/4-lonely.yaml b/scripts/model_configs/4-lonely.yaml new file mode 100644 index 0000000..8c705b8 --- /dev/null +++ b/scripts/model_configs/4-lonely.yaml @@ -0,0 +1,46 @@ +# pytorch_lightning==2.2.4 +seed_everything: 21616 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: null + logger: + class_path: pytorch_lightning.loggers.WandbLogger + init_args: + project: "equitriton-qm9" + entity: "laserkelvin" + log_model: true + #callbacks: + # - class_path: pytorch_lightning.callbacks.EarlyStopping + # init_args: + # monitor: "val_loss_epoch" + # patience: 5 + # mode: "min" + max_epochs: 100 + min_epochs: 15 +model: + model_class: equitriton.model.EquiTritonModel + model_kwargs: + initial_atom_dim: 64 + num_layers: 3 + output_dim: 1 + l_values: [0, 1, 2, 4] # is a higher order odd/even pair + edge_dim: 20 + hidden_dim: 32 + radius_cutoff: 6.0 + degree_norm: 6.08275253 # sqrt(37), avg degree + sph_harm_kwargs: + use_e3nn: false + e_mean: -76.1160 + e_std: 10.3238 + lr: 0.001 + weight_decay: 0.0 + atom_weighted_loss: false +data: + root_path: ./qm9_data + batch_size: 32 + train_frac: 0.8 + val_frac: 0.1 + num_workers: 4 diff --git a/scripts/model_configs/6-lonely.yaml b/scripts/model_configs/6-lonely.yaml new file mode 100644 index 0000000..8eecae8 --- /dev/null +++ b/scripts/model_configs/6-lonely.yaml @@ -0,0 +1,46 @@ +# pytorch_lightning==2.2.4 +seed_everything: 21616 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: null + logger: + class_path: pytorch_lightning.loggers.WandbLogger + init_args: + project: "equitriton-qm9" + entity: "laserkelvin" + log_model: true + #callbacks: + # - class_path: pytorch_lightning.callbacks.EarlyStopping + # init_args: + # monitor: "val_loss_epoch" + # patience: 5 + # mode: "min" + max_epochs: 100 + min_epochs: 15 +model: + model_class: equitriton.model.EquiTritonModel + model_kwargs: + initial_atom_dim: 64 + num_layers: 3 + output_dim: 1 + l_values: [0, 1, 2, 6] # is a higher order odd/even pair + edge_dim: 20 + hidden_dim: 32 + radius_cutoff: 6.0 + degree_norm: 6.08275253 # sqrt(37), avg degree + sph_harm_kwargs: + use_e3nn: false + e_mean: -76.1160 + e_std: 10.3238 + lr: 0.001 + weight_decay: 0.0 + atom_weighted_loss: false +data: + root_path: ./qm9_data + batch_size: 32 + train_frac: 0.8 + val_frac: 0.1 + num_workers: 4 diff --git a/scripts/model_configs/8-lonely.yaml b/scripts/model_configs/8-lonely.yaml new file mode 100644 index 0000000..00338c3 --- /dev/null +++ b/scripts/model_configs/8-lonely.yaml @@ -0,0 +1,46 @@ +# pytorch_lightning==2.2.4 +seed_everything: 21616 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: null + logger: + class_path: pytorch_lightning.loggers.WandbLogger + init_args: + project: "equitriton-qm9" + entity: "laserkelvin" + log_model: true + #callbacks: + # - class_path: pytorch_lightning.callbacks.EarlyStopping + # init_args: + # monitor: "val_loss_epoch" + # patience: 5 + # mode: "min" + max_epochs: 100 + min_epochs: 15 +model: + model_class: equitriton.model.EquiTritonModel + model_kwargs: + initial_atom_dim: 64 + num_layers: 3 + output_dim: 1 + l_values: [0, 1, 2, 8] # just including 8 + edge_dim: 20 + hidden_dim: 32 + radius_cutoff: 6.0 + degree_norm: 6.08275253 # sqrt(37), avg degree + sph_harm_kwargs: + use_e3nn: false + e_mean: -76.1160 + e_std: 10.3238 + lr: 0.001 + weight_decay: 0.0 + atom_weighted_loss: false +data: + root_path: ./qm9_data + batch_size: 32 + train_frac: 0.8 + val_frac: 0.1 + num_workers: 4 diff --git a/scripts/model_configs/baseline.yaml.ignore b/scripts/model_configs/baseline.yaml.ignore new file mode 100644 index 0000000..851dc8c --- /dev/null +++ b/scripts/model_configs/baseline.yaml.ignore @@ -0,0 +1,47 @@ +# pytorch_lightning==2.2.4 +seed_everything: 21616 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: null + logger: + class_path: pytorch_lightning.loggers.WandbLogger + init_args: + project: "equitriton-qm9" + entity: "laserkelvin" + log_model: true + tags: ["baseline"] + #callbacks: + # - class_path: pytorch_lightning.callbacks.EarlyStopping + # init_args: + # monitor: "val_loss_epoch" + # patience: 5 + # mode: "min" + max_epochs: 100 + min_epochs: 15 +model: + model_class: equitriton.model.EquiTritonModel + model_kwargs: + initial_atom_dim: 64 + num_layers: 3 + output_dim: 1 + l_values: [0,] + edge_dim: 20 + hidden_dim: 32 + radius_cutoff: 6.0 + degree_norm: 6.08275253 # sqrt(37), avg degree + sph_harm_kwargs: + use_e3nn: false + e_mean: -76.1160 + e_std: 10.3238 + lr: 0.001 + weight_decay: 0.0 + atom_weighted_loss: false +data: + root_path: ./qm9_data + batch_size: 32 + train_frac: 0.8 + val_frac: 0.1 + num_workers: 4 diff --git a/scripts/model_configs/baseline_big.yaml b/scripts/model_configs/baseline_big.yaml new file mode 100644 index 0000000..2dad118 --- /dev/null +++ b/scripts/model_configs/baseline_big.yaml @@ -0,0 +1,47 @@ +# pytorch_lightning==2.2.4 +seed_everything: 21616 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: null + logger: + class_path: pytorch_lightning.loggers.WandbLogger + init_args: + project: "equitriton-qm9" + entity: "laserkelvin" + log_model: true + tags: ["baseline"] + #callbacks: + # - class_path: pytorch_lightning.callbacks.EarlyStopping + # init_args: + # monitor: "val_loss_epoch" + # patience: 5 + # mode: "min" + max_epochs: 100 + min_epochs: 15 +model: + model_class: equitriton.model.EquiTritonModel + model_kwargs: + initial_atom_dim: 64 + num_layers: 3 + output_dim: 1 + l_values: [0,] + edge_dim: 20 + hidden_dim: 128 + radius_cutoff: 6.0 + degree_norm: 6.08275253 # sqrt(37), avg degree + sph_harm_kwargs: + use_e3nn: false + e_mean: -76.1160 + e_std: 10.3238 + lr: 0.001 + weight_decay: 0.0 + atom_weighted_loss: false +data: + root_path: ./qm9_data + batch_size: 32 + train_frac: 0.8 + val_frac: 0.1 + num_workers: 4 diff --git a/scripts/model_configs/e3nn.yaml b/scripts/model_configs/e3nn.yaml new file mode 100644 index 0000000..6784b53 --- /dev/null +++ b/scripts/model_configs/e3nn.yaml @@ -0,0 +1,47 @@ +# pytorch_lightning==2.2.4 +seed_everything: 21616 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: null + logger: + class_path: pytorch_lightning.loggers.WandbLogger + init_args: + project: "equitriton-qm9" + entity: "laserkelvin" + log_model: true + tags: ["baseline", "e3nn"] + #callbacks: + # - class_path: pytorch_lightning.callbacks.EarlyStopping + # init_args: + # monitor: "val_loss_epoch" + # patience: 5 + # mode: "min" + max_epochs: 100 + min_epochs: 15 +model: + model_class: equitriton.model.EquiTritonModel + model_kwargs: + initial_atom_dim: 64 + num_layers: 3 + output_dim: 1 + l_values: [0,1,2] + edge_dim: 20 + hidden_dim: 32 + radius_cutoff: 6.0 + degree_norm: 6.08275253 # sqrt(37), avg degree + sph_harm_kwargs: + use_e3nn: true + e_mean: -76.1160 + e_std: 10.3238 + lr: 0.001 + weight_decay: 0.0 + atom_weighted_loss: false +data: + root_path: ./qm9_data + batch_size: 32 + train_frac: 0.8 + val_frac: 0.1 + num_workers: 4 diff --git a/scripts/model_configs/equivariant.yaml b/scripts/model_configs/equivariant.yaml new file mode 100644 index 0000000..513f795 --- /dev/null +++ b/scripts/model_configs/equivariant.yaml @@ -0,0 +1,46 @@ +#pytorch_lightning==2.2.4 +seed_everything: 21616 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: null + logger: + class_path: pytorch_lightning.loggers.WandbLogger + init_args: + project: "equitriton-qm9" + entity: "laserkelvin" + log_model: true + #callbacks: + # - class_path: pytorch_lightning.callbacks.EarlyStopping + # init_args: + # monitor: "val_loss_epoch" + # patience: 5 + # mode: "min" + max_epochs: 100 + min_epochs: 15 +model: + model_class: equitriton.model.EquiTritonModel + model_kwargs: + initial_atom_dim: 64 + num_layers: 3 + output_dim: 1 + l_values: [0, 1, 2] # this is the canonical set + edge_dim: 20 + hidden_dim: 32 + radius_cutoff: 6.0 + degree_norm: 6.08275253 # sqrt(37), avg degree + sph_harm_kwargs: + use_e3nn: false + e_mean: -76.1160 + e_std: 10.3238 + lr: 0.001 + weight_decay: 0.0 + atom_weighted_loss: false +data: + root_path: ./qm9_data + batch_size: 32 + train_frac: 0.8 + val_frac: 0.1 + num_workers: 4 diff --git a/scripts/model_configs/even.yaml b/scripts/model_configs/even.yaml new file mode 100644 index 0000000..2d155b1 --- /dev/null +++ b/scripts/model_configs/even.yaml @@ -0,0 +1,46 @@ +# pytorch_lightning==2.2.4 +seed_everything: 21616 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: null + logger: + class_path: pytorch_lightning.loggers.WandbLogger + init_args: + project: "equitriton-qm9" + entity: "laserkelvin" + log_model: true + #callbacks: + # - class_path: pytorch_lightning.callbacks.EarlyStopping + # init_args: + # monitor: "val_loss_epoch" + # patience: 5 + # mode: "min" + max_epochs: 100 + min_epochs: 15 +model: + model_class: equitriton.model.EquiTritonModel + model_kwargs: + initial_atom_dim: 64 + num_layers: 3 + output_dim: 1 + l_values: [0, 1, 2, 4, 6, 8, 10] # this is just even parity + edge_dim: 20 + hidden_dim: 32 + radius_cutoff: 6.0 + degree_norm: 6.08275253 # sqrt(37), avg degree + sph_harm_kwargs: + use_e3nn: false + e_mean: -76.1160 + e_std: 10.3238 + lr: 0.001 + weight_decay: 0.0 + atom_weighted_loss: false +data: + root_path: ./qm9_data + batch_size: 16 + train_frac: 0.8 + val_frac: 0.1 + num_workers: 4 diff --git a/scripts/model_configs/full.yaml b/scripts/model_configs/full.yaml new file mode 100644 index 0000000..3ccef52 --- /dev/null +++ b/scripts/model_configs/full.yaml @@ -0,0 +1,46 @@ +# pytorch_lightning==2.2.4 +seed_everything: 21616 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: null + logger: + class_path: pytorch_lightning.loggers.WandbLogger + init_args: + project: "equitriton-qm9" + entity: "laserkelvin" + log_model: true + #callbacks: + # - class_path: pytorch_lightning.callbacks.EarlyStopping + # init_args: + # monitor: "val_loss_epoch" + # patience: 5 + # mode: "min" + max_epochs: 100 + min_epochs: 15 +model: + model_class: equitriton.model.EquiTritonModel + model_kwargs: + initial_atom_dim: 64 + num_layers: 3 + output_dim: 1 + l_values: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # this up to l=10 + edge_dim: 20 + hidden_dim: 16 + radius_cutoff: 6.0 + degree_norm: 6.08275253 # sqrt(37), avg degree + sph_harm_kwargs: + use_e3nn: false + e_mean: -76.1160 + e_std: 10.3238 + lr: 0.001 + weight_decay: 0.0 + atom_weighted_loss: false +data: + root_path: ./qm9_data + batch_size: 16 + train_frac: 0.8 + val_frac: 0.1 + num_workers: 4 diff --git a/scripts/model_configs/micro-full.yaml b/scripts/model_configs/micro-full.yaml new file mode 100644 index 0000000..c837821 --- /dev/null +++ b/scripts/model_configs/micro-full.yaml @@ -0,0 +1,45 @@ +# pytorch_lightning==2.2.4 +seed_everything: 21616 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: null + logger: + class_path: pytorch_lightning.loggers.WandbLogger + init_args: + project: "equitriton-qm9" + entity: "laserkelvin" + log_model: "all" + #callbacks: + # - class_path: pytorch_lightning.callbacks.EarlyStopping + # init_args: + # monitor: "val_loss_epoch" + # patience: 5 + # mode: "min" + max_epochs: 30 +model: + model_class: equitriton.model.EquiTritonModel + model_kwargs: + initial_atom_dim: 64 + num_layers: 3 + output_dim: 1 + l_values: [0, 1, 2, 3, 4, 5, 6, 7, 8] # not quite up to ten + edge_dim: 20 + hidden_dim: 16 + radius_cutoff: 6.0 + degree_norm: 6.08275253 # sqrt(37), avg degree + sph_harm_kwargs: + use_e3nn: false + e_mean: -76.1160 + e_std: 10.3238 + lr: 0.001 + weight_decay: 0.0 + atom_weighted_loss: false +data: + root_path: ./qm9_data + batch_size: 32 + train_frac: 0.8 + val_frac: 0.1 + num_workers: 4 diff --git a/scripts/model_configs/mini-full.yaml b/scripts/model_configs/mini-full.yaml new file mode 100644 index 0000000..95d0665 --- /dev/null +++ b/scripts/model_configs/mini-full.yaml @@ -0,0 +1,45 @@ +# pytorch_lightning==2.2.4 +seed_everything: 21616 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: null + logger: + class_path: pytorch_lightning.loggers.WandbLogger + init_args: + project: "equitriton-qm9" + entity: "laserkelvin" + log_model: "all" + #callbacks: + # - class_path: pytorch_lightning.callbacks.EarlyStopping + # init_args: + # monitor: "val_loss_epoch" + # patience: 5 + # mode: "min" + max_epochs: 30 +model: + model_class: equitriton.model.EquiTritonModel + model_kwargs: + initial_atom_dim: 64 + num_layers: 3 + output_dim: 1 + l_values: [0, 1, 2, 3, 4, 5, 6] # not quite up to ten + edge_dim: 20 + hidden_dim: 16 + radius_cutoff: 6.0 + degree_norm: 6.08275253 # sqrt(37), avg degree + sph_harm_kwargs: + use_e3nn: false + e_mean: -76.1160 + e_std: 10.3238 + lr: 0.001 + weight_decay: 0.0 + atom_weighted_loss: false +data: + root_path: ./qm9_data + batch_size: 32 + train_frac: 0.8 + val_frac: 0.1 + num_workers: 4 diff --git a/scripts/model_configs/mini-long.yaml b/scripts/model_configs/mini-long.yaml new file mode 100644 index 0000000..fe944b0 --- /dev/null +++ b/scripts/model_configs/mini-long.yaml @@ -0,0 +1,39 @@ +# pytorch_lightning==2.2.4 +seed_everything: 21616 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: null + logger: + class_path: pytorch_lightning.loggers.WandbLogger + init_args: + project: "equitriton-qm9" + entity: "laserkelvin" + log_model: "all" + max_epochs: 100 +model: + model_class: equitriton.model.EquiTritonModel + model_kwargs: + initial_atom_dim: 64 + num_layers: 3 + output_dim: 1 + l_values: [0, 1, 2, 3, 4, 5, 6] # not quite up to ten + edge_dim: 20 + hidden_dim: 16 + radius_cutoff: 6.0 + degree_norm: 6.08275253 # sqrt(37), avg degree + sph_harm_kwargs: + use_e3nn: false + e_mean: -76.1160 + e_std: 10.3238 + lr: 0.001 + weight_decay: 0.0 + atom_weighted_loss: false +data: + root_path: ./qm9_data + batch_size: 32 + train_frac: 0.8 + val_frac: 0.1 + num_workers: 4 diff --git a/scripts/model_configs/nano-full.yaml b/scripts/model_configs/nano-full.yaml new file mode 100644 index 0000000..c1cd5f9 --- /dev/null +++ b/scripts/model_configs/nano-full.yaml @@ -0,0 +1,45 @@ +# pytorch_lightning==2.2.4 +seed_everything: 21616 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: null + logger: + class_path: pytorch_lightning.loggers.WandbLogger + init_args: + project: "equitriton-qm9" + entity: "laserkelvin" + log_model: "all" + #callbacks: + # - class_path: pytorch_lightning.callbacks.EarlyStopping + # init_args: + # monitor: "val_loss_epoch" + # patience: 5 + # mode: "min" + max_epochs: 30 +model: + model_class: equitriton.model.EquiTritonModel + model_kwargs: + initial_atom_dim: 64 + num_layers: 3 + output_dim: 1 + l_values: [0, 1, 2, 3, 4] # not quite up to ten + edge_dim: 20 + hidden_dim: 16 + radius_cutoff: 6.0 + degree_norm: 6.08275253 # sqrt(37), avg degree + sph_harm_kwargs: + use_e3nn: false + e_mean: -76.1160 + e_std: 10.3238 + lr: 0.001 + weight_decay: 0.0 + atom_weighted_loss: false +data: + root_path: ./qm9_data + batch_size: 32 + train_frac: 0.8 + val_frac: 0.1 + num_workers: 4 diff --git a/scripts/model_configs/nano-long.yaml b/scripts/model_configs/nano-long.yaml new file mode 100644 index 0000000..de855b5 --- /dev/null +++ b/scripts/model_configs/nano-long.yaml @@ -0,0 +1,45 @@ +# pytorch_lightning==2.2.4 +seed_everything: 21616 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: null + logger: + class_path: pytorch_lightning.loggers.WandbLogger + init_args: + project: "equitriton-qm9" + entity: "laserkelvin" + log_model: "all" + #callbacks: + # - class_path: pytorch_lightning.callbacks.EarlyStopping + # init_args: + # monitor: "val_loss_epoch" + # patience: 5 + # mode: "min" + max_epochs: 100 +model: + model_class: equitriton.model.EquiTritonModel + model_kwargs: + initial_atom_dim: 64 + num_layers: 3 + output_dim: 1 + l_values: [0, 1, 2, 3, 4] # not quite up to ten + edge_dim: 20 + hidden_dim: 16 + radius_cutoff: 6.0 + degree_norm: 6.08275253 # sqrt(37), avg degree + sph_harm_kwargs: + use_e3nn: false + e_mean: -76.1160 + e_std: 10.3238 + lr: 0.001 + weight_decay: 0.0 + atom_weighted_loss: false +data: + root_path: ./qm9_data + batch_size: 32 + train_frac: 0.8 + val_frac: 0.1 + num_workers: 4 diff --git a/scripts/model_configs/skipped.yaml b/scripts/model_configs/skipped.yaml new file mode 100644 index 0000000..13efb81 --- /dev/null +++ b/scripts/model_configs/skipped.yaml @@ -0,0 +1,46 @@ +# pytorch_lightning==2.2.4 +seed_everything: 21616 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: null + logger: + class_path: pytorch_lightning.loggers.WandbLogger + init_args: + project: "equitriton-qm9" + entity: "laserkelvin" + log_model: true + #callbacks: + # - class_path: pytorch_lightning.callbacks.EarlyStopping + # init_args: + # monitor: "val_loss_epoch" + # patience: 5 + # mode: "min" + max_epochs: 100 + min_epochs: 15 +model: + model_class: equitriton.model.EquiTritonModel + model_kwargs: + initial_atom_dim: 64 + num_layers: 3 + output_dim: 1 + l_values: [0, 1, 2, 5, 6] # is a higher order odd/even pair + edge_dim: 20 + hidden_dim: 32 + radius_cutoff: 6.0 + degree_norm: 6.08275253 # sqrt(37), avg degree + sph_harm_kwargs: + use_e3nn: false + e_mean: -76.1160 + e_std: 10.3238 + lr: 0.001 + weight_decay: 0.0 + atom_weighted_loss: false +data: + root_path: ./qm9_data + batch_size: 32 + train_frac: 0.8 + val_frac: 0.1 + num_workers: 4 diff --git a/scripts/model_configs/unconventional.yaml b/scripts/model_configs/unconventional.yaml new file mode 100644 index 0000000..bfea9da --- /dev/null +++ b/scripts/model_configs/unconventional.yaml @@ -0,0 +1,45 @@ +# pytorch_lightning==2.2.4 +seed_everything: 21616 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: null + logger: + class_path: pytorch_lightning.loggers.WandbLogger + init_args: + project: "equitriton-qm9" + entity: "laserkelvin" + log_model: "all" + #callbacks: + # - class_path: pytorch_lightning.callbacks.EarlyStopping + # init_args: + # monitor: "val_loss_epoch" + # patience: 5 + # mode: "min" + max_epochs: 30 +model: + model_class: equitriton.model.EquiTritonModel + model_kwargs: + initial_atom_dim: 64 + num_layers: 3 + output_dim: 1 + l_values: [0, 3, 4, 5, 6] # not quite up to ten + edge_dim: 20 + hidden_dim: 16 + radius_cutoff: 6.0 + degree_norm: 6.08275253 # sqrt(37), avg degree + sph_harm_kwargs: + use_e3nn: false + e_mean: -76.1160 + e_std: 10.3238 + lr: 0.001 + weight_decay: 0.0 + atom_weighted_loss: false +data: + root_path: ./qm9_data + batch_size: 32 + train_frac: 0.8 + val_frac: 0.1 + num_workers: 4 diff --git a/scripts/model_configs/unconventional_long.yaml b/scripts/model_configs/unconventional_long.yaml new file mode 100644 index 0000000..28ce5a1 --- /dev/null +++ b/scripts/model_configs/unconventional_long.yaml @@ -0,0 +1,45 @@ +# pytorch_lightning==2.2.4 +seed_everything: 21616 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: null + logger: + class_path: pytorch_lightning.loggers.WandbLogger + init_args: + project: "equitriton-qm9" + entity: "laserkelvin" + log_model: "all" + #callbacks: + # - class_path: pytorch_lightning.callbacks.EarlyStopping + # init_args: + # monitor: "val_loss_epoch" + # patience: 5 + # mode: "min" + max_epochs: 100 +model: + model_class: equitriton.model.EquiTritonModel + model_kwargs: + initial_atom_dim: 64 + num_layers: 3 + output_dim: 1 + l_values: [0, 3, 4, 5, 6] # not quite up to ten + edge_dim: 20 + hidden_dim: 16 + radius_cutoff: 6.0 + degree_norm: 6.08275253 # sqrt(37), avg degree + sph_harm_kwargs: + use_e3nn: false + e_mean: -76.1160 + e_std: 10.3238 + lr: 0.001 + weight_decay: 0.0 + atom_weighted_loss: false +data: + root_path: ./qm9_data + batch_size: 32 + train_frac: 0.8 + val_frac: 0.1 + num_workers: 4 diff --git a/scripts/train_model_qm9.py b/scripts/train_model_qm9.py new file mode 100644 index 0000000..1c5ef1d --- /dev/null +++ b/scripts/train_model_qm9.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +from pytorch_lightning.cli import LightningCLI +import torch + +from equitriton.model.lightning import EquiTritonLitModule, LightningQM9 + + +if __name__ == "__main__": + torch.multiprocessing.set_start_method("spawn") + # use LightningCLI for easy configuration + cli = LightningCLI( + EquiTritonLitModule, LightningQM9, save_config_kwargs={"overwrite": True} + ) diff --git a/src/equitriton/__init__.py b/src/equitriton/__init__.py index de4cdee..4e69aba 100644 --- a/src/equitriton/__init__.py +++ b/src/equitriton/__init__.py @@ -26,4 +26,4 @@ if _will_patch: from equitriton import patch # noqa: F401 -__version__ = "0.1.0" +__version__ = "0.2.0" diff --git a/src/equitriton/model/__init__.py b/src/equitriton/model/__init__.py new file mode 100644 index 0000000..e849bfe --- /dev/null +++ b/src/equitriton/model/__init__.py @@ -0,0 +1,6 @@ +from __future__ import annotations + +from equitriton.model.blocks import EquiTritonModel + + +__all__ = ["EquiTritonModel"] diff --git a/src/equitriton/model/blocks.py b/src/equitriton/model/blocks.py new file mode 100644 index 0000000..26e43ed --- /dev/null +++ b/src/equitriton/model/blocks.py @@ -0,0 +1,482 @@ +from __future__ import annotations + +from typing import Literal, Any, Callable +from collections import Counter + +import torch +from torch import nn +import e3nn +from e3nn import o3 +from e3nn import nn as e3nn_nn +from torch_scatter import scatter +from matplotlib import pyplot as plt +from torch_geometric.data import Data as PyGGraph + +from equitriton.utils import spherical_harmonics_irreps +from equitriton.sph_harm.direct import TritonSphericalHarmonic + + +__all__ = [ + "AtomEmbedding", + "EdgeEmbedding", + "SphericalHarmonicEmbedding", + "InteractionBlock", + "EquiTritonModel", + "ScalarReadoutLayer", +] + + +class AtomEmbedding(nn.Module): + def __init__(self, num_atoms: int, atom_dim: int): + super().__init__() + self.embedding = nn.Embedding(num_atoms, atom_dim, padding_idx=0) + + def forward(self, atomic_numbers: torch.LongTensor) -> torch.Tensor: + return self.embedding(atomic_numbers) + + +class EdgeEmbedding(nn.Module): + def __init__(self, num_basis: int, radius_cutoff: float = 6.0, **kwargs): + """ + This module embeds edges in a graph with an EdgeEmbedding object. + + Parameters + ---------- + num_basis : int, optional + The number of basis functions. Defaults to 1. + radius_cutoff : float, optional + The maximum radius up to which basis functions are defined. Defaults to 6.0. + + Optional kwargs + --------------- + basis : str, optional + The type of basis function to use. Defaults to 'bessel'. + start : float, optional + The starting point in the distance grid used in the radial basis. + cutoff : bool, optional + Whether or not to apply a cutoff to the basis functions. + + Returns + ------- + torch.Tensor + A tensor representing the embedding of edges with shape (num_edges, num_basis). + + Examples + -------- + >>> # Define an instance of EdgeEmbedding with 4 basis functions and a radius cutoff of 10. + >>> embedder = EdgeEmbedding(num_basis=4, radius_cutoff=10.0) + """ + super().__init__() + kwargs.setdefault("basis", "bessel") + kwargs.setdefault("start", 0.0) + kwargs.setdefault("cutoff", True) + self.num_basis = num_basis + self.radius_cutoff = radius_cutoff + self.basis_kwargs = kwargs + + def forward(self, distances: torch.Tensor) -> torch.Tensor: + basis_funcs = e3nn.math.soft_one_hot_linspace( + distances, + number=self.num_basis, + end=self.radius_cutoff, + **self.basis_kwargs, + ) + return basis_funcs * self.num_basis**0.5 + + +class SphericalHarmonicEmbedding(nn.Module): + def __init__( + self, + l_values: list[int], + normalize: bool = True, + normalization: Literal["norm", "integral", "component"] = "integral", + use_e3nn: bool = False, + ): + """ + Projects cartesian coordinates onto a specified spherical harmonic basis. + + Arguments mainly implement an equivalent interface to ``e3nn``, + up to just directly using the ``e3nn`` spherical harmonics + implementation. + + Parameters + ---------- + l_values : list[int] + List of l values of spherical harmonics to use as a basis. + normalize : bool, optional + Whether to normalize coordinates before passing into the + embedding step. + normalization : Literal["norm", "integral", "component"], optional + Normalization scheme to use for the embeddings. By default + uses ``integral``, which is the only method implemented for + the Triton kernels. + use_e3nn : bool, optional + Whether to directly use ``e3nn`` spherical harmonics, + by default False. + """ + super().__init__() + self.l_values = list(sorted(l_values)) + self.irreps = spherical_harmonics_irreps(self.l_values, num_feat=1) + self.normalize = normalize + self.normalization = normalization + self.use_e3nn = use_e3nn + + def forward(self, coords: torch.Tensor) -> torch.Tensor: + if not self.use_e3nn: + if self.normalize: + coords = torch.nn.functional.normalize(coords, dim=-1) + outputs = TritonSphericalHarmonic.apply(self.l_values, coords) + if self.normalization == "integral": + outputs /= (4.0 * torch.pi) ** 0.5 + return outputs + else: + return o3.spherical_harmonics( + self.irreps, coords, self.normalize, self.normalization + ) + + +class InteractionBlock(nn.Module): + def __init__( + self, + atomic_dim: int | o3.Irreps, + l_values: list[int], + edge_dim: int, + hidden_dim: int, + radius_cutoff: float, + degree_norm: float, + edge_kwargs: dict[str, Any] = {}, + sph_harm_kwargs: dict[str, Any] = {}, + activation: Callable = nn.functional.silu, + ): + """ + A module that combines radial basis with spherical harmonics to + describe molecular interactions. + + Parameters + ---------- + atomic_dim : int | o3.Irreps + Dimension of the atomic features. If int, it is treated as a + single irreducible representation. + l_values : list[int] + Values of the spherical harmonic order. If the Triton harmonics + are being used, this does not need to be contiguous. + edge_dim : int + Dimension of the edge features. + hidden_dim : int + Hidden dimension for the fully connected network. + radius_cutoff : float + Cutoff radius for the radial basis. + degree_norm : float + Normalization factor for the degree of the graph. + edge_kwargs : dict[str, Any], optional + Keyword arguments for the EdgeEmbedding module. Defaults to {}. + sph_harm_kwargs : dict[str, Any], optional + Keyword arguments for the SphericalHarmonicEmbedding module. + Defaults to {}. + activation : Callable, optional + Activation function for the fully connected network. Defaults to + nn.functional.silu. + + Notes + ----- + The `degree_norm` attribute is set as a property and effectively + represents the average number of neighbors in other models. + + Examples + -------- + >>> block = InteractionBlock(atomic_dim=8, l_values=[0, 1], + edge_dim=16, hidden_dim=32) + >>> block.sph_irreps + ['1x0e', '2x0e'] + """ + sph_harm_kwargs.setdefault("use_e3nn", False) + + super().__init__() + # this is effectively the average number of neighbors in other models + self.degree_norm = degree_norm + # treat atom features as invariant + if isinstance(atomic_dim, int): + atomic_irreps = f"{atomic_dim}x0e" + else: + atomic_irreps = atomic_dim + self.atomic_irreps = atomic_irreps + self.l_values = list(sorted(l_values)) + # these two attributes are similar but different: the former is used for describing + # the basis itself, and the latter is for actually specifying the weights + self.sph_irreps = spherical_harmonics_irreps(self.l_values, num_feat=1) + self.output_irreps = spherical_harmonics_irreps( + self.l_values, num_feat=hidden_dim + ) + # tensor product is the final bit the combines the radial basis with the spherical + # harmonics + self.tensor_product = o3.FullyConnectedTensorProduct( + self.atomic_irreps, + self.sph_irreps, + self.output_irreps, + shared_weights=False, + ) + self.edge_basis = EdgeEmbedding(edge_dim, radius_cutoff, **edge_kwargs) + self.spherical_harmonics = SphericalHarmonicEmbedding( + l_values, **sph_harm_kwargs + ) + self.fc = e3nn_nn.FullyConnectedNet( + [edge_dim, hidden_dim, self.tensor_product.weight_numel], activation + ) + + @property + def num_projections(self) -> int: + """Returns the expected number of projections.""" + return sum([2 * l + 1 for l in self.l_values]) + + @property + def output_dim(self) -> int: + """Returns the dimensionality of the output.""" + return self.output_irreps.dim + + def forward( + self, + atomic_features: torch.Tensor, + coords: torch.Tensor, + edge_index: torch.LongTensor, + ) -> torch.Tensor: + """ + High-level description: + + 1. Project cartesian coordinates onto spherical harmonic basis + 2. Project interatomic distances onto radial (bessel) basis + 3. Transform radial basis functions with learnable weights + 4. Compute tensor product between scalar atom features and spherical harmonic basis + 5. Update node features + """ + edge_dist = coords[edge_index[0]] - coords[edge_index[1]] + sph_harm = self.spherical_harmonics(edge_dist) + # calculate atomic distances, embed, and transform them + edge_basis = self.edge_basis(edge_dist.norm(dim=-1)) + edge_z = self.fc(edge_basis) + # compute tensor product + messages = self.tensor_product(atomic_features[edge_index[0]], sph_harm, edge_z) + # update node features + hidden_feats = ( + scatter(messages, edge_index[1], dim=0, dim_size=atomic_features.size(0)) + / self.degree_norm + ) + return hidden_feats + + +class ScalarReadoutLayer(nn.Module): + def __init__(self, hidden_irreps: o3.Irreps, output_dim: int): + super().__init__() + self.hidden_irreps = hidden_irreps + self.output_irreps = o3.Irreps(f"{output_dim}x0e") + self.output_layer = o3.Linear( + irreps_in=hidden_irreps, irreps_out=self.output_irreps + ) + + def forward(self, node_feats: torch.Tensor) -> torch.Tensor: + return self.output_layer(node_feats) + + +class EquiTritonModel(nn.Module): + def __init__( + self, + initial_atom_dim: int, + num_layers: int, + output_dim: int, + l_values: list[int], + edge_dim: int, + hidden_dim: int, + radius_cutoff: float, + degree_norm: float, + edge_kwargs: dict[str, Any] = {}, + sph_harm_kwargs: dict[str, Any] = {}, + activation: Callable = nn.functional.silu, + num_atoms: int = 100, + skip_connections: bool = True, + ): + """ + A neural network model designed for processing molecular graphs. + + This class implements a hierarchical architecture with multiple interaction blocks, + allowing for efficient and scalable processing of large molecular datasets. + + Parameters + ============= + initial_atom_dim : int + The dimensionality of the atomic embeddings. + num_layers : int + The number of convolutional layers to use. + output_dim : int + The dimensionality of the final scalar features. + l_values : list[int] + A list of spherical harmonics order to consider. If using the Triton kernels, + does not need to be contiguous. + edge_dim : int + The dimensionality of the edge features. + hidden_dim : int + The dimensionality of the hidden state in each interaction block. + radius_cutoff : float + The cutoff distance for radial basis functions. + degree_norm : float + The normalization constant for edge features. Typically square root of the average degree. + edge_kwargs : dict[str, Any], optional + Keyword arguments to pass to the InteractionBlock. + sph_harm_kwargs : dict[str, Any], optional + Keyword arguments to pass to the InteractionBlock. By default, + the ``use_e3nn`` kwarg is set to False, which uses the Triton kernels instead. + activation : Callable, optional + The activation function to use in each interaction block. Defaults to nn.functional.silu. + num_atoms : int, optional + The number of atoms in the embedding table (i.e. unique elements). Defaults to 100. + skip_connections : bool, optional + Whether to enable residual connections between layers. Defaults to True. + + Examples + ============ + >>> model = EquiTritonModel(...) + >>> graph = PyGGraph(...).to(device="cuda") + >>> graph_z, z = model(graph) + """ + sph_harm_kwargs.setdefault("use_e3nn", False) + + super().__init__() + self.atomic_embedding = AtomEmbedding(num_atoms, initial_atom_dim) + self.initial_layer = InteractionBlock( + initial_atom_dim, + l_values, + edge_dim, + hidden_dim, + radius_cutoff, + degree_norm, + edge_kwargs, + sph_harm_kwargs, + activation, + ) + self.conv_layers = nn.ModuleDict() + for layer_index in range(num_layers + 1): + self.conv_layers[f"conv_{layer_index}"] = InteractionBlock( + self.initial_layer.output_irreps, # subsequent layers use irreps instead + l_values, + edge_dim, + hidden_dim, + radius_cutoff, + degree_norm, + edge_kwargs, + sph_harm_kwargs, + activation, + ) + self.scalar_readout = ScalarReadoutLayer( + self.initial_layer.output_irreps, output_dim + ) + self.skip_connections = skip_connections + self.output_dim = output_dim + + @property + def output_irrep_shapes(self) -> dict[str, int]: + # this returns a dictionary for each l-order, the number + # of expected elements for that particular order in the output + return dict(Counter(self.initial_layer.output_irreps.ls)) + + def visualize(self, **kwargs): + """ + Visualize the sequence of tensor products within the model. + + Essentially, all this does is wrap around the ``tensor_product.visualize()`` + functionality, but also tacks on titles for each axis. + """ + num_plots = len(self.conv_layers) + 1 + fig, axarray = plt.subplots(num_plots, 1, figsize=(3, 12)) + # make indexing easier + axarray = axarray.flatten() + + self.initial_layer.tensor_product.visualize(ax=axarray[0], **kwargs) + axarray[0].set_title("Input layer", loc="right") + index = 1 + for layer_name, layer in self.conv_layers.items(): + ax = axarray[index] + layer.tensor_product.visualize(ax=ax, **kwargs) + ax.set_title(layer_name, loc="right") + index += 1 + fig.tight_layout() + return fig, axarray + + def embed(self, graph: PyGGraph) -> dict[str, tuple[str, torch.Tensor]]: + """ + Generate embeddings for a given graph, either batched or + unbatched. + + This proceeds more or less the same way as the ``forward`` + pass, but instead emits a dictionary that maps each layer name + with both the embedding at that layer as well as the irreducible + representations. + """ + # determine if the graph is batched or not + is_batched = hasattr(graph, "ptr") + for key in ["pos", "edge_index", "z"]: + assert hasattr(graph, key) + # get atom embeddings + atom_z = self.atomic_embedding(graph.z) # [nodes, initial_atom_dim] + # first message passing step + z = self.initial_layer(atom_z, graph.pos, graph.edge_index) + outputs = { + "initial": ( + str(self.initial_layer.output_irreps), + z.detach().clone().cpu(), + ), + } + for layer_name, layer in self.conv_layers.items(): + new_z = layer(z, graph.pos, graph.edge_index) + # add residual connections + if self.skip_connections and new_z.shape == z.shape: + new_z += z + z = new_z + outputs[layer_name] = (str(layer.output_irreps), z.detach().clone().cpu()) + if is_batched: + graph_z = scatter(z, graph.batch, dim=0, dim_size=graph.batch_size) + else: + # for a single graph, just sum up the node features + graph_z = z.sum(dim=0, keepdims=True) + outputs["graph_z"] = (str(layer.output_irreps), graph_z.detach().clone().cpu()) + return outputs + + def forward(self, graph: PyGGraph) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for a generic equivariant convolution model. + + Parameters + ---------- + graph : PyGGraph + PyG graph structure, which may be batched or a single graph. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + 2-tuple of outputs; first element are graph level + outputs (from summing over nodes), second element + is the node-level outputs. + """ + # determine if the graph is batched or not + is_batched = hasattr(graph, "ptr") + for key in ["pos", "edge_index", "z"]: + assert hasattr(graph, key) + # get atom embeddings + atom_z = self.atomic_embedding(graph.z) # [nodes, initial_atom_dim] + # first message passing step + z = self.initial_layer(atom_z, graph.pos, graph.edge_index) + outputs = {} + for layer_name, layer in self.conv_layers.items(): + new_z = layer(z, graph.pos, graph.edge_index) + # add residual connections + if self.skip_connections and new_z.shape == z.shape: + new_z += z + z = new_z + outputs[layer_name] = z + # map final output as scalars + z = self.scalar_readout(z) + # latest node features are in z; we generate graph-level scalar features + # by doing a scatter add + if is_batched: + graph_z = scatter(z, graph.batch, dim=0, dim_size=graph.batch_size) + else: + # for a single graph, just sum up the node features + graph_z = z.sum(dim=0, keepdims=True) + return graph_z, z diff --git a/src/equitriton/model/lightning.py b/src/equitriton/model/lightning.py new file mode 100644 index 0000000..81cd5c0 --- /dev/null +++ b/src/equitriton/model/lightning.py @@ -0,0 +1,231 @@ +from __future__ import annotations + +from math import ceil +from typing import Literal + +import pytorch_lightning as pl +import torch +from torch.optim.adamw import AdamW +from torch import nn +from torch.utils.data import random_split +from torch_geometric.datasets import QM9 +from torch_geometric.data import Data as PyGGraph +from torch_geometric.loader import DataLoader + + +class LightningQM9(pl.LightningDataModule): + def __init__( + self, + root_path: str = "./qm9_data", + batch_size: int = 16, + train_frac: float = 0.8, + val_frac: float = 0.1, + num_workers: int = 0, + ): + """ + Custom data module for QM9 dataset. + + Parameters + ---------- + root_path : str, optional (default: "./qm9_data") + Path to the QM9 dataset. + batch_size : int, optional (default: 16) + Number of samples in each mini-batch. + train_frac : float, optional (default: 0.8) + Fraction of data used for training. + val_frac : float, optional (default: 0.1) + Fraction of data used for validation. + num_workers : int, optional (default: 0) + Number of worker processes to use for loading data. + + Examples + -------- + >>> dm = LightningQM9(root_path="/path/to/qm9_data", batch_size=32) + + Attributes + ---------- + dataset : QM9 + Loaded QM9 dataset. + hparams : dict + Hyperparameters of the data module. + + Methods + ------- + setup(stage: str) + Setup data splits for training, validation and testing. + train_dataloader() + Returns a DataLoader instance for training data. + val_dataloader() + Returns a DataLoader instance for validation data. + test_dataloader() + Returns a DataLoader instance for testing data. + """ + super().__init__() + self.dataset = QM9(root_path) + self.save_hyperparameters() + + def setup(self, stage: str): + hparams = self.hparams + num_samples = len(self.dataset) + num_train = int(num_samples * hparams["train_frac"]) + num_val = int(num_samples * hparams["val_frac"]) + num_test = ceil( + num_samples * (1 - (hparams["train_frac"] + hparams["val_frac"])) + ) + # generate random splits + train_split, val_split, test_split = random_split( + self.dataset, lengths=[num_train, num_val, num_test] + ) + self.splits = {"train": train_split, "val": val_split, "test": test_split} + + def train_dataloader(self): + num_workers = self.hparams["num_workers"] + return DataLoader( + self.splits["train"], + batch_size=self.hparams["batch_size"], + shuffle=True, + num_workers=num_workers, + persistent_workers=True if num_workers > 0 else False, + ) + + def val_dataloader(self): + num_workers = self.hparams["num_workers"] + return DataLoader( + self.splits["val"], + batch_size=self.hparams["batch_size"], + shuffle=False, + num_workers=num_workers, + persistent_workers=True if num_workers > 0 else False, + ) + + def test_dataloader(self): + num_workers = self.hparams["num_workers"] + return DataLoader( + self.splits["test"], + batch_size=self.hparams["batch_size"], + shuffle=False, + num_workers=num_workers, + persistent_workers=True if num_workers > 0 else False, + ) + + +class AtomWeightedMSE(nn.Module): + """ + Calculates the mean-squared-error between predicted and targets, + weighted by the number of atoms within each graph. + + From matsciml + """ + + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + atoms_per_graph: torch.Tensor, + ) -> torch.Tensor: + if atoms_per_graph.size(0) != target.size(0): + raise RuntimeError( + "Dimensions for atom-weighted loss do not match:" + f" expected atoms_per_graph to have {target.size(0)} elements; got {atoms_per_graph.size(0)}." + "This loss is intended to be applied to scalar targets only." + ) + # check to make sure we are broad casting correctly + if (input.ndim != target.ndim) and target.size(-1) == 1: + input.unsqueeze_(-1) + # for N-d targets, we might want to keep unsqueezing + while atoms_per_graph.ndim < target.ndim: + atoms_per_graph.unsqueeze_(-1) + # ensures that atoms_per_graph is type cast correctly + squared_error = ((input - target) / atoms_per_graph.to(input.dtype)) ** 2.0 + return squared_error.mean() + + +class EquiTritonLitModule(pl.LightningModule): + def __init__( + self, + model_class: type, + model_kwargs, + e_mean: float, + e_std: float, + lr: float = 1e-3, + weight_decay: float = 0.0, + atom_weighted_loss: bool = True, + ): + """ + Initializes the EquiTritonLitModule clas. + + Parameters + ---------- + model_class : type + Th class of the model to be used. + model_kwargs : dict + Keyword argument for the model initialization. + e_mean : float + The mean of the energy values. + e_std : float + The standard deviation of the energy values. + lr : float, optional + The learning rate (default is 1e-3) for AdamW. + weight_decay : float, optional + Weight decay value (default is 0.0). + atom_weighted_loss : bool, optional + Whether to use atom-weighted loss or not (default is True). + """ + super().__init__() + self.model = model_class(**model_kwargs) + if atom_weighted_loss: + self.loss = AtomWeightedMSE() + else: + self.loss = nn.MSELoss() + self.output_head = nn.Linear(self.model.output_dim, 1) + self.save_hyperparameters() + + def configure_optimizers(self): + return AdamW( + self.parameters(), + lr=self.hparams["lr"], + weight_decay=self.hparams["weight_decay"], + ) + + def step(self, graph: PyGGraph, stage: Literal["train", "test", "val"]): + """ + Performs a single step of the training, validation or testing + process. + + Parameters + ---------- + graph : PyGGraph + The input graph. + stage : Literal["train", "test", "val"] + The current stage (training, testing or validation). + + Returns + ------- + loss : float + The calculated loss value. + """ + g_z, z = self.model(graph) + pred_energy = self.output_head(g_z) + target_energy = graph.y[:, 12].unsqueeze(-1) + norm_energy = (target_energy - self.hparams["e_mean"]) / self.hparams["e_std"] + if self.hparams["atom_weighted_loss"]: + loss = self.loss(pred_energy, norm_energy, torch.diff(graph.ptr)) + else: + loss = self.loss(pred_energy, norm_energy) + batch_size = getattr(graph, "batch_size", 1) + self.log( + f"{stage}_loss", loss, prog_bar=True, on_step=True, batch_size=batch_size + ) + return loss + + def training_step(self, batch): + loss = self.step(batch, "train") + return loss + + def validation_step(self, batch): + loss = self.step(batch, "val") + return loss + + def test_step(self, batch): + loss = self.step(batch, "test") + return loss diff --git a/src/equitriton/sph_harm/direct/README.md b/src/equitriton/sph_harm/direct/README.md new file mode 100644 index 0000000..e5e1a27 --- /dev/null +++ b/src/equitriton/sph_harm/direct/README.md @@ -0,0 +1,15 @@ +# Direct spherical harmonics + +This module implements spherical harmonics of up to $l=10$ _directly_ in terms +of $x,y,z$. Each submodule implements a particular order, comprising four objects: +a PyTorch `autograd.Function` wrapper, forward and backward Triton kernels, +and a PyTorch implementation of the forward kernel. The PyTorch implementation +is not necessarily intended for performance, rather for double checking that +the Triton versions are behaving as intended. + +Currently, the kernels are heavily computer assisted, and may not be optimal +particularly on the register front: there are a lot of redudant constants, +and we are relying heavily on the LLVM compiler to realize this and group +them at run time. Similarly, the variable names are also not very human-friendly; +this is unlikely to change; they might have a high maintenance burden, +but we're unlikely to touch them very much. diff --git a/src/equitriton/sph_harm/direct/__init__.py b/src/equitriton/sph_harm/direct/__init__.py new file mode 100644 index 0000000..88250f5 --- /dev/null +++ b/src/equitriton/sph_harm/direct/__init__.py @@ -0,0 +1,11 @@ +from equitriton.sph_harm.direct.special import FusedSecondOrderSphericalHarmonic +from equitriton.sph_harm.direct.utils import ( + triton_spherical_harmonic, + TritonSphericalHarmonic, +) + +__all__ = [ + "FusedSecondOrderSphericalHarmonic", + "triton_spherical_harmonic", + "TritonSphericalHarmonic", +] diff --git a/src/equitriton/sph_harm/direct/special.py b/src/equitriton/sph_harm/direct/special.py new file mode 100644 index 0000000..f6fe596 --- /dev/null +++ b/src/equitriton/sph_harm/direct/special.py @@ -0,0 +1,267 @@ +import triton +import torch +from triton import language as tl + +from equitriton.utils import calculate_lastdim_num_blocks + +__all__ = ["FusedSecondOrderSphericalHarmonic"] + + +class FusedSecondOrderSphericalHarmonic(torch.autograd.Function): + @staticmethod + def forward( + ctx, + coords: torch.Tensor, + mask: torch.Tensor | None = None, + block_size: int = 64, + ): + output_tensor = torch.empty( + (*coords.shape[:-1], 9), dtype=coords.dtype, device=coords.device + ) + coord_numel = coords.numel() + output_numel = output_tensor.numel() + num_blocks = calculate_lastdim_num_blocks(coords, block_size) + # apply the kernel + joint_second_order_fwd[num_blocks,]( + coords, output_tensor, block_size, coord_numel, output_numel + ) + ctx.save_for_backward(coords) + return output_tensor + + @staticmethod + def backward( + ctx, sph_grad_tensor: torch.Tensor, block_size: int = 64 + ) -> torch.Tensor: + (coords,) = ctx.saved_tensors + coord_grad_output = torch.zeros_like(coords) + num_blocks = calculate_lastdim_num_blocks(coords, block_size) + # call backward kernel + joint_second_order_bwd[num_blocks,]( + coords, + coord_grad_output, + sph_grad_tensor, + block_size, + coords.numel(), + sph_grad_tensor.numel(), + ) + return coord_grad_output + + +def _torch_fwd(coords: torch.Tensor) -> torch.Tensor: + """ + PyTorch implementation of the kernel. This is designed + purely for unit testing to ensure that the Triton implementation + is behaving as intended. + + This function is generically named to make it easier for + it to be called programmatically: it is _not_ intended + to be used manually. + + Parameters + ---------- + coords : torch.Tensor + N-d tensor, where the last dimension corresponds to + xyz values. + + Returns + ------- + torch.Tensor + N-d tensor, where the last dimension corresponds to + each projection of the second order spherical harmonic. + """ + x = coords[..., 0].contiguous().unsqueeze(-1) + y = coords[..., 1].contiguous().unsqueeze(-1) + z = coords[..., 2].contiguous().unsqueeze(-1) + CONST_00 = 3.87298334620742 + CONST_01 = 2.23606797749979 + CONST_02 = -1.11803398874989 + CONST_03 = 1.93649167310371 + CONST_04 = 3**0.5 + Y00 = torch.ones_like(x) + Y10 = x * CONST_04 + Y11 = y * CONST_04 + Y12 = z * CONST_04 + Y20 = CONST_00 * x * z + Y21 = CONST_00 * x * y + Y23 = CONST_00 * y * z # looks jarring but just helping the compiler ;) + Y22 = CONST_02 * x * x + CONST_01 * y * y + CONST_02 * z * z + Y24 = -CONST_03 * x * x + CONST_03 * z * z + return torch.cat([Y00, Y10, Y11, Y12, Y20, Y21, Y22, Y23, Y24], dim=-1) + + +@triton.jit +def joint_second_order_fwd( + coord_ptr: tl.tensor, + output_ptr: tl.tensor, + block_size: tl.constexpr, + coord_numel: tl.constexpr, + output_numel: tl.constexpr, +): + """ + This Triton implementation includes l=0, 1, 2 within the + same kernel, as it would be a common operation. + """ + # these are hardcoded because they are predetermined; + coord_stride = 3 + # work out the row offsets + block_id = tl.program_id(0) + coord_striding = tl.arange(0, block_size) * coord_stride + # as the name suggests, this is effectively every node/atom + coord_row_offset = coord_striding + (block_size * coord_stride * block_id) + x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel) + y = tl.load( + coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + z = tl.load( + coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + CONST_00 = 3.87298334620742 + CONST_01 = 2.23606797749979 + CONST_02 = -1.11803398874989 + CONST_03 = 1.93649167310371 + CONST_04 = tl.sqrt(3.0) + Y10 = CONST_04 * x + Y11 = CONST_04 * y + Y12 = CONST_04 * z + Y20 = CONST_00 * x * z + Y21 = CONST_00 * x * y + Y23 = CONST_00 * y * z # looks jarring but just helping the compiler ;) + Y22 = CONST_02 * x * x + CONST_01 * y * y + CONST_02 * z * z + Y24 = -CONST_03 * x * x + CONST_03 * z * z + output_stride = 9 # sum of [2l + 1] over l=0, 1, 2 + output_striding = tl.arange(0, block_size) * output_stride + output_row_offset = output_striding + (block_size * output_stride * block_id) + # first column are all zeros, per zeroth order + tl.store(output_ptr + output_row_offset, 1.0, mask=output_row_offset < output_numel) + tl.store( + output_ptr + output_row_offset + 1, + Y10, + mask=output_row_offset + 1 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 2, + Y11, + mask=output_row_offset + 2 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 3, + Y12, + mask=output_row_offset + 3 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 4, + Y20, + mask=output_row_offset + 4 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 5, + Y21, + mask=output_row_offset + 5 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 6, + Y22, + mask=output_row_offset + 6 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 7, + Y23, + mask=output_row_offset + 6 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 8, + Y24, + mask=output_row_offset + 7 < output_numel, + ) + + +@triton.jit +def joint_second_order_bwd( + coord_ptr: tl.tensor, + coord_grad_ptr: tl.tensor, + sph_grad_ptr: tl.tensor, + block_size: tl.constexpr, + coord_numel: tl.constexpr, + output_numel: tl.constexpr, +): + # work out the row offsets + block_id = tl.program_id(0) + # these are hardcoded because they are predetermined; + coord_stride = 3 + coord_striding = tl.arange(0, block_size) * coord_stride + # as the name suggests, this is effectively every node/atom + coord_row_offset = coord_striding + (block_size * coord_stride * block_id) + x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel) + y = tl.load( + coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + z = tl.load( + coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + output_stride = 9 # [2l + 1] + output_striding = tl.arange(0, block_size) * output_stride + output_row_offset = output_striding + (block_size * output_stride * block_id) + CONST_00 = 3.87298334620742 + CONST_01 = 2.23606797749979 + CONST_02 = 4.47213595499958 + CONST_03 = tl.sqrt(3.0) + # load in gradients w.r.t. spherical harmonic projections. + # gradient of l = 0 goes to zero + g_Y10 = tl.load( + sph_grad_ptr + output_row_offset + 1, mask=output_row_offset + 1 < output_numel + ) + g_Y11 = tl.load( + sph_grad_ptr + output_row_offset + 2, mask=output_row_offset + 2 < output_numel + ) + g_Y12 = tl.load( + sph_grad_ptr + output_row_offset + 3, mask=output_row_offset + 3 < output_numel + ) + g_Y20 = tl.load( + sph_grad_ptr + output_row_offset + 4, mask=output_row_offset + 4 < output_numel + ) + g_Y21 = tl.load( + sph_grad_ptr + output_row_offset + 5, mask=output_row_offset + 5 < output_numel + ) + g_Y22 = tl.load( + sph_grad_ptr + output_row_offset + 6, mask=output_row_offset + 6 < output_numel + ) + g_Y23 = tl.load( + sph_grad_ptr + output_row_offset + 7, mask=output_row_offset + 7 < output_numel + ) + g_Y24 = tl.load( + sph_grad_ptr + output_row_offset + 8, mask=output_row_offset + 8 < output_numel + ) + g_x = ( + CONST_00 * g_Y20 * z + + CONST_00 * g_Y21 * y + - CONST_01 * g_Y22 * x + - CONST_00 * g_Y24 * x + + CONST_03 * g_Y10 + ) + g_y = ( + CONST_00 * g_Y21 * x + + CONST_02 * g_Y22 * y + + CONST_00 * g_Y23 * z + + CONST_03 * g_Y11 + ) + g_z = ( + CONST_00 * g_Y20 * x + - CONST_01 * g_Y22 * z + + CONST_00 * g_Y23 * y + + CONST_00 * g_Y24 * z + + CONST_03 * g_Y12 + ) + # write out gradients + tl.store( + coord_grad_ptr + coord_row_offset, g_x, mask=coord_row_offset < coord_numel + ) + tl.store( + coord_grad_ptr + coord_row_offset + 1, + g_y, + mask=coord_row_offset + 1 < coord_numel, + ) + tl.store( + coord_grad_ptr + coord_row_offset + 2, + g_z, + mask=coord_row_offset + 2 < coord_numel, + ) diff --git a/src/equitriton/sph_harm/direct/tests/test_direct_sph_harm.py b/src/equitriton/sph_harm/direct/tests/test_direct_sph_harm.py new file mode 100644 index 0000000..5b39b0a --- /dev/null +++ b/src/equitriton/sph_harm/direct/tests/test_direct_sph_harm.py @@ -0,0 +1,105 @@ +import pytest +import torch + +from equitriton import __HAS_XPU__, __HAS_CUDA__ +from equitriton.sph_harm.direct.utils import ( + torch_spherical_harmonic, + triton_spherical_harmonic, +) + +torch.manual_seed(316165) + + +@pytest.mark.parametrize("order", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +@pytest.mark.parametrize( + "device", + [ + pytest.param( + "xpu", + marks=pytest.mark.skipif(not __HAS_XPU__, reason="No XPUs available."), + ), + pytest.param( + "cuda", + marks=pytest.mark.skipif( + not __HAS_CUDA__, reason="No CUDA GPUs available." + ), + ), + ], +) +@pytest.mark.parametrize("tensor_shape", [(512, 3), (128, 16, 3), (256, 8, 8, 3)]) +@pytest.mark.parametrize( + "dtype", + [ + pytest.param( + torch.float16, + marks=pytest.mark.xfail(raises=AssertionError, reason="low precision"), + ), + pytest.param( + torch.bfloat16, + marks=pytest.mark.xfail(raises=AssertionError, reason="low precision"), + ), + torch.float32, + torch.float64, + ], +) +def test_forward_equivalence(order, device, tensor_shape, dtype): + """ + Tests the numerical equivalence of the PyTorch versus + the Triton implementations. This is mostly to ensure that + writing outputs back out is being done correctly. + """ + coords = torch.rand(tensor_shape, device=device, dtype=dtype) + triton_out = triton_spherical_harmonic(order, coords) + torch_out = torch_spherical_harmonic(order, coords) + assert torch.allclose(triton_out, torch_out, atol=1e-5, rtol=1e-3) + + +@pytest.mark.parametrize("order", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +@pytest.mark.parametrize( + "device", + [ + pytest.param( + "xpu", + marks=pytest.mark.skipif(not __HAS_XPU__, reason="No XPUs available."), + ), + pytest.param( + "cuda", + marks=pytest.mark.skipif( + not __HAS_CUDA__, reason="No CUDA GPUs available." + ), + ), + ], +) +@pytest.mark.parametrize("tensor_shape", [(512, 3), (128, 16, 3), (256, 8, 8, 3)]) +@pytest.mark.parametrize( + "dtype", + [ + pytest.param( + torch.float16, + marks=pytest.mark.xfail(raises=AssertionError, reason="low precision"), + ), + pytest.param( + torch.bfloat16, + marks=pytest.mark.xfail(raises=AssertionError, reason="low precision"), + ), + torch.float32, + torch.float64, + ], +) +def test_backward_equivalence(order, device, tensor_shape, dtype): + """ + Tests the numerical equivalence of the PyTorch versus + the Triton implementation of the backward pass. This is mostly to ensure that + writing outputs back out is being done correctly. + """ + coords = torch.rand(tensor_shape, device=device, dtype=dtype, requires_grad=True) + # run with autograd first + torch_out = torch_spherical_harmonic(order, coords) + torch_out.backward(gradient=torch.ones_like(torch_out)) + torch_grad = coords.grad.clone().detach() + coords.grad.zero_() + # now run the triton result + triton_out = triton_spherical_harmonic(order, coords) + triton_out.backward(gradient=torch.ones_like(triton_out)) + triton_grad = coords.grad.clone().detach() + assert torch.allclose(triton_grad, torch_grad, atol=1e-5, rtol=1e-3) diff --git a/src/equitriton/sph_harm/direct/utils.py b/src/equitriton/sph_harm/direct/utils.py new file mode 100644 index 0000000..8c66642 --- /dev/null +++ b/src/equitriton/sph_harm/direct/utils.py @@ -0,0 +1,323 @@ +from __future__ import annotations + +from importlib import import_module +from typing import Callable + +import torch +import numpy as np + +from equitriton.utils import num_irreps_projections, calculate_lastdim_num_blocks + +__all__ = ["torch_spherical_harmonic", "triton_spherical_harmonic"] + +BLOCK_SIZE = 64 + + +def _get_autograd_func(l: int) -> type[torch.autograd.Function]: + """ + Function that will grab the autograd.Function for a specified + l order. + + Parameters + ---------- + l : int + Order of spherical harmonic to compute. + + Returns + ------- + type[torch.autograd.Function] + Class reference to the autograd Function. + + Raises + ------ + ModuleNotFoundError: + If the order of spherical harmonic is not implemented, + the module will not exist. + RuntimeError: + If the autograd.Function can't be found. + """ + try: + target_module = import_module(f"equitriton.sph_harm.direct.y_{l}") + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"Spherical harmonic order l={l} requested, but not found!" + ) from e + defined_objs = dir(target_module) + for key in defined_objs: + if "SphericalHarmonic" in key: + sph_harm_func = getattr(target_module, key) + return sph_harm_func + raise RuntimeError(f"Namespace for module l={l} is broken!") + + +def _get_fwd_kernel(l: int) -> Callable: + """ + Reach into the module of a specified l value and grab + the corresponding forward Triton kernel function. + + Parameters + ---------- + l : int + Spherical harmonic l value to search for. + + Returns + ------- + Callable + Triton forward kernel + + Raises + ------ + ModuleNotFoundError: + If the l value is not implemented, the module will + not exist and raises a ``ModuleNotFoundError``. + RuntimeError: + If the module exists but we aren't able to find + a forward kernel defined, it's broken. + """ + try: + target_module = import_module(f"equitriton.sph_harm.direct.y_{l}") + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"Spherical harmonic order l={l} requested, but not found!" + ) from e + defined_objs = dir(target_module) + for key in defined_objs: + if "order_fwd" in key: + sph_harm_func = getattr(target_module, key) + return sph_harm_func + raise RuntimeError(f"Namespace for module l={l} is broken!") + + +def _get_bwd_kernel(l: int) -> Callable: + """ + Reach into the module of a specified l value and grab + the corresponding backward Triton kernel function. + + Parameters + ---------- + l : int + Spherical harmonic l value to search for. + + Returns + ------- + Callable + Triton backward kernel + + Raises + ------ + ModuleNotFoundError: + If the l value is not implemented, the module will + not exist and raises a ``ModuleNotFoundError``. + RuntimeError: + If the module exists but we aren't able to find + a backward kernel defined, it's broken. + """ + try: + target_module = import_module(f"equitriton.sph_harm.direct.y_{l}") + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"Spherical harmonic order l={l} requested, but not found!" + ) from e + defined_objs = dir(target_module) + for key in defined_objs: + if "order_bwd" in key: + sph_harm_func = getattr(target_module, key) + return sph_harm_func + raise RuntimeError(f"Namespace for module l={l} is broken!") + + +def torch_spherical_harmonic(l: int, coords: torch.Tensor) -> torch.Tensor: + """ + Utility function that will call the PyTorch implementation + of a spherical harmonic order. + + This is not intended for production use, but mainly for + sanity checking and convenience. + + Parameters + ---------- + l : int + Order of spherical harmonic requested. + coords : torch.Tensor + N-d tensor, where the last dimension should correspond + to xyz vectors. + + Returns + ------- + torch.Tensor + N-d tensor of the same dimensionality as the input coordinates, + but the size of the last dimension equal to [2 * l + 1]. + + Raises + ------ + ModuleNotFoundError + If order of spherical harmonic requested is not found, it is + likely not yet implemented. + RuntimeError + If the PyTorch implementation of the spherical harmonic is + not found within the module. + RuntimeError + If the shape of the last dimension of the ``coords`` tensor + is not equal to three. + """ + try: + target_module = import_module(f"equitriton.sph_harm.direct.y_{l}") + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"Spherical harmonic order l={l} requested, but not found!" + ) from e + torch_func = getattr(target_module, "_torch_fwd", None) + if not torch_func: + raise RuntimeError(f"PyTorch implementation of l={l} not found.") + if coords.size(-1) != 3: + raise RuntimeError("Expects last dimension of coordinate tensor to be 3!") + return torch_func(coords) + + +def triton_spherical_harmonic( + l_values: int | list[int], coords: torch.Tensor, mask: torch.Tensor | None = None +) -> torch.Tensor: + """ + Utility function that will call the Triton implementation + of a spherical harmonic order. + + This is not intended for production use, but mainly for + sanity checking and convenience. + + Parameters + ---------- + l : int + Order of spherical harmonic requested. + coords : torch.Tensor + N-d tensor, where the last dimension should correspond + to xyz vectors. + + Returns + ------- + torch.Tensor + N-d tensor of the same dimensionality as the input coordinates, + but the size of the last dimension equal to [2 * l + 1]. + + Raises + ------ + ModuleNotFoundError + If order of spherical harmonic requested is not found, it is + likely not yet implemented. + RuntimeError + If the Triton implementation of the spherical harmonic is + not found within the module. + RuntimeError + If the shape of the last dimension of the ``coords`` tensor + is not equal to three. + """ + if coords.size(-1) != 3: + raise RuntimeError("Expects last dimension of coordinate tensor to be 3!") + if isinstance(l_values, int): + l_values = [ + l_values, + ] + # ensure we are in ascending order + l_values = list(sorted(l_values)) + dims = [num_irreps_projections(l) for l in l_values] + offsets = np.zeros_like(dims) + # prepend zero, since we start with zero offset + offsets[1:] = np.cumsum(dims[:-1]) + + # convert into a list, since np.int64 is not desired + offsets = offsets.tolist() + # preallocate a tensor that holds all of the spherical harmonic terms + output_tensor = torch.empty( + (*coords.shape[:-1], sum(dims)), + device=coords.device, + dtype=coords.dtype, + requires_grad=True, + ) + for l, offset in zip(l_values, offsets): + sph_harm_func = _get_autograd_func(l) + sph_harm_func.apply(coords, output_tensor, mask, BLOCK_SIZE, offset) + return output_tensor + + +class TritonSphericalHarmonic(torch.autograd.Function): + __l_values__: list + __offsets__: list + + @staticmethod + def forward( + ctx, + l_values: int | list[int], + coords: torch.Tensor, + mask: torch.Tensor | None = None, + ): + if coords.size(-1) != 3: + raise RuntimeError("Expects last dimension of coordinate tensor to be 3!") + if isinstance(l_values, int): + l_values = [ + l_values, + ] + # ensure we are in ascending order + l_values = list(sorted(l_values)) + dims = [num_irreps_projections(l) for l in l_values] + offsets = np.zeros_like(dims) + # prepend zero, since we start with zero offset + offsets[1:] = np.cumsum(dims[:-1]) + # convert into a list, since np.int64 is not desired + offsets = offsets.tolist() + # preallocate a tensor that holds all of the spherical harmonic terms + output_tensor = torch.empty( + (*coords.shape[:-1], sum(dims)), + device=coords.device, + dtype=coords.dtype, + requires_grad=True, + ) + coord_numel = coords.numel() + output_numel = output_tensor.numel() + # this corresponds to the number of projections + output_stride = output_tensor.stride(-2) + num_blocks = calculate_lastdim_num_blocks(coords, BLOCK_SIZE) + for l, offset in zip(l_values, offsets): + sph_harm_func = _get_fwd_kernel(l) + sph_harm_func[num_blocks,]( + coords, + output_tensor, + BLOCK_SIZE, + coord_numel, + output_numel, + offset, + output_stride, + ) + ctx.save_for_backward(coords) + # stash values as class attributes, as they are the same + # and ctx can only hold tensors + TritonSphericalHarmonic.__l_values__ = l_values + TritonSphericalHarmonic.__offsets__ = offsets + return output_tensor + + @staticmethod + def backward(ctx, sph_harm_grads: torch.Tensor): + (coords,) = ctx.saved_tensors + # grab from private class variables + l_values = TritonSphericalHarmonic.__l_values__ + offsets = TritonSphericalHarmonic.__offsets__ + coord_grad_output = torch.zeros_like(coords) + # combine start and end together to slice the gradient tensor + coord_numel = coords.numel() + grads_numel = sph_harm_grads.numel() + # this corresponds to the number of projections + output_stride = sph_harm_grads.stride(-2) + num_blocks = calculate_lastdim_num_blocks(coords, BLOCK_SIZE) + for l, offset in zip(l_values, offsets): + sph_harm_bwd = _get_bwd_kernel(l) + sph_harm_bwd[num_blocks,]( + coords, + coord_grad_output, + sph_harm_grads, + BLOCK_SIZE, + coord_numel, + grads_numel, + offset, + output_stride, + ) + # first element ise None becausey are l_values which + # can't have gradients + return None, coord_grad_output diff --git a/src/equitriton/sph_harm/direct/y_0.py b/src/equitriton/sph_harm/direct/y_0.py new file mode 100644 index 0000000..ddf54ec --- /dev/null +++ b/src/equitriton/sph_harm/direct/y_0.py @@ -0,0 +1,118 @@ +import triton +import torch +from triton import language as tl + +from equitriton.utils import calculate_lastdim_num_blocks + +__all__ = ["ZerothOrderSphericalHarmonic"] + + +class ZerothOrderSphericalHarmonic(torch.autograd.Function): + @staticmethod + def forward( + ctx, + coords: torch.Tensor, + output_tensor: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + block_size: int = 64, + col_offset: int = 0, + ): + if not isinstance(output_tensor, torch.Tensor): + output_tensor = torch.ones( + (*coords.shape[:-1], 1), dtype=coords.dtype, device=coords.device + ) + ctx.save_for_backward(coords) + coord_numel = coords.numel() + output_numel = output_tensor.numel() + num_blocks = calculate_lastdim_num_blocks(coords, block_size) + zeroth_order_fwd[num_blocks,]( + coords, + output_tensor, + block_size, + coord_numel, + output_numel, + col_offset, + output_tensor.stride(-2), + ) + return output_tensor + + @staticmethod + def backward( + ctx, sph_grad_tensor: torch.Tensor, block_size: int = 64, col_offset: int = 0 + ) -> torch.Tensor: + (coords,) = ctx.saved_tensors + coord_grad_output = torch.zeros_like(coords) + num_blocks = calculate_lastdim_num_blocks(coords, block_size) + # call backward kernel + zeroth_order_bwd[num_blocks,]( + coord_grad_output, + sph_grad_tensor, + block_size, + coords.numel(), + sph_grad_tensor.numel(), + col_offset, + sph_grad_tensor.stride(-2), + ) + return coord_grad_output + + +def _torch_fwd(coords: torch.Tensor) -> torch.Tensor: + """ + PyTorch implementation of the kernel. This is designed + purely for unit testing to ensure that the Triton implementation + is behaving as intended. + + This function is generically named to make it easier for + it to be called programmatically: it is _not_ intended + to be used manually. + + Parameters + ---------- + coords : torch.Tensor + N-d tensor, where the last dimension corresponds to + xyz values. + + Returns + ------- + torch.Tensor + N-d tensor, where the last dimension corresponds to + each projection of the second order spherical harmonic. + """ + x = coords[..., 0].contiguous().unsqueeze(-1) + output = torch.ones_like(x) + return output + + +@triton.jit +def zeroth_order_fwd( + coord_ptr: tl.tensor, + output_ptr: tl.tensor, + block_size: tl.constexpr, + coord_numel: tl.constexpr, + output_numel: tl.constexpr, + col_offset: tl.constexpr, + output_stride: tl.constexpr, +): + # work out the row offsets + block_id = tl.program_id(0) + output_striding = tl.arange(0, block_size) * output_stride + output_row_offset = ( + output_striding + (block_size * output_stride * block_id) + col_offset + ) + tl.store(output_ptr + output_row_offset, 1.0, mask=output_row_offset < output_numel) + + +@triton.jit +def zeroth_order_bwd( + coord_ptr: tl.tensor, + coord_grad_ptr: tl.tensor, + sph_grad_ptr: tl.tensor, + block_size: tl.constexpr, + coord_numel: tl.constexpr, + output_numel: tl.constexpr, + col_offset: tl.constexpr, + output_stride: tl.constexpr, +): + # work out the row offsets + block_id = tl.program_id(0) # noqa: F841 + # do nothing in this function because no gradient contributions! diff --git a/src/equitriton/sph_harm/direct/y_1.py b/src/equitriton/sph_harm/direct/y_1.py new file mode 100644 index 0000000..17af709 --- /dev/null +++ b/src/equitriton/sph_harm/direct/y_1.py @@ -0,0 +1,187 @@ +import triton +import torch +from triton import language as tl + +from equitriton.utils import calculate_lastdim_num_blocks + +__all__ = ["FirstOrderSphericalHarmonic"] + + +class FirstOrderSphericalHarmonic(torch.autograd.Function): + @staticmethod + def forward( + ctx, + coords: torch.Tensor, + mask: torch.Tensor | None = None, + block_size: int = 64, + col_offset: int = 0, + ): + output_tensor = torch.empty( + (*coords.shape[:-1], 3), dtype=coords.dtype, device=coords.device + ) + coord_numel = coords.numel() + output_numel = output_tensor.numel() + num_blocks = calculate_lastdim_num_blocks(coords, block_size) + # apply the kernel + first_order_fwd[num_blocks,]( + coords, output_tensor, block_size, coord_numel, output_numel, col_offset + ) + ctx.save_for_backward(coords) + return output_tensor + + @staticmethod + def backward( + ctx, sph_grad_tensor: torch.Tensor, block_size: int = 64, col_offset: int = 0 + ) -> torch.Tensor: + (coords,) = ctx.saved_tensors + coord_grad_output = torch.zeros_like(coords) + num_blocks = calculate_lastdim_num_blocks(coords, block_size) + # call backward kernel + first_order_bwd[num_blocks,]( + coord_grad_output, + sph_grad_tensor, + block_size, + coords.numel(), + sph_grad_tensor.numel(), + col_offset, + ) + return coord_grad_output + + +def _torch_fwd(coords: torch.Tensor) -> torch.Tensor: + """ + PyTorch implementation of the kernel. This is designed + purely for unit testing to ensure that the Triton implementation + is behaving as intended. + + This function is generically named to make it easier for + it to be called programmatically: it is _not_ intended + to be used manually. + + Parameters + ---------- + coords : torch.Tensor + N-d tensor, where the last dimension corresponds to + xyz values. + + Returns + ------- + torch.Tensor + N-d tensor, where the last dimension corresponds to + each projection of the second order spherical harmonic. + """ + x = coords[..., 0].contiguous().unsqueeze(-1) + y = coords[..., 1].contiguous().unsqueeze(-1) + z = coords[..., 2].contiguous().unsqueeze(-1) + CONST_00 = 3**0.5 + Y10 = x * CONST_00 + Y11 = y * CONST_00 + Y12 = z * CONST_00 + return torch.cat([Y10, Y11, Y12], dim=-1) + + +@triton.jit +def first_order_fwd( + coord_ptr: tl.tensor, + output_ptr: tl.tensor, + block_size: tl.constexpr, + coord_numel: tl.constexpr, + output_numel: tl.constexpr, + col_offset: tl.constexpr, + output_stride: tl.constexpr, +): + # these are hardcoded because they are predetermined; + coord_stride = 3 + # work out the row offsets + block_id = tl.program_id(0) + coord_striding = tl.arange(0, block_size) * coord_stride + # as the name suggests, this is effectively every node/atom + coord_row_offset = coord_striding + (block_size * coord_stride * block_id) + x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel) + y = tl.load( + coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + z = tl.load( + coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + CONST_00 = tl.sqrt(3.0) + Y10 = CONST_00 * x + Y11 = CONST_00 * y + Y12 = CONST_00 * z + output_striding = tl.arange(0, block_size) * output_stride + output_row_offset = ( + output_striding + (block_size * output_stride * block_id) + col_offset + ) + tl.store(output_ptr + output_row_offset, Y10, mask=output_row_offset < output_numel) + tl.store( + output_ptr + output_row_offset + 1, + Y11, + mask=output_row_offset + 1 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 2, + Y12, + mask=output_row_offset + 2 < output_numel, + ) + + +@triton.jit +def first_order_bwd( + coord_ptr: tl.tensor, # noqa: F403 + coord_grad_ptr: tl.tensor, + sph_grad_ptr: tl.tensor, + block_size: tl.constexpr, + coord_numel: tl.constexpr, + output_numel: tl.constexpr, + col_offset: tl.constexpr, + output_stride: tl.constexpr, +): + # work out the row offsets + block_id = tl.program_id(0) + # these are hardcoded because they are predetermined; + coord_stride = 3 + coord_striding = tl.arange(0, block_size) * coord_stride + # as the name suggests, this is effectively every node/atom + coord_row_offset = coord_striding + (block_size * coord_stride * block_id) + output_striding = tl.arange(0, block_size) * output_stride + output_row_offset = ( + output_striding + (block_size * output_stride * block_id) + col_offset + ) + # load in gradients w.r.t. spherical harmonic projections + g_Y10 = tl.load( + sph_grad_ptr + output_row_offset, mask=output_row_offset < output_numel + ) + g_Y11 = tl.load( + sph_grad_ptr + output_row_offset + 1, mask=output_row_offset + 1 < output_numel + ) + g_Y12 = tl.load( + sph_grad_ptr + output_row_offset + 2, mask=output_row_offset + 2 < output_numel + ) + # read in current gradients + g_x = tl.load( + coord_grad_ptr + coord_row_offset, mask=coord_row_offset < coord_numel + ) + g_y = tl.load( + coord_grad_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + g_z = tl.load( + coord_grad_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + CONST_00 = tl.sqrt(3.0) + g_x += CONST_00 * g_Y10 + g_y += CONST_00 * g_Y11 + g_z += CONST_00 * g_Y12 + # write out gradients + tl.store( + coord_grad_ptr + coord_row_offset, g_x, mask=coord_row_offset < coord_numel + ) + tl.store( + coord_grad_ptr + coord_row_offset + 1, + g_y, + mask=coord_row_offset + 1 < coord_numel, + ) + tl.store( + coord_grad_ptr + coord_row_offset + 2, + g_z, + mask=coord_row_offset + 2 < coord_numel, + ) diff --git a/src/equitriton/sph_harm/direct/y_10.py b/src/equitriton/sph_harm/direct/y_10.py new file mode 100644 index 0000000..af925fb --- /dev/null +++ b/src/equitriton/sph_harm/direct/y_10.py @@ -0,0 +1,2607 @@ +import triton +import torch +from triton import language as tl + +from equitriton.utils import calculate_lastdim_num_blocks + +__all__ = ["TenthOrderSphericalHarmonic"] + + +class TenthOrderSphericalHarmonic(torch.autograd.Function): + @staticmethod + def forward( + ctx, + coords: torch.Tensor, + output_tensor: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + block_size: int = 64, + col_offset: int = 0, + ): + # allocate a tensor if one isn't given + if not isinstance(output_tensor, torch.Tensor): + output_tensor = torch.empty( + (*coords.shape[:-1], 21), dtype=coords.dtype, device=coords.device + ) + coord_numel = coords.numel() + output_numel = output_tensor.numel() + num_blocks = calculate_lastdim_num_blocks(coords, block_size) + # apply the kernel + tenth_order_fwd[num_blocks,]( + coords, + output_tensor, + block_size, + coord_numel, + output_numel, + col_offset, + output_tensor.stride(-2), + ) + ctx.save_for_backward(coords) + return output_tensor + + @staticmethod + def backward( + ctx, + sph_grad_tensor: torch.Tensor, + block_size: int = 64, + col_offset: int = 0, + ) -> torch.Tensor: + (coords,) = ctx.saved_tensors + coord_grad_output = torch.zeros_like(coords) + num_blocks = calculate_lastdim_num_blocks(coords, block_size) + # call backward kernel + tenth_order_bwd[num_blocks,]( + coords, + coord_grad_output, + sph_grad_tensor, + block_size, + coords.numel(), + sph_grad_tensor.numel(), + col_offset, + sph_grad_tensor.stride(-2), + ) + return coord_grad_output + + +def _torch_fwd(coords: torch.Tensor) -> torch.Tensor: + """ + PyTorch implementation of the kernel. This is designed + purely for unit testing to ensure that the Triton implementation + is behaving as intended. + + Parameters + ---------- + coords : torch.Tensor + N-d tensor, where the last dimension corresponds to + xyz values. + + Returns + ------- + torch.Tensor + N-d tensor, where the last dimension corresponds to + each projection of the second order spherical harmonic. + """ + x = coords[..., 0].contiguous().unsqueeze(-1) + y = coords[..., 1].contiguous().unsqueeze(-1) + z = coords[..., 2].contiguous().unsqueeze(-1) + # -------------------- variable and constant definitions + CONST001 = 1.75869118663323 + CONST002 = -1021.92317475320 + CONST004 = 4.58257569495584 + CONST005 = 6.63243980843400 + CONST006 = 4.82870805793735 + CONST007 = 4.97432985632550 + CONST008 = 1545.18657853995 + CONST009 = 10.5521471197994 + CONST010 = 12.1657520803952 + CONST011 = 13.2648796168680 + CONST013 = 15.7883647328499 + CONST014 = 15.7302121789667 + CONST015 = 16.4144510752435 + CONST016 = 12.8765548211663 + CONST017 = 19.3148322317494 + CONST018 = 16.7271353825295 + CONST019 = 22.8629854262320 + CONST020 = 535.268332240943 + CONST021 = 23.2135393295190 + CONST022 = 24.6216766128653 + CONST023 = 27.2034486491732 + CONST024 = 541.428124558099 + CONST025 = -994.666978169547 + CONST026 = 33.9852909359329 + CONST027 = 33.9852909359329 + CONST028 = 35.5238206489124 + CONST029 = -984.867064514610 + CONST030 = -4.82870805793735 + CONST031 = 1070.53666448189 + CONST032 = -463.555973561985 + CONST034 = 53.2857309733686 + CONST035 = 53.2857309733686 + CONST036 = 56.3871618715269 + CONST037 = 56.3871618715269 + CONST039 = -1989.33395633909 + CONST041 = -450.224943778107 + CONST042 = 66.9085415301178 + CONST043 = 69.6406179885570 + CONST044 = 69.6406179885570 + CONST045 = -437.967074894228 + CONST046 = 77.2593289269976 + CONST047 = 78.6510608948335 + CONST049 = -1969.73412902922 + CONST050 = 77.3468749368712 + CONST051 = 1624.28437367430 + CONST054 = 94.7301883970997 + CONST056 = 100.362812295177 + CONST057 = -412.049754277320 + CONST058 = 101.517773354644 + CONST059 = -5.63871618715269 + CONST060 = -406.071093418574 + CONST061 = 109.491768723557 + CONST062 = -393.946825805844 + CONST063 = -902.194589944431 + CONST065 = -386.296644634988 + CONST066 = -386.296644634988 + CONST070 = 4.97432985632550 + CONST071 = 150.074981259369 + CONST074 = 685.526905959165 + CONST075 = -337.668707833581 + CONST076 = -337.668707833581 + CONST077 = 176.178376404427 + CONST078 = 176.592751833137 + CONST079 = 185.708314636152 + CONST080 = -326.441383790078 + CONST081 = -1.60956935264578 + CONST082 = -1.97354559160624 + CONST083 = 196.973412902922 + CONST085 = -824.099508554641 + CONST087 = -1.97354559160624 + CONST088 = -305.867618423396 + CONST089 = -305.867618423396 + CONST090 = 721.755671955545 + CONST091 = -305.867618423396 + CONST092 = -300.731529981477 + CONST093 = -300.731529981477 + CONST094 = -1.75869118663323 + CONST095 = -290.050781013267 + CONST097 = 225.548647486108 + CONST098 = 225.548647486108 + CONST099 = -284.190565191299 + CONST101 = -278.562471954228 + CONST102 = -278.562471954228 + CONST103 = -787.893651611688 + CONST104 = -787.893651611688 + CONST105 = 772.593289269975 + CONST106 = 787.893651611688 + CONST107 = 787.893651611688 + CONST108 = 278.562471954228 + CONST109 = -742.833258544608 + CONST110 = -1.65810995210850 + CONST112 = -1761.78376404427 + CONST113 = -223.028471767059 + CONST114 = -734.076568351780 + CONST116 = -220.222970505534 + CONST117 = 1321.33782303320 + CONST118 = 1321.33782303320 + CONST119 = -203.035546709287 + CONST120 = -1.65810995210850 + CONST121 = -196.973412902922 + CONST122 = -196.973412902922 + CONST123 = -696.406179885570 + CONST125 = 338.322971229162 + CONST126 = -1181.84047741753 + CONST127 = -669.085415301178 + CONST128 = -669.085415301178 + CONST129 = -154.518657853995 + CONST130 = -154.518657853995 + CONST131 = 360.877835977772 + CONST132 = -150.074981259369 + CONST133 = -2707.14062279049 + CONST134 = -146.815313670356 + CONST135 = 880.891882022136 + CONST136 = 1392.81235977114 + CONST137 = 1392.81235977114 + CONST138 = -131.315608601948 + CONST139 = -131.315608601948 + CONST141 = -125.841697431734 + CONST142 = -125.841697431734 + CONST143 = -122.415518921279 + CONST145 = 406.071093418574 + CONST146 = -103.107953136506 + CONST147 = -103.107953136506 + CONST148 = -101.517773354644 + CONST149 = -98.4867064514610 + CONST150 = 412.049754277320 + CONST151 = -94.7301883970997 + CONST152 = -1114.24988781691 + CONST153 = -88.2963759165686 + CONST154 = -1624.28437367430 + CONST155 = -82.8889148474622 + CONST156 = -82.8889148474622 + CONST158 = -590.920238708766 + CONST159 = -77.3468749368713 + CONST160 = -77.2593289269975 + CONST161 = 2486.66744542387 + CONST162 = -2626.31217203896 + CONST165 = -571.272421632637 + CONST166 = -56.2781179722634 + CONST167 = -49.2433532257305 + CONST168 = -49.2433532257305 + CONST169 = 984.867064514610 + CONST170 = -541.428124558099 + CONST171 = -24.6216766128653 + CONST172 = -22.8629854262320 + CONST173 = -16.4144510752435 + CONST174 = -15.7883647328499 + CONST175 = -14.0695294930659 + CONST176 = -13.2648796168680 + CONST177 = -11.2774323743054 + CONST178 = -14.5025390506634 + CONST179 = -6.63243980843400 + CONST180 = -5.63871618715269 + CONST181 = 1532.88476212980 + CONST182 = -3.21913870529156 + CONST183 = -2.72034486491732 + CONST184 = -1.12774323743054 + # ordering is really messy because I've refactored + # the higher powers in terms of the lower ones + VAR05 = x * x * x * x * x + VAR06 = x * x * x * x + VAR07 = x * x * x + VAR08 = x * x + VAR00 = VAR05 * VAR05 + VAR01 = VAR05 * VAR06 + VAR02 = VAR06 * VAR06 + VAR03 = VAR06 * VAR07 + VAR04 = VAR07 * VAR07 + VAR14 = y * y * y * y * y + VAR15 = y * y * y * y + VAR16 = y * y * y + VAR17 = y * y + VAR09 = VAR14 * VAR14 + VAR10 = VAR14 * VAR15 + VAR11 = VAR15 * VAR15 + VAR12 = VAR15 * VAR16 + VAR13 = VAR16 * VAR16 + VAR23 = z * z * z * z * z + VAR24 = z * z * z * z + VAR25 = z * z * z + VAR26 = z * z + VAR18 = VAR23 * VAR23 + VAR19 = VAR23 * VAR24 + VAR20 = VAR24 * VAR24 + VAR21 = VAR24 * VAR25 + VAR22 = VAR25 * VAR25 + # -------------------- kernel implementations + Y00 = ( + CONST023 * VAR01 * z + + CONST023 * VAR19 * x + + CONST074 * VAR05 * VAR23 + + CONST080 * VAR03 * VAR25 + + CONST080 * VAR07 * VAR21 + ) + Y01 = y * ( + CONST002 * VAR07 * VAR22 + + CONST010 * VAR01 + + CONST045 * VAR03 * VAR26 + + CONST061 * VAR20 * x + + CONST181 * VAR05 * VAR24 + ) + Y02 = ( + CONST013 * VAR01 * z + + CONST054 * VAR07 * VAR21 + + CONST151 * VAR03 * VAR25 + + CONST174 * VAR19 * x + + VAR17 + * ( + -CONST039 * VAR05 * VAR25 + + CONST039 * VAR07 * VAR23 + + CONST099 * VAR03 * z + - CONST099 * VAR21 * x + ) + ) + Y03 = VAR16 * ( + CONST024 * VAR22 * x + + CONST051 * VAR05 * VAR26 + + CONST133 * VAR07 * VAR24 + + CONST159 * VAR03 + ) + y * ( + CONST095 * VAR03 * VAR26 + - CONST119 * VAR05 * VAR24 + + CONST145 * VAR07 * VAR22 + + CONST148 * VAR20 * x + - CONST178 * VAR01 + ) + Y04 = ( + CONST009 * VAR01 * z + + VAR03 * (CONST076 * VAR17 * z + CONST175 * VAR25) + + VAR05 * (CONST106 * VAR15 * z + CONST107 * VAR17 * VAR25 + CONST167 * VAR23) + + VAR07 + * (CONST106 * VAR17 * VAR23 + CONST162 * VAR15 * VAR25 + CONST175 * VAR21) + + x * (CONST009 * VAR19 + CONST075 * VAR17 * VAR21 + CONST106 * VAR15 * VAR23) + ) + Y05 = ( + VAR14 * (CONST077 * VAR05 + CONST112 * VAR07 * VAR26 + CONST135 * VAR24 * x) + + VAR16 + * ( + -CONST114 * VAR07 * VAR24 + + CONST114 * VAR22 * x + + CONST117 * VAR05 * VAR26 + + CONST134 * VAR03 + ) + + y + * ( + CONST014 * VAR01 + + CONST047 * VAR20 * x + + CONST116 * VAR05 * VAR24 + + CONST141 * VAR03 * VAR26 + ) + ) + Y06 = ( + CONST005 * VAR01 * z + + VAR03 * (CONST011 * VAR25 + CONST102 * VAR17 * z) + + VAR05 * (CONST101 * VAR17 * VAR25 - CONST152 * VAR15 * z) + + VAR07 * (CONST108 * VAR17 * VAR23 + CONST109 * VAR13 * z + CONST176 * VAR21) + + x + * ( + CONST108 * VAR17 * VAR21 + - CONST109 * VAR13 * VAR25 + + CONST152 * VAR15 * VAR23 + + CONST179 * VAR19 + ) + ) + Y07 = ( + VAR12 * (-CONST041 * VAR26 * x + CONST132 * VAR07) + + VAR14 * (-CONST062 * VAR05 + CONST103 * VAR07 * VAR26 + CONST126 * VAR24 * x) + + VAR16 + * ( + CONST083 * VAR05 * VAR26 + + CONST121 * VAR03 + - CONST158 * VAR22 * x + + CONST169 * VAR07 * VAR24 + ) + + y + * ( + CONST015 * VAR01 + + CONST138 * VAR07 * VAR22 + + CONST149 * VAR05 * VAR24 + + CONST168 * VAR20 * x + ) + ) + Y08 = ( + -CONST182 * VAR01 * z + + VAR03 * (CONST016 * VAR25 + CONST129 * VAR17 * z) + + VAR05 * (CONST017 * VAR23 + CONST032 * VAR17 * VAR25 + CONST105 * VAR15 * z) + + VAR07 + * ( + CONST008 * VAR15 * VAR25 + + CONST016 * VAR21 + + CONST032 * VAR17 * VAR23 + + CONST085 * VAR13 * z + ) + + x + * ( + CONST078 * VAR11 * z + + CONST085 * VAR13 * VAR25 + + CONST105 * VAR15 * VAR23 + + CONST129 * VAR17 * VAR21 + - CONST182 * VAR19 + ) + ) + Y09 = ( + CONST018 * VAR01 * y + + VAR03 * (CONST042 * VAR26 * y + CONST113 * VAR16) + + VAR05 * (CONST020 * VAR14 + CONST056 * VAR24 * y + CONST128 * VAR16 * VAR26) + + VAR07 + * ( + CONST031 * VAR14 * VAR26 + + CONST042 * VAR22 * y + + CONST088 * VAR12 + + CONST127 * VAR16 * VAR24 + ) + + x + * ( + CONST018 * VAR20 * y + + CONST020 * VAR14 * VAR24 + + CONST026 * VAR10 + + CONST088 * VAR12 * VAR26 + + CONST113 * VAR16 * VAR22 + ) + ) + Y10 = ( + CONST004 * VAR09 + + CONST037 * VAR17 * VAR20 + + CONST093 * VAR15 * VAR22 + + CONST131 * VAR13 * VAR24 + + CONST147 * VAR11 * VAR26 + + CONST184 * VAR00 + + CONST184 * VAR18 + + VAR02 * (CONST036 * VAR17 + CONST059 * VAR26) + + VAR04 * (CONST092 * VAR15 + CONST098 * VAR17 * VAR26 + CONST177 * VAR24) + + VAR06 + * ( + CONST063 * VAR15 * VAR26 + + CONST125 * VAR17 * VAR24 + + CONST131 * VAR13 + + CONST177 * VAR22 + ) + + VAR08 + * ( + CONST063 * VAR15 * VAR24 + + CONST090 * VAR13 * VAR26 + + CONST097 * VAR17 * VAR22 + + CONST146 * VAR11 + + CONST180 * VAR20 + ) + ) + Y11 = ( + CONST018 * VAR19 * y + + VAR21 * (CONST042 * VAR08 * y + CONST113 * VAR16) + + VAR23 * (CONST020 * VAR14 + CONST056 * VAR06 * y + CONST128 * VAR08 * VAR16) + + VAR25 + * ( + CONST031 * VAR08 * VAR14 + + CONST042 * VAR04 * y + + CONST091 * VAR12 + + CONST127 * VAR06 * VAR16 + ) + + z + * ( + CONST018 * VAR02 * y + + CONST020 * VAR06 * VAR14 + + CONST027 * VAR10 + + CONST089 * VAR08 * VAR12 + + CONST113 * VAR04 * VAR16 + ) + ) + Y12 = ( + CONST057 * VAR13 * VAR24 + - CONST066 * VAR15 * VAR22 + + CONST081 * VAR00 + - CONST081 * VAR18 + - CONST153 * VAR11 * VAR26 + + CONST160 * VAR17 * VAR20 + + VAR02 * (CONST030 * VAR26 + CONST046 * VAR17) + + VAR04 * (CONST066 * VAR15 - CONST129 * VAR17 * VAR26 + CONST182 * VAR24) + + VAR06 * (CONST065 * VAR15 * VAR26 + CONST150 * VAR13 - CONST182 * VAR22) + + VAR08 + * ( + CONST006 * VAR20 + - CONST066 * VAR15 * VAR24 + + CONST130 * VAR17 * VAR22 + + CONST153 * VAR11 + ) + ) + Y13 = ( + VAR12 * (CONST041 * VAR08 * z + CONST071 * VAR25) + + VAR14 * (CONST062 * VAR23 + CONST107 * VAR08 * VAR25 - CONST126 * VAR06 * z) + + VAR16 + * ( + CONST029 * VAR06 * VAR25 + - CONST121 * VAR21 + + CONST122 * VAR08 * VAR23 + + CONST158 * VAR04 * z + ) + + y + * ( + -CONST138 * VAR04 * VAR25 + - CONST149 * VAR06 * VAR23 + - CONST168 * VAR02 * z + + CONST173 * VAR19 + ) + ) + Y14 = ( + CONST044 * VAR17 * VAR20 + + CONST079 * VAR13 * VAR24 + + CONST101 * VAR15 * VAR22 + + CONST110 * VAR00 + + CONST120 * VAR18 + + VAR02 * (CONST043 * VAR17 + CONST070 * VAR26) + + VAR04 * (CONST021 * VAR24 + CONST101 * VAR15 + CONST101 * VAR17 * VAR26) + + VAR06 + * ( + CONST021 * VAR22 + + CONST079 * VAR13 + + CONST123 * VAR17 * VAR24 + + CONST137 * VAR15 * VAR26 + ) + + VAR08 + * ( + CONST007 * VAR20 + + CONST101 * VAR17 * VAR22 + + CONST136 * VAR15 * VAR24 + + CONST152 * VAR13 * VAR26 + ) + ) + Y15 = ( + VAR14 * (CONST077 * VAR23 + CONST112 * VAR08 * VAR25 + CONST135 * VAR06 * z) + + VAR16 + * ( + CONST114 * VAR04 * z + - CONST114 * VAR06 * VAR25 + + CONST118 * VAR08 * VAR23 + + CONST134 * VAR21 + ) + + y + * ( + CONST014 * VAR19 + + CONST047 * VAR02 * z + + CONST116 * VAR06 * VAR23 + + CONST142 * VAR08 * VAR21 + ) + ) + Y16 = ( + CONST001 * VAR18 + + CONST094 * VAR00 + - CONST139 * VAR15 * VAR22 + + CONST166 * VAR17 * VAR20 + + VAR02 * (CONST019 * VAR26 - CONST166 * VAR17) + + VAR04 * (CONST022 * VAR24 + CONST104 * VAR17 * VAR26 + CONST139 * VAR15) + + VAR06 * (-CONST049 * VAR15 * VAR26 + CONST171 * VAR22) + + VAR08 + * (CONST049 * VAR15 * VAR24 + CONST106 * VAR17 * VAR22 + CONST172 * VAR20) + ) + Y17 = VAR16 * ( + CONST050 * VAR21 + - CONST133 * VAR06 * VAR25 + + CONST154 * VAR08 * VAR23 + + CONST170 * VAR04 * z + ) + y * ( + CONST058 * VAR02 * z + + CONST060 * VAR04 * VAR25 + - CONST095 * VAR08 * VAR21 + + CONST119 * VAR06 * VAR23 + + CONST178 * VAR19 + ) + Y18 = ( + CONST034 * VAR02 * VAR26 + + CONST035 * VAR08 * VAR20 + + CONST082 * VAR00 + + CONST087 * VAR18 + + CONST155 * VAR04 * VAR24 + + CONST156 * VAR06 * VAR22 + + VAR17 + * ( + CONST025 * VAR04 * VAR26 + + CONST025 * VAR08 * VAR22 + + CONST028 * VAR02 + + CONST028 * VAR20 + + CONST161 * VAR06 * VAR24 + ) + ) + Y19 = y * ( + CONST002 * VAR04 * VAR25 + + CONST010 * VAR19 + + CONST045 * VAR08 * VAR21 + + CONST061 * VAR02 * z + + CONST181 * VAR06 * VAR23 + ) + Y20 = ( + -CONST143 * VAR02 * VAR26 + + CONST143 * VAR08 * VAR20 + + CONST165 * VAR04 * VAR24 + - CONST165 * VAR06 * VAR22 + + CONST183 * VAR00 + - CONST183 * VAR18 + ) + # not the prettiest way to concatenate, but better than + # messing with the linter + tensors = [ + Y00, + Y01, + Y02, + Y03, + Y04, + Y05, + Y06, + Y07, + Y08, + Y09, + Y10, + Y11, + Y12, + Y13, + Y14, + Y15, + Y16, + Y17, + Y18, + Y19, + Y20, + ] + return torch.cat(tensors, dim=-1) + + +@triton.jit +def tenth_order_fwd( + coord_ptr: tl.tensor, + output_ptr: tl.tensor, + block_size: tl.constexpr, + coord_numel: tl.constexpr, + output_numel: tl.constexpr, + col_offset: tl.constexpr, + output_stride: tl.constexpr, +): + # these are hardcoded because they are predetermined; + coord_stride = 3 + # work out the row offsets + block_id = tl.program_id(0) + coord_striding = tl.arange(0, block_size) * coord_stride + # as the name suggests, this is effectively every node/atom + coord_row_offset = coord_striding + (block_size * coord_stride * block_id) + x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel) + y = tl.load( + coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + z = tl.load( + coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + # -------------------- variable and constant definitions + CONST001 = 1.75869118663323 + CONST002 = -1021.92317475320 + CONST004 = 4.58257569495584 + CONST005 = 6.63243980843400 + CONST006 = 4.82870805793735 + CONST007 = 4.97432985632550 + CONST008 = 1545.18657853995 + CONST009 = 10.5521471197994 + CONST010 = 12.1657520803952 + CONST011 = 13.2648796168680 + CONST013 = 15.7883647328499 + CONST014 = 15.7302121789667 + CONST015 = 16.4144510752435 + CONST016 = 12.8765548211663 + CONST017 = 19.3148322317494 + CONST018 = 16.7271353825295 + CONST019 = 22.8629854262320 + CONST020 = 535.268332240943 + CONST021 = 23.2135393295190 + CONST022 = 24.6216766128653 + CONST023 = 27.2034486491732 + CONST024 = 541.428124558099 + CONST025 = -994.666978169547 + CONST026 = 33.9852909359329 + CONST027 = 33.9852909359329 + CONST028 = 35.5238206489124 + CONST029 = -984.867064514610 + CONST030 = -4.82870805793735 + CONST031 = 1070.53666448189 + CONST032 = -463.555973561985 + CONST034 = 53.2857309733686 + CONST035 = 53.2857309733686 + CONST036 = 56.3871618715269 + CONST037 = 56.3871618715269 + CONST039 = -1989.33395633909 + CONST041 = -450.224943778107 + CONST042 = 66.9085415301178 + CONST043 = 69.6406179885570 + CONST044 = 69.6406179885570 + CONST045 = -437.967074894228 + CONST046 = 77.2593289269976 + CONST047 = 78.6510608948335 + CONST049 = -1969.73412902922 + CONST050 = 77.3468749368712 + CONST051 = 1624.28437367430 + CONST054 = 94.7301883970997 + CONST056 = 100.362812295177 + CONST057 = -412.049754277320 + CONST058 = 101.517773354644 + CONST059 = -5.63871618715269 + CONST060 = -406.071093418574 + CONST061 = 109.491768723557 + CONST062 = -393.946825805844 + CONST063 = -902.194589944431 + CONST065 = -386.296644634988 + CONST066 = -386.296644634988 + CONST070 = 4.97432985632550 + CONST071 = 150.074981259369 + CONST074 = 685.526905959165 + CONST075 = -337.668707833581 + CONST076 = -337.668707833581 + CONST077 = 176.178376404427 + CONST078 = 176.592751833137 + CONST079 = 185.708314636152 + CONST080 = -326.441383790078 + CONST081 = -1.60956935264578 + CONST082 = -1.97354559160624 + CONST083 = 196.973412902922 + CONST085 = -824.099508554641 + CONST087 = -1.97354559160624 + CONST088 = -305.867618423396 + CONST089 = -305.867618423396 + CONST090 = 721.755671955545 + CONST091 = -305.867618423396 + CONST092 = -300.731529981477 + CONST093 = -300.731529981477 + CONST094 = -1.75869118663323 + CONST095 = -290.050781013267 + CONST097 = 225.548647486108 + CONST098 = 225.548647486108 + CONST099 = -284.190565191299 + CONST101 = -278.562471954228 + CONST102 = -278.562471954228 + CONST103 = -787.893651611688 + CONST104 = -787.893651611688 + CONST105 = 772.593289269975 + CONST106 = 787.893651611688 + CONST107 = 787.893651611688 + CONST108 = 278.562471954228 + CONST109 = -742.833258544608 + CONST110 = -1.65810995210850 + CONST112 = -1761.78376404427 + CONST113 = -223.028471767059 + CONST114 = -734.076568351780 + CONST116 = -220.222970505534 + CONST117 = 1321.33782303320 + CONST118 = 1321.33782303320 + CONST119 = -203.035546709287 + CONST120 = -1.65810995210850 + CONST121 = -196.973412902922 + CONST122 = -196.973412902922 + CONST123 = -696.406179885570 + CONST125 = 338.322971229162 + CONST126 = -1181.84047741753 + CONST127 = -669.085415301178 + CONST128 = -669.085415301178 + CONST129 = -154.518657853995 + CONST130 = -154.518657853995 + CONST131 = 360.877835977772 + CONST132 = -150.074981259369 + CONST133 = -2707.14062279049 + CONST134 = -146.815313670356 + CONST135 = 880.891882022136 + CONST136 = 1392.81235977114 + CONST137 = 1392.81235977114 + CONST138 = -131.315608601948 + CONST139 = -131.315608601948 + CONST141 = -125.841697431734 + CONST142 = -125.841697431734 + CONST143 = -122.415518921279 + CONST145 = 406.071093418574 + CONST146 = -103.107953136506 + CONST147 = -103.107953136506 + CONST148 = -101.517773354644 + CONST149 = -98.4867064514610 + CONST150 = 412.049754277320 + CONST151 = -94.7301883970997 + CONST152 = -1114.24988781691 + CONST153 = -88.2963759165686 + CONST154 = -1624.28437367430 + CONST155 = -82.8889148474622 + CONST156 = -82.8889148474622 + CONST158 = -590.920238708766 + CONST159 = -77.3468749368713 + CONST160 = -77.2593289269975 + CONST161 = 2486.66744542387 + CONST162 = -2626.31217203896 + CONST165 = -571.272421632637 + CONST166 = -56.2781179722634 + CONST167 = -49.2433532257305 + CONST168 = -49.2433532257305 + CONST169 = 984.867064514610 + CONST170 = -541.428124558099 + CONST171 = -24.6216766128653 + CONST172 = -22.8629854262320 + CONST173 = -16.4144510752435 + CONST174 = -15.7883647328499 + CONST175 = -14.0695294930659 + CONST176 = -13.2648796168680 + CONST177 = -11.2774323743054 + CONST178 = -14.5025390506634 + CONST179 = -6.63243980843400 + CONST180 = -5.63871618715269 + CONST181 = 1532.88476212980 + CONST182 = -3.21913870529156 + CONST183 = -2.72034486491732 + CONST184 = -1.12774323743054 + # ordering is really messy because I've refactored + # the higher powers in terms of the lower ones + VAR05 = x * x * x * x * x + VAR06 = x * x * x * x + VAR07 = x * x * x + VAR08 = x * x + VAR00 = VAR05 * VAR05 + VAR01 = VAR05 * VAR06 + VAR02 = VAR06 * VAR06 + VAR03 = VAR06 * VAR07 + VAR04 = VAR07 * VAR07 + VAR14 = y * y * y * y * y + VAR15 = y * y * y * y + VAR16 = y * y * y + VAR17 = y * y + VAR09 = VAR14 * VAR14 + VAR10 = VAR14 * VAR15 + VAR11 = VAR15 * VAR15 + VAR12 = VAR15 * VAR16 + VAR13 = VAR16 * VAR16 + VAR23 = z * z * z * z * z + VAR24 = z * z * z * z + VAR25 = z * z * z + VAR26 = z * z + VAR18 = VAR23 * VAR23 + VAR19 = VAR23 * VAR24 + VAR20 = VAR24 * VAR24 + VAR21 = VAR24 * VAR25 + VAR22 = VAR25 * VAR25 + # -------------------- kernel implementations + Y00 = ( + CONST023 * VAR01 * z + + CONST023 * VAR19 * x + + CONST074 * VAR05 * VAR23 + + CONST080 * VAR03 * VAR25 + + CONST080 * VAR07 * VAR21 + ) + Y01 = y * ( + CONST002 * VAR07 * VAR22 + + CONST010 * VAR01 + + CONST045 * VAR03 * VAR26 + + CONST061 * VAR20 * x + + CONST181 * VAR05 * VAR24 + ) + Y02 = ( + CONST013 * VAR01 * z + + CONST054 * VAR07 * VAR21 + + CONST151 * VAR03 * VAR25 + + CONST174 * VAR19 * x + + VAR17 + * ( + -CONST039 * VAR05 * VAR25 + + CONST039 * VAR07 * VAR23 + + CONST099 * VAR03 * z + - CONST099 * VAR21 * x + ) + ) + Y03 = VAR16 * ( + CONST024 * VAR22 * x + + CONST051 * VAR05 * VAR26 + + CONST133 * VAR07 * VAR24 + + CONST159 * VAR03 + ) + y * ( + CONST095 * VAR03 * VAR26 + - CONST119 * VAR05 * VAR24 + + CONST145 * VAR07 * VAR22 + + CONST148 * VAR20 * x + - CONST178 * VAR01 + ) + Y04 = ( + CONST009 * VAR01 * z + + VAR03 * (CONST076 * VAR17 * z + CONST175 * VAR25) + + VAR05 * (CONST106 * VAR15 * z + CONST107 * VAR17 * VAR25 + CONST167 * VAR23) + + VAR07 + * (CONST106 * VAR17 * VAR23 + CONST162 * VAR15 * VAR25 + CONST175 * VAR21) + + x * (CONST009 * VAR19 + CONST075 * VAR17 * VAR21 + CONST106 * VAR15 * VAR23) + ) + Y05 = ( + VAR14 * (CONST077 * VAR05 + CONST112 * VAR07 * VAR26 + CONST135 * VAR24 * x) + + VAR16 + * ( + -CONST114 * VAR07 * VAR24 + + CONST114 * VAR22 * x + + CONST117 * VAR05 * VAR26 + + CONST134 * VAR03 + ) + + y + * ( + CONST014 * VAR01 + + CONST047 * VAR20 * x + + CONST116 * VAR05 * VAR24 + + CONST141 * VAR03 * VAR26 + ) + ) + Y06 = ( + CONST005 * VAR01 * z + + VAR03 * (CONST011 * VAR25 + CONST102 * VAR17 * z) + + VAR05 * (CONST101 * VAR17 * VAR25 - CONST152 * VAR15 * z) + + VAR07 * (CONST108 * VAR17 * VAR23 + CONST109 * VAR13 * z + CONST176 * VAR21) + + x + * ( + CONST108 * VAR17 * VAR21 + - CONST109 * VAR13 * VAR25 + + CONST152 * VAR15 * VAR23 + + CONST179 * VAR19 + ) + ) + Y07 = ( + VAR12 * (-CONST041 * VAR26 * x + CONST132 * VAR07) + + VAR14 * (-CONST062 * VAR05 + CONST103 * VAR07 * VAR26 + CONST126 * VAR24 * x) + + VAR16 + * ( + CONST083 * VAR05 * VAR26 + + CONST121 * VAR03 + - CONST158 * VAR22 * x + + CONST169 * VAR07 * VAR24 + ) + + y + * ( + CONST015 * VAR01 + + CONST138 * VAR07 * VAR22 + + CONST149 * VAR05 * VAR24 + + CONST168 * VAR20 * x + ) + ) + Y08 = ( + -CONST182 * VAR01 * z + + VAR03 * (CONST016 * VAR25 + CONST129 * VAR17 * z) + + VAR05 * (CONST017 * VAR23 + CONST032 * VAR17 * VAR25 + CONST105 * VAR15 * z) + + VAR07 + * ( + CONST008 * VAR15 * VAR25 + + CONST016 * VAR21 + + CONST032 * VAR17 * VAR23 + + CONST085 * VAR13 * z + ) + + x + * ( + CONST078 * VAR11 * z + + CONST085 * VAR13 * VAR25 + + CONST105 * VAR15 * VAR23 + + CONST129 * VAR17 * VAR21 + - CONST182 * VAR19 + ) + ) + Y09 = ( + CONST018 * VAR01 * y + + VAR03 * (CONST042 * VAR26 * y + CONST113 * VAR16) + + VAR05 * (CONST020 * VAR14 + CONST056 * VAR24 * y + CONST128 * VAR16 * VAR26) + + VAR07 + * ( + CONST031 * VAR14 * VAR26 + + CONST042 * VAR22 * y + + CONST088 * VAR12 + + CONST127 * VAR16 * VAR24 + ) + + x + * ( + CONST018 * VAR20 * y + + CONST020 * VAR14 * VAR24 + + CONST026 * VAR10 + + CONST088 * VAR12 * VAR26 + + CONST113 * VAR16 * VAR22 + ) + ) + Y10 = ( + CONST004 * VAR09 + + CONST037 * VAR17 * VAR20 + + CONST093 * VAR15 * VAR22 + + CONST131 * VAR13 * VAR24 + + CONST147 * VAR11 * VAR26 + + CONST184 * VAR00 + + CONST184 * VAR18 + + VAR02 * (CONST036 * VAR17 + CONST059 * VAR26) + + VAR04 * (CONST092 * VAR15 + CONST098 * VAR17 * VAR26 + CONST177 * VAR24) + + VAR06 + * ( + CONST063 * VAR15 * VAR26 + + CONST125 * VAR17 * VAR24 + + CONST131 * VAR13 + + CONST177 * VAR22 + ) + + VAR08 + * ( + CONST063 * VAR15 * VAR24 + + CONST090 * VAR13 * VAR26 + + CONST097 * VAR17 * VAR22 + + CONST146 * VAR11 + + CONST180 * VAR20 + ) + ) + Y11 = ( + CONST018 * VAR19 * y + + VAR21 * (CONST042 * VAR08 * y + CONST113 * VAR16) + + VAR23 * (CONST020 * VAR14 + CONST056 * VAR06 * y + CONST128 * VAR08 * VAR16) + + VAR25 + * ( + CONST031 * VAR08 * VAR14 + + CONST042 * VAR04 * y + + CONST091 * VAR12 + + CONST127 * VAR06 * VAR16 + ) + + z + * ( + CONST018 * VAR02 * y + + CONST020 * VAR06 * VAR14 + + CONST027 * VAR10 + + CONST089 * VAR08 * VAR12 + + CONST113 * VAR04 * VAR16 + ) + ) + Y12 = ( + CONST057 * VAR13 * VAR24 + - CONST066 * VAR15 * VAR22 + + CONST081 * VAR00 + - CONST081 * VAR18 + - CONST153 * VAR11 * VAR26 + + CONST160 * VAR17 * VAR20 + + VAR02 * (CONST030 * VAR26 + CONST046 * VAR17) + + VAR04 * (CONST066 * VAR15 - CONST129 * VAR17 * VAR26 + CONST182 * VAR24) + + VAR06 * (CONST065 * VAR15 * VAR26 + CONST150 * VAR13 - CONST182 * VAR22) + + VAR08 + * ( + CONST006 * VAR20 + - CONST066 * VAR15 * VAR24 + + CONST130 * VAR17 * VAR22 + + CONST153 * VAR11 + ) + ) + Y13 = ( + VAR12 * (CONST041 * VAR08 * z + CONST071 * VAR25) + + VAR14 * (CONST062 * VAR23 + CONST107 * VAR08 * VAR25 - CONST126 * VAR06 * z) + + VAR16 + * ( + CONST029 * VAR06 * VAR25 + - CONST121 * VAR21 + + CONST122 * VAR08 * VAR23 + + CONST158 * VAR04 * z + ) + + y + * ( + -CONST138 * VAR04 * VAR25 + - CONST149 * VAR06 * VAR23 + - CONST168 * VAR02 * z + + CONST173 * VAR19 + ) + ) + Y14 = ( + CONST044 * VAR17 * VAR20 + + CONST079 * VAR13 * VAR24 + + CONST101 * VAR15 * VAR22 + + CONST110 * VAR00 + + CONST120 * VAR18 + + VAR02 * (CONST043 * VAR17 + CONST070 * VAR26) + + VAR04 * (CONST021 * VAR24 + CONST101 * VAR15 + CONST101 * VAR17 * VAR26) + + VAR06 + * ( + CONST021 * VAR22 + + CONST079 * VAR13 + + CONST123 * VAR17 * VAR24 + + CONST137 * VAR15 * VAR26 + ) + + VAR08 + * ( + CONST007 * VAR20 + + CONST101 * VAR17 * VAR22 + + CONST136 * VAR15 * VAR24 + + CONST152 * VAR13 * VAR26 + ) + ) + Y15 = ( + VAR14 * (CONST077 * VAR23 + CONST112 * VAR08 * VAR25 + CONST135 * VAR06 * z) + + VAR16 + * ( + CONST114 * VAR04 * z + - CONST114 * VAR06 * VAR25 + + CONST118 * VAR08 * VAR23 + + CONST134 * VAR21 + ) + + y + * ( + CONST014 * VAR19 + + CONST047 * VAR02 * z + + CONST116 * VAR06 * VAR23 + + CONST142 * VAR08 * VAR21 + ) + ) + Y16 = ( + CONST001 * VAR18 + + CONST094 * VAR00 + - CONST139 * VAR15 * VAR22 + + CONST166 * VAR17 * VAR20 + + VAR02 * (CONST019 * VAR26 - CONST166 * VAR17) + + VAR04 * (CONST022 * VAR24 + CONST104 * VAR17 * VAR26 + CONST139 * VAR15) + + VAR06 * (-CONST049 * VAR15 * VAR26 + CONST171 * VAR22) + + VAR08 + * (CONST049 * VAR15 * VAR24 + CONST106 * VAR17 * VAR22 + CONST172 * VAR20) + ) + Y17 = VAR16 * ( + CONST050 * VAR21 + - CONST133 * VAR06 * VAR25 + + CONST154 * VAR08 * VAR23 + + CONST170 * VAR04 * z + ) + y * ( + CONST058 * VAR02 * z + + CONST060 * VAR04 * VAR25 + - CONST095 * VAR08 * VAR21 + + CONST119 * VAR06 * VAR23 + + CONST178 * VAR19 + ) + Y18 = ( + CONST034 * VAR02 * VAR26 + + CONST035 * VAR08 * VAR20 + + CONST082 * VAR00 + + CONST087 * VAR18 + + CONST155 * VAR04 * VAR24 + + CONST156 * VAR06 * VAR22 + + VAR17 + * ( + CONST025 * VAR04 * VAR26 + + CONST025 * VAR08 * VAR22 + + CONST028 * VAR02 + + CONST028 * VAR20 + + CONST161 * VAR06 * VAR24 + ) + ) + Y19 = y * ( + CONST002 * VAR04 * VAR25 + + CONST010 * VAR19 + + CONST045 * VAR08 * VAR21 + + CONST061 * VAR02 * z + + CONST181 * VAR06 * VAR23 + ) + Y20 = ( + -CONST143 * VAR02 * VAR26 + + CONST143 * VAR08 * VAR20 + + CONST165 * VAR04 * VAR24 + - CONST165 * VAR06 * VAR22 + + CONST183 * VAR00 + - CONST183 * VAR18 + ) + output_striding = tl.arange(0, block_size) * output_stride + output_row_offset = ( + output_striding + (block_size * output_stride * block_id) + col_offset + ) + tl.store(output_ptr + output_row_offset, Y00, mask=output_row_offset < output_numel) + tl.store( + output_ptr + output_row_offset + 1, + Y01, + mask=output_row_offset + 1 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 2, + Y02, + mask=output_row_offset + 2 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 3, + Y03, + mask=output_row_offset + 3 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 4, + Y04, + mask=output_row_offset + 4 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 5, + Y05, + mask=output_row_offset + 5 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 6, + Y06, + mask=output_row_offset + 6 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 7, + Y07, + mask=output_row_offset + 7 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 8, + Y08, + mask=output_row_offset + 8 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 9, + Y09, + mask=output_row_offset + 9 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 10, + Y10, + mask=output_row_offset + 10 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 11, + Y11, + mask=output_row_offset + 11 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 12, + Y12, + mask=output_row_offset + 12 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 13, + Y13, + mask=output_row_offset + 13 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 14, + Y14, + mask=output_row_offset + 14 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 15, + Y15, + mask=output_row_offset + 15 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 16, + Y16, + mask=output_row_offset + 16 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 17, + Y17, + mask=output_row_offset + 17 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 18, + Y18, + mask=output_row_offset + 18 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 19, + Y19, + mask=output_row_offset + 19 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 20, + Y20, + mask=output_row_offset + 20 < output_numel, + ) + + +@triton.jit +def tenth_order_bwd( + coord_ptr: tl.tensor, + coord_grad_ptr: tl.tensor, + sph_grad_ptr: tl.tensor, + block_size: tl.constexpr, + coord_numel: tl.constexpr, + output_numel: tl.constexpr, + col_offset: tl.constexpr, + output_stride: tl.constexpr, +): + # work out the row offsets + block_id = tl.program_id(0) + # these are hardcoded because they are predetermined; + coord_stride = 3 + coord_striding = tl.arange(0, block_size) * coord_stride + # as the name suggests, this is effectively every node/atom + coord_row_offset = coord_striding + (block_size * coord_stride * block_id) + x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel) + y = tl.load( + coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + z = tl.load( + coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + output_striding = tl.arange(0, block_size) * output_stride + output_row_offset = ( + output_striding + (block_size * output_stride * block_id) + col_offset + ) + # load in gradients w.r.t. spherical harmonic projections + g_0 = tl.load( + sph_grad_ptr + output_row_offset, mask=output_row_offset < output_numel + ) + g_1 = tl.load( + sph_grad_ptr + output_row_offset + 1, mask=output_row_offset + 1 < output_numel + ) + g_2 = tl.load( + sph_grad_ptr + output_row_offset + 2, mask=output_row_offset + 2 < output_numel + ) + g_3 = tl.load( + sph_grad_ptr + output_row_offset + 3, mask=output_row_offset + 3 < output_numel + ) + g_4 = tl.load( + sph_grad_ptr + output_row_offset + 4, mask=output_row_offset + 4 < output_numel + ) + g_5 = tl.load( + sph_grad_ptr + output_row_offset + 5, mask=output_row_offset + 5 < output_numel + ) + g_6 = tl.load( + sph_grad_ptr + output_row_offset + 6, mask=output_row_offset + 6 < output_numel + ) + g_7 = tl.load( + sph_grad_ptr + output_row_offset + 7, mask=output_row_offset + 7 < output_numel + ) + g_8 = tl.load( + sph_grad_ptr + output_row_offset + 8, mask=output_row_offset + 8 < output_numel + ) + g_9 = tl.load( + sph_grad_ptr + output_row_offset + 9, mask=output_row_offset + 9 < output_numel + ) + g_10 = tl.load( + sph_grad_ptr + output_row_offset + 10, + mask=output_row_offset + 10 < output_numel, + ) + g_11 = tl.load( + sph_grad_ptr + output_row_offset + 11, + mask=output_row_offset + 11 < output_numel, + ) + g_12 = tl.load( + sph_grad_ptr + output_row_offset + 12, + mask=output_row_offset + 12 < output_numel, + ) + g_13 = tl.load( + sph_grad_ptr + output_row_offset + 13, + mask=output_row_offset + 13 < output_numel, + ) + g_14 = tl.load( + sph_grad_ptr + output_row_offset + 14, + mask=output_row_offset + 14 < output_numel, + ) + g_15 = tl.load( + sph_grad_ptr + output_row_offset + 15, + mask=output_row_offset + 15 < output_numel, + ) + g_16 = tl.load( + sph_grad_ptr + output_row_offset + 16, + mask=output_row_offset + 16 < output_numel, + ) + g_17 = tl.load( + sph_grad_ptr + output_row_offset + 17, + mask=output_row_offset + 17 < output_numel, + ) + g_18 = tl.load( + sph_grad_ptr + output_row_offset + 18, + mask=output_row_offset + 18 < output_numel, + ) + g_19 = tl.load( + sph_grad_ptr + output_row_offset + 19, + mask=output_row_offset + 19 < output_numel, + ) + g_20 = tl.load( + sph_grad_ptr + output_row_offset + 20, + mask=output_row_offset + 20 < output_numel, + ) + # -------------------- variable and constant definitions + CONST000 = 2.00000000000000 + CONST002 = 4.00000000000000 + CONST003 = 4.82870805793735 + CONST004 = 6.00000000000000 + CONST005 = 4.97432985632550 + CONST006 = 8.00000000000000 + CONST007 = 4.97432985632550 + CONST008 = 10.5521471197994 + CONST009 = 3.00000000000000 + CONST010 = 5.00000000000000 + CONST011 = 7.00000000000000 + CONST012 = 13.2648796168680 + CONST014 = 12.1657520803952 + CONST015 = 16.7271353825295 + CONST016 = -2030.35546709287 + CONST017 = 19.3148322317494 + CONST018 = -6131.53904851919 + CONST019 = 22.8629854262320 + CONST020 = 23.2135393295190 + CONST021 = 24.6216766128653 + CONST022 = 17.5869118663323 + CONST024 = 28.9722483476241 + CONST025 = 33.9852909359329 + CONST026 = 33.9852909359329 + CONST027 = 35.5238206489124 + CONST028 = 6180.74631415980 + CONST029 = 38.6296644634988 + CONST030 = 39.7946388506040 + CONST031 = 38.6296644634988 + CONST032 = -2007.25624590353 + CONST033 = -2007.25624590353 + CONST034 = 45.8257569495584 + CONST035 = 45.7259708524640 + CONST037 = 56.3871618715269 + CONST038 = 56.2781179722634 + CONST039 = -1989.33395633909 + CONST040 = -1989.33395633909 + CONST041 = 59.6919582759060 + CONST042 = 66.9085415301178 + CONST043 = 69.6406179885570 + CONST044 = -8121.42186837148 + CONST045 = 77.2593289269976 + CONST046 = 78.6510608948335 + CONST047 = -1969.73412902922 + CONST048 = 77.3468749368712 + CONST049 = -1969.73412902922 + CONST050 = -9.65741611587469 + CONST051 = 90.1358837481638 + CONST053 = 94.9693240781945 + CONST055 = 96.5741611587469 + CONST057 = 98.4867064514610 + CONST058 = 100.362812295177 + CONST059 = 101.517773354644 + CONST060 = 106.571461946737 + CONST061 = 106.571461946737 + CONST062 = 109.491768723557 + CONST063 = 109.491768723557 + CONST064 = 112.774323743054 + CONST065 = 112.774323743054 + CONST067 = 2165.26701586663 + CONST070 = 133.817083060236 + CONST071 = 139.281235977114 + CONST072 = 139.281235977114 + CONST073 = 141.571909610700 + CONST074 = 142.095282595650 + CONST075 = 147.730059677192 + CONST076 = 150.544218442765 + CONST077 = 150.074981259369 + CONST079 = 2202.22970505534 + CONST080 = -3939.46825805844 + CONST081 = -5968.00186901728 + CONST082 = 176.592751833137 + CONST083 = 176.178376404427 + CONST085 = 185.708314636152 + CONST087 = 196.973412902922 + CONST089 = 225.548647486108 + CONST090 = 225.548647486108 + CONST091 = 4330.53403173327 + CONST093 = 244.831037842559 + CONST094 = -1804.38917988886 + CONST095 = -1804.38917988886 + CONST097 = 2317.77986780993 + CONST098 = 278.562471954228 + CONST100 = 284.190565191299 + CONST101 = -1761.78376404427 + CONST103 = -9946.66978169547 + CONST104 = 9.94865971265100 + CONST108 = -7878.93651611688 + CONST111 = 338.322971229162 + CONST112 = 360.877835977772 + CONST114 = -1671.37483172537 + CONST116 = 2436.42656051144 + CONST119 = 393.946825805844 + CONST120 = -1648.19901710928 + CONST121 = 401.451249180707 + CONST122 = 406.071093418574 + CONST123 = 412.049754277320 + CONST125 = -1624.28437367430 + CONST126 = 426.285847786949 + CONST127 = 426.285847786948 + CONST128 = 2486.66744542387 + CONST130 = 451.097294972216 + CONST131 = 451.097294972216 + CONST132 = 451.097294972215 + CONST133 = 6606.68911516602 + CONST134 = 6606.68911516602 + CONST135 = -1575.78730322338 + CONST136 = -1575.78730322338 + CONST137 = -3608.77835977772 + CONST139 = -1545.18657853995 + CONST140 = -1545.18657853995 + CONST142 = 535.268332240943 + CONST143 = 4635.55973561985 + CONST144 = 541.428124558099 + CONST145 = -3545.52143225260 + CONST146 = 557.124943908456 + CONST147 = -3523.56752808854 + CONST148 = -5571.24943908456 + CONST151 = 15.7883647328499 + CONST153 = 2642.67564606641 + CONST154 = 2642.67564606641 + CONST155 = 2676.34166120471 + CONST156 = 629.208487158668 + CONST158 = 4727.36190967013 + CONST159 = -1392.81235977114 + CONST160 = -1390.66792068596 + CONST162 = 663.111318779698 + CONST163 = -3427.63452979582 + CONST164 = -1378.81389032045 + CONST165 = 676.645942458323 + CONST167 = -1338.17083060236 + CONST168 = -1338.17083060236 + CONST169 = 721.755671955545 + CONST171 = 2785.62471954228 + CONST173 = 772.593289269975 + CONST175 = 787.893651611688 + CONST176 = 787.893651611688 + CONST177 = 6.63243980843400 + CONST178 = 812.142186837148 + CONST180 = -1218.21328025572 + CONST181 = -1202.92611992591 + CONST182 = -1202.92611992591 + CONST183 = -3248.56874734859 + CONST184 = -3248.56874734859 + CONST185 = -5285.35129213281 + CONST186 = -1181.84047741753 + CONST190 = 2936.30627340712 + CONST192 = 2954.60119354383 + CONST193 = -1114.24988781691 + CONST194 = -16.5810995210850 + CONST195 = -1101.11485252767 + CONST196 = -1081.63060497797 + CONST197 = 15.7302121789667 + CONST199 = 984.867064514610 + CONST202 = -1027.70719569249 + CONST203 = -1021.92317475320 + CONST204 = -3065.76952425960 + CONST205 = -1015.17773354644 + CONST206 = 3090.37315707990 + CONST207 = -994.666978169547 + CONST208 = -984.867064514610 + CONST209 = -984.867064514610 + CONST210 = -979.324151370235 + CONST211 = 1070.53666448189 + CONST212 = -979.324151370235 + CONST213 = 3151.57460644675 + CONST216 = -927.111947123971 + CONST217 = -927.111947123970 + CONST218 = -5.63871618715269 + CONST219 = -2954.60119354383 + CONST220 = -902.194589944431 + CONST221 = -900.449887556215 + CONST222 = -880.891882022136 + CONST223 = -880.891882022136 + CONST224 = -875.934149788456 + CONST226 = -4944.59705132784 + CONST228 = 3248.56874734859 + CONST229 = -835.687415862684 + CONST230 = 1218.21328025572 + CONST231 = -824.099508554641 + CONST232 = -824.863625092051 + CONST233 = -824.863625092051 + CONST234 = -812.142186837148 + CONST235 = 5352.68332240943 + CONST236 = -787.893651611688 + CONST237 = -787.893651611688 + CONST238 = -772.593289269976 + CONST239 = -742.833258544608 + CONST240 = -2785.62471954228 + CONST241 = -734.076568351780 + CONST242 = 1321.33782303320 + CONST243 = 1321.33782303320 + CONST244 = -706.371007332549 + CONST245 = -696.406179885570 + CONST246 = 1353.29188491665 + CONST247 = -675.337415667161 + CONST248 = -675.337415667161 + CONST250 = 3427.63452979582 + CONST251 = -669.085415301178 + CONST252 = -669.085415301178 + CONST253 = -669.085415301178 + CONST255 = -663.111318779698 + CONST256 = -2707.14062279049 + CONST258 = 1392.81235977114 + CONST259 = 1412.74201466510 + CONST260 = -4727.36190967013 + CONST261 = -2676.34166120471 + CONST262 = -618.074631415980 + CONST263 = -611.735236846792 + CONST264 = -611.735236846792 + CONST265 = 1443.51134391109 + CONST266 = -590.920238708766 + CONST267 = -10828.5624911620 + CONST268 = -580.101562026534 + CONST269 = -2626.31217203896 + CONST272 = 5571.24943908456 + CONST273 = -12.8765548211663 + CONST274 = -557.124943908456 + CONST275 = -557.124943908456 + CONST277 = -541.428124558099 + CONST278 = -6685.49932690147 + CONST279 = 7664.42381064899 + CONST280 = -525.262434407792 + CONST281 = 1532.88476212980 + CONST283 = -497.333489084773 + CONST284 = -497.333489084773 + CONST285 = -492.433532257305 + CONST286 = 1575.78730322338 + CONST287 = 1575.78730322338 + CONST288 = -463.555973561985 + CONST289 = -450.224943778107 + CONST290 = -450.224943778107 + CONST291 = -450.224943778108 + CONST292 = -437.967074894228 + CONST293 = -2472.29852566392 + CONST294 = 1624.28437367430 + CONST295 = -2472.29852566392 + CONST296 = -406.071093418574 + CONST297 = -393.946825805844 + CONST298 = -393.946825805844 + CONST299 = -2436.42656051144 + CONST300 = -386.296644634988 + CONST301 = -386.296644634988 + CONST302 = -4456.99955126765 + CONST303 = -337.668707833581 + CONST304 = -337.668707833581 + CONST305 = -331.555659389849 + CONST306 = -331.555659389849 + CONST307 = -2363.68095483506 + CONST309 = -309.037315707990 + CONST310 = -4404.45941011068 + CONST311 = -309.037315707990 + CONST312 = -305.867618423396 + CONST313 = -305.867618423396 + CONST314 = -305.867618423396 + CONST315 = -300.731529981477 + CONST316 = 9946.66978169547 + CONST318 = -290.050781013267 + CONST319 = -284.190565191299 + CONST320 = -278.562471954228 + CONST321 = -278.562471954228 + CONST322 = -2317.77986780993 + CONST323 = -10505.2486881558 + CONST324 = -251.683394863467 + CONST325 = -251.683394863467 + CONST326 = -246.216766128653 + CONST327 = -244.831037842559 + CONST328 = -2285.08968653055 + CONST329 = -2285.08968653055 + CONST330 = 3862.96644634988 + CONST331 = -223.028471767059 + CONST332 = -220.222970505534 + CONST333 = -206.215906273013 + CONST334 = -203.035546709287 + CONST335 = -196.973412902922 + CONST336 = -196.973412902922 + CONST337 = -182.903883409856 + CONST338 = -2228.49977563382 + CONST340 = 16.4144510752435 + CONST341 = 3939.46825805844 + CONST342 = 3939.46825805844 + CONST343 = -154.518657853995 + CONST344 = -154.518657853995 + CONST345 = -150.074981259369 + CONST346 = -147.730059677191 + CONST347 = -146.815313670356 + CONST348 = -142.095282595650 + CONST349 = -131.315608601948 + CONST350 = -131.315608601948 + CONST351 = -130.522851455970 + CONST352 = -125.841697431734 + CONST353 = -125.841697431734 + CONST354 = -112.556235944527 + CONST355 = -103.107953136506 + CONST356 = -101.517773354644 + CONST357 = 1949.93730367960 + CONST358 = -98.4867064514610 + CONST359 = -98.4867064514610 + CONST360 = -2141.07332896377 + CONST361 = -2141.07332896377 + CONST362 = -92.8541573180760 + CONST363 = -88.2963759165686 + CONST366 = -77.3468749368713 + CONST367 = 8121.42186837148 + CONST369 = -67.6645942458323 + CONST372 = -59.6919582759060 + CONST373 = -49.2433532257305 + CONST374 = -49.2433532257305 + CONST375 = -45.1097294972216 + CONST376 = -45.1097294972216 + CONST377 = -42.2085884791976 + CONST378 = -27.2034486491732 + CONST379 = -24.6216766128653 + CONST380 = -22.8629854262320 + CONST381 = -19.7354559160624 + CONST383 = -17.5869118663323 + CONST384 = -16.4144510752435 + CONST385 = -16.0956935264578 + CONST386 = -14.5025390506634 + CONST388 = -16.5810995210850 + CONST389 = -15.7883647328499 + CONST390 = -14.0695294930659 + CONST391 = -11.2774323743054 + CONST392 = -11.2774323743054 + CONST393 = -13.2648796168680 + CONST394 = -6.63243980843400 + CONST395 = -5.63871618715269 + CONST396 = -4.82870805793735 + CONST397 = -3.21913870529156 + CONST398 = -11.2774323743054 + VAR05 = x * x * x * x * x + VAR06 = x * x * x * x + VAR07 = x * x * x + VAR08 = x * x + VAR01 = VAR05 * VAR06 + VAR02 = VAR06 * VAR06 + VAR03 = VAR06 * VAR07 + VAR04 = VAR07 * VAR07 + VAR14 = y * y * y * y * y + VAR15 = y * y * y * y + VAR16 = y * y * y + VAR17 = y * y + VAR10 = VAR14 * VAR15 + VAR11 = VAR15 * VAR15 + VAR12 = VAR15 * VAR16 + VAR13 = VAR16 * VAR16 + VAR23 = z * z * z * z * z + VAR24 = z * z * z * z + VAR25 = z * z * z + VAR26 = z * z + VAR19 = VAR23 * VAR24 + VAR20 = VAR24 * VAR24 + VAR21 = VAR24 * VAR25 + VAR22 = VAR25 * VAR25 + # -------------------- kernel implementations + g_x = tl.load( + coord_grad_ptr + coord_row_offset, mask=coord_row_offset < coord_numel + ) + g_y = tl.load( + coord_grad_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + g_z = tl.load( + coord_grad_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + g_x += ( + g_0 + * ( + CONST093 * VAR02 * z + + CONST210 * VAR08 * VAR21 + + CONST250 * VAR06 * VAR23 + + CONST328 * VAR04 * VAR25 + - CONST378 * VAR19 + ) + + g_1 + * y + * ( + CONST062 * VAR20 + + CONST063 * VAR02 + + CONST204 * VAR04 * VAR26 + + CONST204 * VAR08 * VAR22 + + CONST279 * VAR06 * VAR24 + ) + + g_10 + * ( + CONST000 + * x + * ( + CONST089 * VAR17 * VAR22 + + CONST169 * VAR13 * VAR26 + + CONST220 * VAR15 * VAR24 + + CONST355 * VAR11 + + CONST395 * VAR20 + ) + + CONST002 + * VAR07 + * ( + CONST111 * VAR17 * VAR24 + + CONST112 * VAR13 + + CONST220 * VAR15 * VAR26 + + CONST392 * VAR22 + ) + + CONST004 + * VAR05 + * (CONST090 * VAR17 * VAR26 + CONST315 * VAR15 + CONST392 * VAR24) + + CONST006 * VAR03 * (CONST037 * VAR17 + CONST218 * VAR26) + + CONST391 * VAR01 + ) + + g_11 + * ( + CONST070 * VAR21 * x * y + + VAR23 * (CONST121 * VAR07 * y + CONST168 * VAR16 * x) + + VAR25 + * (CONST121 * VAR05 * y + CONST261 * VAR07 * VAR16 - CONST361 * VAR14 * x) + + z + * ( + CONST070 * VAR03 * y + + CONST167 * VAR05 * VAR16 + + CONST263 * VAR12 * x + - CONST361 * VAR07 * VAR14 + ) + ) + + g_12 + * ( + CONST000 + * x + * ( + CONST003 * VAR20 + - CONST301 * VAR15 * VAR24 + + CONST343 * VAR17 * VAR22 + + CONST363 * VAR11 + ) + + CONST002 + * VAR07 + * (CONST123 * VAR13 + CONST300 * VAR15 * VAR26 - CONST397 * VAR22) + + CONST004 + * VAR05 + * (CONST301 * VAR15 - CONST344 * VAR17 * VAR26 + CONST397 * VAR24) + + CONST006 * VAR03 * (CONST045 * VAR17 + CONST396 * VAR26) + + CONST385 * VAR01 + ) + + g_13 + * ( + CONST221 * VAR12 * x * z + + VAR14 * (-CONST260 * VAR07 * z + CONST286 * VAR25 * x) + + VAR16 + * (CONST080 * VAR07 * VAR25 + CONST145 * VAR05 * z + CONST297 * VAR23 * x) + + y + * ( + -CONST237 * VAR05 * VAR25 + - CONST297 * VAR07 * VAR23 + - CONST298 * VAR03 * z + ) + ) + + g_14 + * ( + CONST000 + * x + * ( + CONST005 * VAR20 + - CONST159 * VAR15 * VAR24 + + CONST193 * VAR13 * VAR26 + + CONST320 * VAR17 * VAR22 + ) + + CONST002 + * VAR07 + * ( + CONST020 * VAR22 + + CONST085 * VAR13 + + CONST245 * VAR17 * VAR24 + + CONST258 * VAR15 * VAR26 + ) + + CONST004 + * VAR05 + * (CONST020 * VAR24 + CONST320 * VAR15 + CONST320 * VAR17 * VAR26) + + CONST006 * VAR03 * (CONST007 * VAR26 + CONST043 * VAR17) + + CONST388 * VAR01 + ) + + g_15 + * ( + VAR14 * (-CONST147 * VAR07 * z + CONST147 * VAR25 * x) + + VAR16 + * (CONST153 * VAR23 * x + CONST190 * VAR07 * VAR25 + CONST310 * VAR05 * z) + + y + * (CONST156 * VAR03 * z + CONST222 * VAR07 * VAR23 + CONST324 * VAR21 * x) + ) + + g_16 + * ( + CONST000 + * x + * (CONST047 * VAR15 * VAR24 + CONST175 * VAR17 * VAR22 + CONST380 * VAR20) + + CONST002 * VAR07 * (-CONST047 * VAR15 * VAR26 + CONST379 * VAR22) + + CONST004 + * VAR05 + * (CONST021 * VAR24 + CONST236 * VAR17 * VAR26 + CONST349 * VAR15) + + CONST006 * VAR03 * (CONST019 * VAR26 + CONST038 * VAR17) + + CONST383 * VAR01 + ) + + g_17 + * ( + VAR16 + * (CONST183 * VAR23 * x + CONST184 * VAR05 * z - CONST267 * VAR07 * VAR25) + + y + * ( + CONST178 * VAR03 * z + + CONST234 * VAR07 * VAR23 + - CONST268 * VAR21 * x + + CONST299 * VAR05 * VAR25 + ) + ) + + g_18 + * ( + CONST060 * VAR20 * x + + CONST126 * VAR03 * VAR26 + + CONST283 * VAR05 * VAR24 + + CONST305 * VAR07 * VAR22 + + CONST381 * VAR01 + + VAR17 + * ( + CONST039 * VAR22 * x + + CONST081 * VAR05 * VAR26 + + CONST316 * VAR07 * VAR24 + - CONST319 * VAR03 + ) + ) + + g_19 + * y + * ( + CONST018 * VAR05 * VAR25 + - CONST018 * VAR07 * VAR23 + - CONST224 * VAR03 * z + + CONST224 * VAR21 * x + ) + + g_2 + * ( + CONST074 * VAR02 * z + + CONST100 * VAR08 * VAR21 + + CONST255 * VAR04 * VAR25 + + CONST389 * VAR19 + + VAR17 + * ( + CONST040 * VAR04 * z + + CONST081 * VAR08 * VAR23 + - CONST103 * VAR06 * VAR25 + - CONST319 * VAR21 + ) + ) + + g_20 + * ( + CONST163 * VAR05 * VAR24 + - CONST212 * VAR03 * VAR26 + + CONST327 * VAR20 * x + - CONST329 * VAR07 * VAR22 + + CONST378 * VAR01 + ) + + g_3 + * ( + VAR16 + * ( + CONST044 * VAR08 * VAR24 + + CONST144 * VAR22 + + CONST277 * VAR04 + + CONST367 * VAR06 * VAR26 + ) + + y + * ( + CONST016 * VAR04 * VAR26 + - CONST205 * VAR06 * VAR24 + + CONST230 * VAR08 * VAR22 + - CONST351 * VAR02 + + CONST356 * VAR20 + ) + ) + + g_4 + * ( + CONST008 * VAR19 + + CONST009 + * VAR08 + * (CONST175 * VAR17 * VAR23 + CONST269 * VAR15 * VAR25 + CONST390 * VAR21) + + CONST010 + * VAR06 + * (CONST175 * VAR15 * z + CONST176 * VAR17 * VAR25 + CONST373 * VAR23) + + CONST011 * VAR04 * (CONST303 * VAR17 * z + CONST390 * VAR25) + + CONST053 * VAR02 * z + + CONST175 * VAR15 * VAR23 + + CONST304 * VAR17 * VAR21 + ) + + g_5 + * ( + VAR14 * (CONST185 * VAR08 * VAR26 - CONST222 * VAR06 - CONST223 * VAR24) + + VAR16 + * ( + CONST079 * VAR08 * VAR24 + + CONST133 * VAR06 * VAR26 + + CONST202 * VAR04 + + CONST241 * VAR22 + ) + + y + * ( + CONST046 * VAR20 + + CONST073 * VAR02 + + CONST195 * VAR06 * VAR24 + + CONST222 * VAR04 * VAR26 + ) + ) + + g_6 + * ( + CONST009 + * VAR08 + * (CONST098 * VAR17 * VAR23 + CONST239 * VAR13 * z + CONST393 * VAR21) + + CONST010 * VAR06 * (-CONST193 * VAR15 * z + CONST320 * VAR17 * VAR25) + + CONST011 * VAR04 * (CONST012 * VAR25 + CONST321 * VAR17 * z) + + CONST041 * VAR02 * z + + CONST098 * VAR17 * VAR21 + + CONST193 * VAR15 * VAR23 + - CONST239 * VAR13 * VAR25 + + CONST394 * VAR19 + ) + + g_7 + * ( + VAR12 * (CONST289 * VAR08 - CONST290 * VAR26) + + VAR14 * (-CONST049 * VAR06 + CONST186 * VAR24 + CONST307 * VAR08 * VAR26) + + VAR16 + * ( + CONST164 * VAR04 + + CONST192 * VAR08 * VAR24 + + CONST199 * VAR06 * VAR26 + - CONST266 * VAR22 + ) + + y + * ( + CONST075 * VAR02 + + CONST285 * VAR06 * VAR24 + + CONST297 * VAR08 * VAR22 + + CONST374 * VAR20 + ) + ) + + g_8 + * ( + CONST009 + * VAR08 + * ( + -CONST140 * VAR15 * VAR25 + + CONST231 * VAR13 * z + - CONST273 * VAR21 + + CONST288 * VAR17 * VAR23 + ) + + CONST010 + * VAR06 + * (CONST017 * VAR23 + CONST173 * VAR15 * z + CONST288 * VAR17 * VAR25) + + CONST011 * VAR04 * (-CONST273 * VAR25 + CONST344 * VAR17 * z) + + CONST024 * VAR02 * z + + CONST082 * VAR11 * z + + CONST173 * VAR15 * VAR23 + + CONST231 * VAR13 * VAR25 + + CONST344 * VAR17 * VAR21 + - CONST397 * VAR19 + ) + + g_9 + * ( + CONST009 + * VAR08 + * ( + CONST042 * VAR22 * y + + CONST211 * VAR14 * VAR26 + + CONST251 * VAR16 * VAR24 + + CONST312 * VAR12 + ) + + CONST010 + * VAR06 + * (CONST058 * VAR24 * y + CONST142 * VAR14 + CONST252 * VAR16 * VAR26) + + CONST011 * VAR04 * (CONST042 * VAR26 * y + CONST331 * VAR16) + + CONST015 * VAR20 * y + + CONST025 * VAR10 + + CONST076 * VAR02 * y + + CONST142 * VAR14 * VAR24 + + CONST312 * VAR12 * VAR26 + + CONST331 * VAR16 * VAR22 + ) + ) + g_y += ( + CONST000 + * g_18 + * y + * ( + CONST027 * VAR02 + + CONST027 * VAR20 + + CONST128 * VAR06 * VAR24 + + CONST207 * VAR04 * VAR26 + + CONST207 * VAR08 * VAR22 + ) + + CONST000 + * g_2 + * y + * ( + -CONST039 * VAR05 * VAR25 + + CONST039 * VAR07 * VAR23 + + CONST319 * VAR03 * z + - CONST319 * VAR21 * x + ) + + g_1 + * ( + CONST014 * VAR01 + + CONST062 * VAR20 * x + + CONST203 * VAR07 * VAR22 + + CONST281 * VAR05 * VAR24 + + CONST292 * VAR03 * VAR26 + ) + + g_10 + * ( + CONST034 * VAR10 + + CONST064 * VAR20 * y + + CONST065 * VAR02 * y + + CONST067 * VAR14 * VAR24 + + CONST182 * VAR16 * VAR22 + + CONST233 * VAR12 * VAR26 + + VAR04 * (CONST131 * VAR26 * y + CONST181 * VAR16) + + VAR06 + * (CONST067 * VAR14 + CONST137 * VAR16 * VAR26 + CONST165 * VAR24 * y) + + VAR08 + * ( + CONST091 * VAR14 * VAR26 + + CONST130 * VAR22 * y + + CONST137 * VAR16 * VAR24 + + CONST232 * VAR12 + ) + ) + + g_11 + * ( + CONST015 * VAR19 + + VAR21 * (CONST042 * VAR08 + CONST253 * VAR17) + + VAR23 * (CONST033 * VAR08 * VAR17 + CONST058 * VAR06 + CONST155 * VAR15) + + VAR25 + * ( + CONST032 * VAR06 * VAR17 + + CONST042 * VAR04 + + CONST235 * VAR08 * VAR15 + + CONST361 * VAR13 + ) + + z + * ( + CONST015 * VAR02 + + CONST155 * VAR06 * VAR15 + + CONST253 * VAR04 * VAR17 + - CONST312 * VAR11 + + CONST360 * VAR08 * VAR13 + ) + ) + + g_12 + * ( + -CONST140 * VAR16 * VAR22 + - CONST244 * VAR12 * VAR26 + + CONST293 * VAR14 * VAR24 + + CONST343 * VAR20 * y + - CONST344 * VAR02 * y + + VAR04 * (CONST140 * VAR16 - CONST311 * VAR26 * y) + + VAR06 * (CONST139 * VAR16 * VAR26 - CONST295 * VAR14) + + VAR08 + * (-CONST140 * VAR16 * VAR24 + CONST244 * VAR12 + CONST309 * VAR22 * y) + ) + + g_13 + * ( + CONST009 + * VAR17 + * ( + CONST208 * VAR06 * VAR25 + + CONST266 * VAR04 * z + + CONST335 * VAR08 * VAR23 + - CONST336 * VAR21 + ) + + CONST010 + * VAR15 + * (CONST176 * VAR08 * VAR25 - CONST186 * VAR06 * z + CONST298 * VAR23) + + CONST011 * VAR13 * (CONST077 * VAR25 + CONST290 * VAR08 * z) + - CONST350 * VAR04 * VAR25 + - CONST358 * VAR06 * VAR23 + - CONST374 * VAR02 * z + + CONST384 * VAR19 + ) + + g_14 + * ( + CONST071 * VAR02 * y + + CONST072 * VAR20 * y + - CONST193 * VAR14 * VAR24 + + CONST193 * VAR16 * VAR22 + + VAR04 * (CONST193 * VAR16 + CONST274 * VAR26 * y) + + VAR06 + * (CONST159 * VAR24 * y - CONST193 * VAR14 + CONST272 * VAR16 * VAR26) + + VAR08 + * ( + -CONST148 * VAR16 * VAR24 + + CONST274 * VAR22 * y + + CONST278 * VAR14 * VAR26 + ) + ) + + g_15 + * ( + CONST009 + * VAR17 + * ( + CONST241 * VAR04 * z + - CONST241 * VAR06 * VAR25 + + CONST242 * VAR08 * VAR23 + + CONST347 * VAR21 + ) + + CONST010 + * VAR15 + * (CONST083 * VAR23 + CONST101 * VAR08 * VAR25 - CONST223 * VAR06 * z) + + CONST046 * VAR02 * z + + CONST197 * VAR19 + + CONST332 * VAR06 * VAR23 + + CONST352 * VAR08 * VAR21 + ) + + g_16 + * ( + -CONST108 * VAR06 * VAR16 * VAR26 + - CONST280 * VAR16 * VAR22 + - CONST354 * VAR02 * y + + CONST354 * VAR20 * y + + VAR04 * (CONST135 * VAR26 * y + CONST280 * VAR16) + + VAR08 * (CONST108 * VAR16 * VAR24 + CONST287 * VAR22 * y) + ) + + g_17 + * ( + CONST009 + * VAR17 + * ( + CONST048 * VAR21 + + CONST125 * VAR08 * VAR23 + - CONST256 * VAR06 * VAR25 + + CONST277 * VAR04 * z + ) + + CONST059 * VAR02 * z + + CONST296 * VAR04 * VAR25 + - CONST318 * VAR08 * VAR21 + + CONST334 * VAR06 * VAR23 + + CONST386 * VAR19 + ) + + g_19 + * ( + CONST014 * VAR19 + + CONST062 * VAR02 * z + + CONST203 * VAR04 * VAR25 + + CONST281 * VAR06 * VAR23 + + CONST292 * VAR08 * VAR21 + ) + + g_3 + * ( + CONST009 + * VAR17 + * ( + CONST144 * VAR22 * x + + CONST256 * VAR07 * VAR24 + + CONST294 * VAR05 * VAR26 + + CONST366 * VAR03 + ) + + CONST122 * VAR07 * VAR22 + + CONST318 * VAR03 * VAR26 + - CONST334 * VAR05 * VAR24 + + CONST356 * VAR20 * x + - CONST386 * VAR01 + ) + + g_4 + * ( + CONST248 * VAR03 * y * z + + VAR05 * (CONST213 * VAR16 * z + CONST286 * VAR25 * y) + + VAR07 * (CONST287 * VAR23 * y + CONST323 * VAR16 * VAR25) + + x * (CONST213 * VAR16 * VAR23 + CONST247 * VAR21 * y) + ) + + g_5 + * ( + CONST009 + * VAR17 + * ( + -CONST241 * VAR07 * VAR24 + + CONST241 * VAR22 * x + + CONST243 * VAR05 * VAR26 + + CONST347 * VAR03 + ) + + CONST010 + * VAR15 + * (CONST083 * VAR05 + CONST101 * VAR07 * VAR26 - CONST223 * VAR24 * x) + + CONST046 * VAR20 * x + + CONST197 * VAR01 + + CONST332 * VAR05 * VAR24 + + CONST353 * VAR03 * VAR26 + ) + + g_6 + * ( + CONST275 * VAR03 * y * z + + VAR05 * (CONST274 * VAR25 * y - CONST302 * VAR16 * z) + + VAR07 * (CONST146 * VAR23 * y + CONST302 * VAR14 * z) + + x + * ( + CONST146 * VAR21 * y + - CONST302 * VAR14 * VAR25 + + CONST302 * VAR16 * VAR23 + ) + ) + + g_7 + * ( + CONST009 + * VAR17 + * ( + CONST087 * VAR05 * VAR26 + - CONST209 * VAR07 * VAR24 + - CONST266 * VAR22 * x + + CONST336 * VAR03 + ) + + CONST010 + * VAR15 + * (CONST186 * VAR24 * x + CONST237 * VAR07 * VAR26 - CONST298 * VAR05) + + CONST011 * VAR13 * (-CONST290 * VAR26 * x + CONST345 * VAR07) + + CONST340 * VAR01 + + CONST350 * VAR07 * VAR22 + + CONST358 * VAR05 * VAR24 + + CONST374 * VAR20 * x + ) + + g_8 + * ( + CONST311 * VAR03 * y * z + + VAR05 * (CONST206 * VAR16 * z + CONST216 * VAR25 * y) + + VAR07 + * (CONST028 * VAR16 * VAR25 + CONST216 * VAR23 * y + CONST226 * VAR14 * z) + + x + * ( + CONST206 * VAR16 * VAR23 + + CONST226 * VAR14 * VAR25 + + CONST259 * VAR12 * z + + CONST311 * VAR21 * y + ) + ) + + g_9 + * ( + CONST015 * VAR01 + + VAR03 * (CONST042 * VAR26 + CONST253 * VAR17) + + VAR05 * (CONST033 * VAR17 * VAR26 + CONST058 * VAR24 + CONST155 * VAR15) + + VAR07 + * ( + CONST032 * VAR17 * VAR24 + + CONST042 * VAR22 + + CONST235 * VAR15 * VAR26 + + CONST361 * VAR13 + ) + + x + * ( + CONST015 * VAR20 + + CONST155 * VAR15 * VAR24 + + CONST253 * VAR17 * VAR22 + - CONST314 * VAR11 + + CONST361 * VAR13 * VAR26 + ) + ) + ) + g_z += ( + g_0 + * ( + CONST093 * VAR20 * x + + CONST210 * VAR03 * VAR26 + + CONST250 * VAR05 * VAR24 + + CONST328 * VAR07 * VAR22 + - CONST378 * VAR01 + ) + + g_1 + * y + * ( + -CONST018 * VAR05 * VAR25 + + CONST018 * VAR07 * VAR23 + + CONST224 * VAR03 * z + - CONST224 * VAR21 * x + ) + + g_10 + * ( + CONST095 * VAR15 * VAR23 + + CONST132 * VAR17 * VAR21 + + CONST265 * VAR13 * VAR25 + + CONST333 * VAR11 * z + + CONST391 * VAR19 + + CONST398 * VAR02 * z + + VAR04 * (CONST131 * VAR17 * z + CONST376 * VAR25) + + VAR06 + * (CONST094 * VAR15 * z + CONST246 * VAR17 * VAR25 + CONST369 * VAR23) + + VAR08 + * ( + CONST137 * VAR15 * VAR25 + + CONST246 * VAR17 * VAR23 + + CONST265 * VAR13 * z + + CONST375 * VAR21 + ) + ) + + g_11 + * ( + CONST009 + * VAR26 + * ( + CONST042 * VAR04 * y + + CONST211 * VAR08 * VAR14 + + CONST251 * VAR06 * VAR16 + + CONST313 * VAR12 + ) + + CONST010 + * VAR24 + * (CONST058 * VAR06 * y + CONST142 * VAR14 + CONST252 * VAR08 * VAR16) + + CONST011 * VAR22 * (CONST042 * VAR08 * y + CONST331 * VAR16) + + CONST015 * VAR02 * y + + CONST026 * VAR10 + + CONST076 * VAR20 * y + + CONST142 * VAR06 * VAR14 + + CONST314 * VAR08 * VAR12 + + CONST331 * VAR04 * VAR16 + ) + + g_12 + * ( + CONST050 * VAR02 * z + + CONST082 * VAR11 * z + + CONST097 * VAR15 * VAR23 + + CONST120 * VAR13 * VAR25 + + CONST262 * VAR17 * VAR21 + - CONST385 * VAR19 + + VAR04 * (CONST273 * VAR25 - CONST311 * VAR17 * z) + + VAR06 * (CONST017 * VAR23 + CONST238 * VAR15 * z) + + VAR08 + * (CONST029 * VAR21 - CONST140 * VAR15 * VAR25 + CONST217 * VAR17 * VAR23) + ) + + g_13 + * ( + VAR12 * (CONST290 * VAR08 - CONST290 * VAR26) + + VAR14 * (CONST049 * VAR24 - CONST186 * VAR06 - CONST307 * VAR08 * VAR26) + + VAR16 + * ( + -CONST164 * VAR22 + + CONST209 * VAR08 * VAR24 + + CONST219 * VAR06 * VAR26 + + CONST266 * VAR04 + ) + + y + * ( + -CONST285 * VAR06 * VAR24 + - CONST297 * VAR04 * VAR26 + + CONST346 * VAR20 + - CONST374 * VAR02 + ) + ) + + g_14 + * ( + CONST104 * VAR02 * z + + CONST114 * VAR15 * VAR23 + + CONST146 * VAR17 * VAR21 + + CONST194 * VAR19 + - CONST239 * VAR13 * VAR25 + + VAR04 * (CONST274 * VAR17 * z - CONST362 * VAR25) + + VAR06 + * (CONST072 * VAR23 + CONST171 * VAR15 * z + CONST240 * VAR17 * VAR25) + + VAR08 + * ( + CONST030 * VAR21 + + CONST114 * VAR17 * VAR23 + - CONST148 * VAR15 * VAR25 + + CONST338 * VAR13 * z + ) + ) + + g_15 + * ( + VAR14 * (CONST185 * VAR08 * VAR26 - CONST222 * VAR24 - CONST223 * VAR06) + + VAR16 + * ( + CONST079 * VAR06 * VAR26 + + CONST134 * VAR08 * VAR24 + + CONST202 * VAR22 + + CONST241 * VAR04 + ) + + y + * ( + CONST046 * VAR02 + + CONST073 * VAR20 + + CONST195 * VAR06 * VAR24 + + CONST223 * VAR08 * VAR22 + ) + ) + + g_16 + * ( + CONST022 * VAR19 + + CONST035 * VAR02 * z + + CONST175 * VAR15 * VAR23 + + CONST291 * VAR17 * VAR21 + + VAR04 * (CONST057 * VAR25 + CONST135 * VAR17 * z) + + VAR06 * (CONST341 * VAR15 * z + CONST346 * VAR23) + + VAR08 + * (CONST108 * VAR15 * VAR25 + CONST158 * VAR17 * VAR23 + CONST337 * VAR21) + ) + + g_17 + * ( + VAR16 + * ( + -CONST044 * VAR06 * VAR26 + + CONST044 * VAR08 * VAR24 + + CONST144 * VAR22 + + CONST277 * VAR04 + ) + + y + * ( + -CONST016 * VAR08 * VAR22 + + CONST059 * VAR02 + + CONST180 * VAR04 * VAR26 + + CONST205 * VAR06 * VAR24 + + CONST351 * VAR20 + ) + ) + + g_18 + * ( + CONST061 * VAR02 * z + + CONST127 * VAR08 * VAR21 + + CONST284 * VAR06 * VAR23 + + CONST306 * VAR04 * VAR25 + + CONST381 * VAR19 + + VAR17 + * ( + CONST039 * VAR04 * z + + CONST081 * VAR08 * VAR23 + + CONST316 * VAR06 * VAR25 + - CONST319 * VAR21 + ) + ) + + g_19 + * y + * ( + CONST062 * VAR02 + + CONST063 * VAR20 + + CONST204 * VAR04 * VAR26 + + CONST204 * VAR08 * VAR22 + + CONST279 * VAR06 * VAR24 + ) + + g_2 + * ( + CONST151 * VAR01 + + CONST162 * VAR07 * VAR22 + + CONST319 * VAR03 * VAR26 + + CONST348 * VAR20 * x + + VAR17 + * ( + -CONST040 * VAR22 * x + - CONST081 * VAR05 * VAR26 + + CONST103 * VAR07 * VAR24 + + CONST319 * VAR03 + ) + ) + + g_20 + * ( + -CONST163 * VAR06 * VAR23 + + CONST212 * VAR08 * VAR21 + - CONST327 * VAR02 * z + + CONST329 * VAR04 * VAR25 + - CONST378 * VAR19 + ) + + g_3 + * ( + VAR16 + * (-CONST183 * VAR23 * x + CONST228 * VAR05 * z + CONST267 * VAR07 * VAR25) + + y + * ( + CONST116 * VAR07 * VAR23 + - CONST234 * VAR05 * VAR25 + + CONST234 * VAR21 * x + + CONST268 * VAR03 * z + ) + ) + + g_4 + * ( + CONST008 * VAR01 + + VAR03 * (CONST303 * VAR17 + CONST377 * VAR26) + + VAR05 * (CONST175 * VAR15 - CONST307 * VAR17 * VAR26 + CONST326 * VAR24) + + VAR07 + * (CONST108 * VAR15 * VAR26 + CONST341 * VAR17 * VAR24 + CONST359 * VAR22) + + x + * (CONST053 * VAR20 + CONST307 * VAR17 * VAR22 + CONST341 * VAR15 * VAR24) + ) + + g_5 + * ( + VAR14 * (CONST147 * VAR07 * z - CONST147 * VAR25 * x) + + VAR16 + * (CONST154 * VAR05 * z + CONST190 * VAR07 * VAR25 + CONST310 * VAR23 * x) + + y + * (CONST156 * VAR21 * x + CONST222 * VAR05 * VAR25 + CONST325 * VAR03 * z) + ) + + g_6 + * ( + CONST177 * VAR01 + + VAR03 * (CONST030 * VAR26 + CONST321 * VAR17) + + VAR05 * (-CONST193 * VAR15 + CONST229 * VAR17 * VAR26) + + VAR07 * (CONST239 * VAR13 + CONST258 * VAR17 * VAR24 + CONST362 * VAR22) + + x + * ( + CONST148 * VAR15 * VAR24 + - CONST338 * VAR13 * VAR26 + + CONST357 * VAR17 * VAR22 + + CONST372 * VAR20 + ) + ) + + g_7 + * ( + -CONST221 * VAR12 * x * z + + VAR14 * (CONST136 * VAR07 * z + CONST260 * VAR25 * x) + + VAR16 + * (CONST119 * VAR05 * z - CONST145 * VAR23 * x + CONST342 * VAR07 * VAR25) + + y + * ( + CONST237 * VAR07 * VAR23 + + CONST297 * VAR05 * VAR25 + + CONST298 * VAR21 * x + ) + ) + + g_8 + * ( + -CONST397 * VAR01 + + VAR03 * (CONST031 * VAR26 + CONST344 * VAR17) + + VAR05 * (CONST055 * VAR24 + CONST160 * VAR17 * VAR26 + CONST173 * VAR15) + + VAR07 + * ( + CONST051 * VAR22 + + CONST143 * VAR15 * VAR26 + + CONST231 * VAR13 + + CONST322 * VAR17 * VAR24 + ) + + x + * ( + CONST024 * VAR20 + + CONST082 * VAR11 + + CONST196 * VAR17 * VAR22 + + CONST295 * VAR13 * VAR26 + + CONST330 * VAR15 * VAR24 + ) + ) + + g_9 + * ( + CONST070 * VAR03 * y * z + + VAR05 * (CONST121 * VAR25 * y + CONST168 * VAR16 * z) + + VAR07 + * (CONST121 * VAR23 * y + CONST261 * VAR16 * VAR25 - CONST361 * VAR14 * z) + + x + * ( + CONST070 * VAR21 * y + + CONST167 * VAR16 * VAR23 + + CONST264 * VAR12 * z + - CONST361 * VAR14 * VAR25 + ) + ) + ) + # write out gradients + tl.store( + coord_grad_ptr + coord_row_offset, g_x, mask=coord_row_offset < coord_numel + ) + tl.store( + coord_grad_ptr + coord_row_offset + 1, + g_y, + mask=coord_row_offset + 1 < coord_numel, + ) + tl.store( + coord_grad_ptr + coord_row_offset + 2, + g_z, + mask=coord_row_offset + 2 < coord_numel, + ) diff --git a/src/equitriton/sph_harm/direct/y_2.py b/src/equitriton/sph_harm/direct/y_2.py new file mode 100644 index 0000000..3448158 --- /dev/null +++ b/src/equitriton/sph_harm/direct/y_2.py @@ -0,0 +1,248 @@ +import triton +import torch +from triton import language as tl + +from equitriton.utils import calculate_lastdim_num_blocks + +__all__ = ["SecondOrderSphericalHarmonic"] + + +class SecondOrderSphericalHarmonic(torch.autograd.Function): + @staticmethod + def forward( + ctx, + coords: torch.Tensor, + output_tensor: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + block_size: int = 64, + col_offset: int = 0, + ): + num_projections = 5 # 2l + 1 + # allocate a tensor if one isn't given + if not isinstance(output_tensor, torch.Tensor): + output_tensor = torch.empty( + (*coords.shape[:-1], num_projections), + dtype=coords.dtype, + device=coords.device, + ) + coord_numel = coords.numel() + output_numel = output_tensor.numel() + num_blocks = calculate_lastdim_num_blocks(coords, block_size) + # apply the kernel + second_order_fwd[num_blocks,]( + coords, + output_tensor, + block_size, + coord_numel, + output_numel, + col_offset, + output_tensor.stride(-2), + ) + ctx.save_for_backward(coords) + return output_tensor + + @staticmethod + def backward( + ctx, + sph_grad_tensor: torch.Tensor, + block_size: int = 64, + col_offset: int = 0, + ) -> torch.Tensor: + (coords,) = ctx.saved_tensors + coord_grad_output = torch.zeros_like(coords) + num_blocks = calculate_lastdim_num_blocks(coords, block_size) + # call backward kernel + second_order_bwd[num_blocks,]( + coords, + coord_grad_output, + sph_grad_tensor, + block_size, + coords.numel(), + sph_grad_tensor.numel(), + col_offset, + sph_grad_tensor.stride(-2), + ) + return coord_grad_output + + +def _torch_fwd(coords: torch.Tensor) -> torch.Tensor: + """ + PyTorch implementation of the kernel. This is designed + purely for unit testing to ensure that the Triton implementation + is behaving as intended. + + This function is generically named to make it easier for + it to be called programmatically: it is _not_ intended + to be used manually. + + Parameters + ---------- + coords : torch.Tensor + N-d tensor, where the last dimension corresponds to + xyz values. + + Returns + ------- + torch.Tensor + N-d tensor, where the last dimension corresponds to + each projection of the second order spherical harmonic. + """ + x = coords[..., 0].contiguous().unsqueeze(-1) + y = coords[..., 1].contiguous().unsqueeze(-1) + z = coords[..., 2].contiguous().unsqueeze(-1) + CONST_00 = 3.87298334620742 + CONST_01 = 2.23606797749979 + CONST_02 = -1.11803398874989 + CONST_03 = 1.93649167310371 + Y20 = CONST_00 * x * z + Y21 = CONST_00 * x * y + Y23 = CONST_00 * y * z # looks jarring but just helping the compiler ;) + Y22 = CONST_02 * x * x + CONST_01 * y * y + CONST_02 * z * z + Y24 = -CONST_03 * x * x + CONST_03 * z * z + return torch.cat([Y20, Y21, Y22, Y23, Y24], dim=-1) + + +@triton.jit +def second_order_fwd( + coord_ptr: tl.tensor, + output_ptr: tl.tensor, + block_size: tl.constexpr, + coord_numel: tl.constexpr, + output_numel: tl.constexpr, + col_offset: tl.constexpr, + output_stride: tl.constexpr, +): + # these are hardcoded because they are predetermined; + coord_stride = 3 + # work out the row offsets + block_id = tl.program_id(0) + coord_striding = tl.arange(0, block_size) * coord_stride + # as the name suggests, this is effectively every node/atom + coord_row_offset = coord_striding + (block_size * coord_stride * block_id) + x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel) + y = tl.load( + coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + z = tl.load( + coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + CONST_00 = 3.87298334620742 + CONST_01 = 2.23606797749979 + CONST_02 = -1.11803398874989 + CONST_03 = 1.93649167310371 + Y20 = CONST_00 * x * z + Y21 = CONST_00 * x * y + Y23 = CONST_00 * y * z # looks jarring but just helping the compiler ;) + Y22 = CONST_02 * x * x + CONST_01 * y * y + CONST_02 * z * z + Y24 = -CONST_03 * x * x + CONST_03 * z * z + output_striding = tl.arange(0, block_size) * output_stride + output_row_offset = ( + output_striding + (block_size * output_stride * block_id) + col_offset + ) + tl.store(output_ptr + output_row_offset, Y20, mask=output_row_offset < output_numel) + tl.store( + output_ptr + output_row_offset + 1, + Y21, + mask=output_row_offset + 1 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 2, + Y22, + mask=output_row_offset + 2 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 3, + Y23, + mask=output_row_offset + 3 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 4, + Y24, + mask=output_row_offset + 4 < output_numel, + ) + + +@triton.jit +def second_order_bwd( + coord_ptr: tl.tensor, + coord_grad_ptr: tl.tensor, + sph_grad_ptr: tl.tensor, + block_size: tl.constexpr, + coord_numel: tl.constexpr, + output_numel: tl.constexpr, + col_offset: tl.constexpr, + output_stride: tl.constexpr, +): + # work out the row offsets + block_id = tl.program_id(0) + # these are hardcoded because they are predetermined; + coord_stride = 3 + coord_striding = tl.arange(0, block_size) * coord_stride + # as the name suggests, this is effectively every node/atom + coord_row_offset = coord_striding + (block_size * coord_stride * block_id) + x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel) + y = tl.load( + coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + z = tl.load( + coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + output_striding = tl.arange(0, block_size) * output_stride + output_row_offset = ( + output_striding + (block_size * output_stride * block_id) + col_offset + ) + CONST_00 = 3.87298334620742 + CONST_01 = 2.23606797749979 + CONST_02 = 4.47213595499958 + # load in gradients w.r.t. spherical harmonic projections + g_Y20 = tl.load( + sph_grad_ptr + output_row_offset, mask=output_row_offset < output_numel + ) + g_Y21 = tl.load( + sph_grad_ptr + output_row_offset + 1, mask=output_row_offset + 1 < output_numel + ) + g_Y22 = tl.load( + sph_grad_ptr + output_row_offset + 2, mask=output_row_offset + 2 < output_numel + ) + g_Y23 = tl.load( + sph_grad_ptr + output_row_offset + 3, mask=output_row_offset + 3 < output_numel + ) + g_Y24 = tl.load( + sph_grad_ptr + output_row_offset + 4, mask=output_row_offset + 4 < output_numel + ) + g_x = tl.load( + coord_grad_ptr + coord_row_offset, mask=coord_row_offset < coord_numel + ) + g_y = tl.load( + coord_grad_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + g_z = tl.load( + coord_grad_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + g_x += ( + CONST_00 * g_Y20 * z + + CONST_00 * g_Y21 * y + - CONST_01 * g_Y22 * x + - CONST_00 * g_Y24 * x + ) + g_y += CONST_00 * g_Y21 * x + CONST_02 * g_Y22 * y + CONST_00 * g_Y23 * z + g_z += ( + CONST_00 * g_Y20 * x + - CONST_01 * g_Y22 * z + + CONST_00 * g_Y23 * y + + CONST_00 * g_Y24 * z + ) + # write out gradients + tl.store( + coord_grad_ptr + coord_row_offset, g_x, mask=coord_row_offset < coord_numel + ) + tl.store( + coord_grad_ptr + coord_row_offset + 1, + g_y, + mask=coord_row_offset + 1 < coord_numel, + ) + tl.store( + coord_grad_ptr + coord_row_offset + 2, + g_z, + mask=coord_row_offset + 2 < coord_numel, + ) diff --git a/src/equitriton/sph_harm/direct/y_3.py b/src/equitriton/sph_harm/direct/y_3.py new file mode 100644 index 0000000..1d39314 --- /dev/null +++ b/src/equitriton/sph_harm/direct/y_3.py @@ -0,0 +1,321 @@ +import triton +import torch +from triton import language as tl + +from equitriton.utils import calculate_lastdim_num_blocks + +__all__ = ["ThirdOrderSphericalHarmonic"] + + +class ThirdOrderSphericalHarmonic(torch.autograd.Function): + @staticmethod + def forward( + ctx, + coords: torch.Tensor, + output_tensor: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + block_size: int = 64, + col_offset: int = 0, + ): + # allocate a tensor if one isn't given + if not isinstance(output_tensor, torch.Tensor): + output_tensor = torch.empty( + (*coords.shape[:-1], 7), dtype=coords.dtype, device=coords.device + ) + coord_numel = coords.numel() + output_numel = output_tensor.numel() + num_blocks = calculate_lastdim_num_blocks(coords, block_size) + # apply the kernel + third_order_fwd[num_blocks,]( + coords, + output_tensor, + block_size, + coord_numel, + output_numel, + col_offset, + output_tensor.stride(-2), + ) + ctx.save_for_backward(coords) + return output_tensor + + @staticmethod + def backward( + ctx, + sph_grad_tensor: torch.Tensor, + coord_grad_output: torch.Tensor | None = None, + block_size: int = 64, + col_offset: int = 0, + ) -> torch.Tensor: + (coords,) = ctx.saved_tensors + if not isinstance(coord_grad_output, torch.Tensor): + coord_grad_output = torch.zeros_like(coords) + num_blocks = calculate_lastdim_num_blocks(coords, block_size) + # call backward kernel + third_order_bwd[num_blocks,]( + coords, + coord_grad_output, + sph_grad_tensor, + block_size, + coords.numel(), + sph_grad_tensor.numel(), + col_offset, + sph_grad_tensor.stride(-2), + ) + return coord_grad_output + + +def _torch_fwd(coords: torch.Tensor) -> torch.Tensor: + """ + PyTorch implementation of the kernel. This is designed + purely for unit testing to ensure that the Triton implementation + is behaving as intended. + + This function is generically named to make it easier for + it to be called programmatically: it is _not_ intended + to be used manually. + + Parameters + ---------- + coords : torch.Tensor + N-d tensor, where the last dimension corresponds to + xyz values. + + Returns + ------- + torch.Tensor + N-d tensor, where the last dimension corresponds to + each projection of the second order spherical harmonic. + """ + x = coords[..., 0].contiguous().unsqueeze(-1) + y = coords[..., 1].contiguous().unsqueeze(-1) + z = coords[..., 2].contiguous().unsqueeze(-1) + # -------------------- variable and constant definitions + CONST000 = 2.64575131106459 + CONST002 = 5.12347538297980 + CONST004 = 6.48074069840786 + CONST005 = 10.2469507659596 + CONST006 = -2.09165006633519 + CONST007 = -1 + CONST008 = -6.27495019900557 + CONST009 = -3.96862696659689 + CONST010 = -1.62018517460197 + VAR07 = x * x * x + VAR08 = x * x + VAR16 = y * y * y + VAR17 = y * y + VAR25 = z * z * z + VAR26 = z * z + # -------------------- kernel implementations + Y00 = CONST006 * VAR07 - CONST008 * VAR26 * x + Y01 = CONST005 * x * y * z + Y02 = CONST010 * VAR07 + x * (CONST004 * VAR17 + CONST010 * VAR26) + Y03 = CONST000 * VAR16 + CONST009 * VAR08 * y + CONST009 * VAR26 * y + Y04 = CONST010 * VAR25 + z * (CONST004 * VAR17 + CONST010 * VAR08) + Y05 = CONST002 * y * (CONST007 * VAR08 + VAR26) + Y06 = -CONST006 * VAR25 + CONST008 * VAR08 * z + tensors = [Y00, Y01, Y02, Y03, Y04, Y05, Y06] + return torch.cat(tensors, dim=-1) + + +@triton.jit +def third_order_fwd( + coord_ptr: tl.tensor, + output_ptr: tl.tensor, + block_size: tl.constexpr, + coord_numel: tl.constexpr, + output_numel: tl.constexpr, + col_offset: tl.constexpr, + output_stride: tl.constexpr, +): + # these are hardcoded because they are predetermined; + coord_stride = 3 + # work out the row offsets + block_id = tl.program_id(0) + coord_striding = tl.arange(0, block_size) * coord_stride + # as the name suggests, this is effectively every node/atom + coord_row_offset = coord_striding + (block_size * coord_stride * block_id) + x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel) + y = tl.load( + coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + z = tl.load( + coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + # -------------------- variable and constant definitions + CONST000 = 2.64575131106459 + CONST002 = 5.12347538297980 + CONST004 = 6.48074069840786 + CONST005 = 10.2469507659596 + CONST006 = -2.09165006633519 + CONST007 = -1 + CONST008 = -6.27495019900557 + CONST009 = -3.96862696659689 + CONST010 = -1.62018517460197 + VAR07 = x * x * x + VAR08 = x * x + VAR16 = y * y * y + VAR17 = y * y + VAR25 = z * z * z + VAR26 = z * z + # -------------------- kernel implementations + Y00 = CONST006 * VAR07 - CONST008 * VAR26 * x + Y01 = CONST005 * x * y * z + Y02 = CONST010 * VAR07 + x * (CONST004 * VAR17 + CONST010 * VAR26) + Y03 = CONST000 * VAR16 + CONST009 * VAR08 * y + CONST009 * VAR26 * y + Y04 = CONST010 * VAR25 + z * (CONST004 * VAR17 + CONST010 * VAR08) + Y05 = CONST002 * y * (CONST007 * VAR08 + VAR26) + Y06 = -CONST006 * VAR25 + CONST008 * VAR08 * z + output_striding = tl.arange(0, block_size) * output_stride + # zero on the row offset is the first spherical harmonic term of this order + output_row_offset = ( + output_striding + (block_size * output_stride * block_id) + col_offset + ) + tl.store(output_ptr + output_row_offset, Y00, mask=output_row_offset < output_numel) + tl.store( + output_ptr + output_row_offset + 1, + Y01, + mask=output_row_offset + 1 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 2, + Y02, + mask=output_row_offset + 2 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 3, + Y03, + mask=output_row_offset + 3 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 4, + Y04, + mask=output_row_offset + 4 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 5, + Y05, + mask=output_row_offset + 5 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 6, + Y06, + mask=output_row_offset + 6 < output_numel, + ) + + +@triton.jit +def third_order_bwd( + coord_ptr: tl.tensor, + coord_grad_ptr: tl.tensor, + sph_grad_ptr: tl.tensor, + block_size: tl.constexpr, + coord_numel: tl.constexpr, + output_numel: tl.constexpr, + col_offset: tl.constexpr, + output_stride: tl.constexpr, +): + # work out the row offsets + block_id = tl.program_id(0) + # these are hardcoded because they are predetermined; + coord_stride = 3 + coord_striding = tl.arange(0, block_size) * coord_stride + # as the name suggests, this is effectively every node/atom + coord_row_offset = coord_striding + (block_size * coord_stride * block_id) + x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel) + y = tl.load( + coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + z = tl.load( + coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + output_striding = tl.arange(0, block_size) * output_stride + # zero on the row offset is the first spherical harmonic term of this order + output_row_offset = ( + output_striding + (block_size * output_stride * block_id) + col_offset + ) + # load in gradients w.r.t. spherical harmonic projections + g_0 = tl.load( + sph_grad_ptr + output_row_offset, mask=output_row_offset < output_numel + ) + g_1 = tl.load( + sph_grad_ptr + output_row_offset + 1, mask=output_row_offset + 1 < output_numel + ) + g_2 = tl.load( + sph_grad_ptr + output_row_offset + 2, mask=output_row_offset + 2 < output_numel + ) + g_3 = tl.load( + sph_grad_ptr + output_row_offset + 3, mask=output_row_offset + 3 < output_numel + ) + g_4 = tl.load( + sph_grad_ptr + output_row_offset + 4, mask=output_row_offset + 4 < output_numel + ) + g_5 = tl.load( + sph_grad_ptr + output_row_offset + 5, mask=output_row_offset + 5 < output_numel + ) + g_6 = tl.load( + sph_grad_ptr + output_row_offset + 6, mask=output_row_offset + 6 < output_numel + ) + # -------------------- variable and constant definitions + CONST002 = 6.48074069840786 + CONST005 = 12.9614813968157 + CONST007 = -3.96862696659689 + CONST008 = -12.5499003980111 + CONST009 = -10.2469507659596 + CONST010 = -7.93725393319377 + CONST011 = -6.27495019900557 + CONST012 = -5.12347538297980 + CONST013 = -4.86055552380590 + CONST014 = -3.24037034920393 + CONST015 = -1.62018517460197 + VAR08 = x * x + VAR17 = y * y + VAR26 = z * z + # -------------------- kernel implementations + g_x = tl.load( + coord_grad_ptr + coord_row_offset, mask=coord_row_offset < coord_numel + ) + g_y = tl.load( + coord_grad_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + g_z = tl.load( + coord_grad_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + g_x += ( + CONST008 * g_6 * x * z + - CONST009 * g_1 * y * z + + CONST009 * g_5 * x * y + + CONST010 * g_3 * x * y + + CONST014 * g_4 * x * z + + g_0 * (CONST011 * VAR08 - CONST011 * VAR26) + + g_2 * (CONST002 * VAR17 + CONST013 * VAR08 + CONST015 * VAR26) + ) + g_y += ( + CONST005 * g_2 * x * y + + CONST005 * g_4 * y * z + - CONST009 * g_1 * x * z + + g_3 * (CONST007 * VAR08 + CONST007 * VAR26 - CONST010 * VAR17) + + g_5 * (CONST012 * VAR08 - CONST012 * VAR26) + ) + g_z += ( + -CONST008 * g_0 * x * z + - CONST009 * g_1 * x * y + - CONST009 * g_5 * y * z + + CONST010 * g_3 * y * z + + CONST014 * g_2 * x * z + + g_4 * (CONST002 * VAR17 + CONST013 * VAR26 + CONST015 * VAR08) + + g_6 * (CONST011 * VAR08 - CONST011 * VAR26) + ) + # write out gradients + tl.store( + coord_grad_ptr + coord_row_offset, g_x, mask=coord_row_offset < coord_numel + ) + tl.store( + coord_grad_ptr + coord_row_offset + 1, + g_y, + mask=coord_row_offset + 1 < coord_numel, + ) + tl.store( + coord_grad_ptr + coord_row_offset + 2, + g_z, + mask=coord_row_offset + 2 < coord_numel, + ) diff --git a/src/equitriton/sph_harm/direct/y_4.py b/src/equitriton/sph_harm/direct/y_4.py new file mode 100644 index 0000000..ba964c3 --- /dev/null +++ b/src/equitriton/sph_harm/direct/y_4.py @@ -0,0 +1,394 @@ +import triton +import torch +from triton import language as tl + +from equitriton.utils import calculate_lastdim_num_blocks + +__all__ = ["FourthOrderSphericalHarmonic"] + + +class FourthOrderSphericalHarmonic(torch.autograd.Function): + @staticmethod + def forward( + ctx, + coords: torch.Tensor, + output_tensor: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + block_size: int = 64, + col_offset: int = 0, + ): + if not isinstance(output_tensor, torch.Tensor): + output_tensor = torch.empty( + (*coords.shape[:-1], 9), dtype=coords.dtype, device=coords.device + ) + coord_numel = coords.numel() + output_numel = output_tensor.numel() + num_blocks = calculate_lastdim_num_blocks(coords, block_size) + # apply the kernel + fourth_order_fwd[num_blocks,]( + coords, + output_tensor, + block_size, + coord_numel, + output_numel, + col_offset, + output_tensor.stride(-2), + ) + ctx.save_for_backward(coords) + return output_tensor + + @staticmethod + def backward( + ctx, + sph_grad_tensor: torch.Tensor, + block_size: int = 64, + col_offset: int = 0, + ) -> torch.Tensor: + (coords,) = ctx.saved_tensors + coord_grad_output = torch.zeros_like(coords) + num_blocks = calculate_lastdim_num_blocks(coords, block_size) + # call backward kernel + fourth_order_bwd[num_blocks,]( + coords, + coord_grad_output, + sph_grad_tensor, + block_size, + coords.numel(), + sph_grad_tensor.numel(), + col_offset, + sph_grad_tensor.stride(-2), + ) + return coord_grad_output + + +def _torch_fwd(coords: torch.Tensor) -> torch.Tensor: + """ + PyTorch implementation of the kernel. This is designed + purely for unit testing to ensure that the Triton implementation + is behaving as intended. + + This function is generically named to make it easier for + it to be called programmatically: it is _not_ intended + to be used manually. + + Parameters + ---------- + coords : torch.Tensor + N-d tensor, where the last dimension corresponds to + xyz values. + + Returns + ------- + torch.Tensor + N-d tensor, where the last dimension corresponds to + each projection of the second order spherical harmonic. + """ + x = coords[..., 0].contiguous().unsqueeze(-1) + y = coords[..., 1].contiguous().unsqueeze(-1) + z = coords[..., 2].contiguous().unsqueeze(-1) + # -------------------- variable and constant definitions + CONST000 = 1.12500000000000 + CONST001 = 2.25000000000000 + CONST002 = 3.00000000000000 + CONST005 = 2.21852991866236 + CONST007 = 9.48683298050514 + CONST010 = 20.1246117974981 + CONST011 = -18.8248505970167 + CONST012 = -13.3111795119741 + CONST013 = -10.0623058987491 + CONST014 = -9.00000000000000 + CONST015 = -8.87411967464942 + CONST016 = -7.11512473537885 + CONST017 = -6.27495019900557 + CONST018 = -3.35410196624968 + CONST019 = -1.67705098312484 + VAR06 = x * x * x * x + VAR07 = x * x * x + VAR08 = x * x + VAR15 = y * y * y * y + VAR16 = y * y * y + VAR17 = y * y + VAR24 = z * z * z * z + VAR25 = z * z * z + VAR26 = z * z + # -------------------- kernel implementations + Y00 = CONST015 * VAR07 * z - CONST015 * VAR25 * x + Y01 = y * (-CONST011 * VAR26 * x + CONST017 * VAR07) + Y02 = CONST018 * VAR07 * z + x * (CONST010 * VAR17 * z + CONST018 * VAR25) + Y03 = CONST016 * VAR07 * y + x * (CONST007 * VAR16 + CONST016 * VAR26 * y) + Y04 = ( + CONST000 * VAR06 + + CONST000 * VAR24 + + CONST002 * VAR15 + + CONST014 * VAR17 * VAR26 + + VAR08 * (CONST001 * VAR26 + CONST014 * VAR17) + ) + Y05 = CONST016 * VAR25 * y + z * (CONST007 * VAR16 + CONST016 * VAR08 * y) + Y06 = ( + -CONST019 * VAR06 + + CONST019 * VAR24 + + VAR17 * (CONST013 * VAR08 - CONST013 * VAR26) + ) + Y07 = y * (CONST011 * VAR08 * z - CONST017 * VAR25) + Y08 = CONST005 * VAR06 + CONST005 * VAR24 + CONST012 * VAR08 * VAR26 + tensors = [Y00, Y01, Y02, Y03, Y04, Y05, Y06, Y07, Y08] + return torch.cat(tensors, dim=-1) + + +@triton.jit +def fourth_order_fwd( + coord_ptr: tl.tensor, + output_ptr: tl.tensor, + block_size: tl.constexpr, + coord_numel: tl.constexpr, + output_numel: tl.constexpr, + col_offset: tl.constexpr, + output_stride: tl.constexpr, +): + # these are hardcoded because they are predetermined; + coord_stride = 3 + # work out the row offsets + block_id = tl.program_id(0) + coord_striding = tl.arange(0, block_size) * coord_stride + # as the name suggests, this is effectively every node/atom + coord_row_offset = coord_striding + (block_size * coord_stride * block_id) + x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel) + y = tl.load( + coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + z = tl.load( + coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + # -------------------- variable and constant definitions + CONST000 = 1.12500000000000 + CONST001 = 2.25000000000000 + CONST002 = 3.00000000000000 + CONST005 = 2.21852991866236 + CONST007 = 9.48683298050514 + CONST010 = 20.1246117974981 + CONST011 = -18.8248505970167 + CONST012 = -13.3111795119741 + CONST013 = -10.0623058987491 + CONST014 = -9.00000000000000 + CONST015 = -8.87411967464942 + CONST016 = -7.11512473537885 + CONST017 = -6.27495019900557 + CONST018 = -3.35410196624968 + CONST019 = -1.67705098312484 + VAR06 = x * x * x * x + VAR07 = x * x * x + VAR08 = x * x + VAR15 = y * y * y * y + VAR16 = y * y * y + VAR17 = y * y + VAR24 = z * z * z * z + VAR25 = z * z * z + VAR26 = z * z + # -------------------- kernel implementations + Y00 = CONST015 * VAR07 * z - CONST015 * VAR25 * x + Y01 = y * (-CONST011 * VAR26 * x + CONST017 * VAR07) + Y02 = CONST018 * VAR07 * z + x * (CONST010 * VAR17 * z + CONST018 * VAR25) + Y03 = CONST016 * VAR07 * y + x * (CONST007 * VAR16 + CONST016 * VAR26 * y) + Y04 = ( + CONST000 * VAR06 + + CONST000 * VAR24 + + CONST002 * VAR15 + + CONST014 * VAR17 * VAR26 + + VAR08 * (CONST001 * VAR26 + CONST014 * VAR17) + ) + Y05 = CONST016 * VAR25 * y + z * (CONST007 * VAR16 + CONST016 * VAR08 * y) + Y06 = ( + -CONST019 * VAR06 + + CONST019 * VAR24 + + VAR17 * (CONST013 * VAR08 - CONST013 * VAR26) + ) + Y07 = y * (CONST011 * VAR08 * z - CONST017 * VAR25) + Y08 = CONST005 * VAR06 + CONST005 * VAR24 + CONST012 * VAR08 * VAR26 + output_striding = tl.arange(0, block_size) * output_stride + output_row_offset = ( + output_striding + (block_size * output_stride * block_id) + col_offset + ) + tl.store(output_ptr + output_row_offset, Y00, mask=output_row_offset < output_numel) + tl.store( + output_ptr + output_row_offset + 1, + Y01, + mask=output_row_offset + 1 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 2, + Y02, + mask=output_row_offset + 2 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 3, + Y03, + mask=output_row_offset + 3 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 4, + Y04, + mask=output_row_offset + 4 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 5, + Y05, + mask=output_row_offset + 5 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 6, + Y06, + mask=output_row_offset + 6 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 7, + Y07, + mask=output_row_offset + 7 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 8, + Y08, + mask=output_row_offset + 8 < output_numel, + ) + + +@triton.jit +def fourth_order_bwd( + coord_ptr: tl.tensor, + coord_grad_ptr: tl.tensor, + sph_grad_ptr: tl.tensor, + block_size: tl.constexpr, + coord_numel: tl.constexpr, + output_numel: tl.constexpr, + col_offset: tl.constexpr, + output_stride: tl.constexpr, +): + # work out the row offsets + block_id = tl.program_id(0) + # these are hardcoded because they are predetermined; + coord_stride = 3 + coord_striding = tl.arange(0, block_size) * coord_stride + # as the name suggests, this is effectively every node/atom + coord_row_offset = coord_striding + (block_size * coord_stride * block_id) + x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel) + y = tl.load( + coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + z = tl.load( + coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + output_striding = tl.arange(0, block_size) * output_stride + output_row_offset = ( + output_striding + (block_size * output_stride * block_id) + col_offset + ) + # load in gradients w.r.t. spherical harmonic projections + g_0 = tl.load( + sph_grad_ptr + output_row_offset, mask=output_row_offset < output_numel + ) + g_1 = tl.load( + sph_grad_ptr + output_row_offset + 1, mask=output_row_offset + 1 < output_numel + ) + g_2 = tl.load( + sph_grad_ptr + output_row_offset + 2, mask=output_row_offset + 2 < output_numel + ) + g_3 = tl.load( + sph_grad_ptr + output_row_offset + 3, mask=output_row_offset + 3 < output_numel + ) + g_4 = tl.load( + sph_grad_ptr + output_row_offset + 4, mask=output_row_offset + 4 < output_numel + ) + g_5 = tl.load( + sph_grad_ptr + output_row_offset + 5, mask=output_row_offset + 5 < output_numel + ) + g_6 = tl.load( + sph_grad_ptr + output_row_offset + 6, mask=output_row_offset + 6 < output_numel + ) + g_7 = tl.load( + sph_grad_ptr + output_row_offset + 7, mask=output_row_offset + 7 < output_numel + ) + g_8 = tl.load( + sph_grad_ptr + output_row_offset + 8, mask=output_row_offset + 8 < output_numel + ) + # -------------------- variable and constant definitions + CONST000 = 2.00000000000000 + CONST001 = 4.50000000000000 + CONST002 = 2.25000000000000 + CONST006 = 9.48683298050514 + CONST008 = 12.0000000000000 + CONST012 = 28.4604989415154 + CONST014 = 40.2492235949962 + CONST015 = -37.6497011940334 + CONST016 = -6.70820393249937 + CONST017 = -26.6223590239483 + CONST018 = -21.3453742061366 + CONST019 = -20.1246117974981 + CONST020 = -18.8248505970167 + CONST021 = -18.0000000000000 + CONST022 = -14.2302494707577 + CONST023 = -10.0623058987491 + CONST024 = -9.00000000000000 + CONST025 = -8.87411967464942 + CONST026 = -7.11512473537885 + CONST027 = -6.27495019900557 + CONST028 = -3.35410196624968 + VAR07 = x * x * x + VAR08 = x * x + VAR16 = y * y * y + VAR17 = y * y + VAR25 = z * z * z + VAR26 = z * z + # -------------------- kernel implementations + g_x = tl.load( + coord_grad_ptr + coord_row_offset, mask=coord_row_offset < coord_numel + ) + g_y = tl.load( + coord_grad_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + g_z = tl.load( + coord_grad_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + g_x += ( + CONST015 * g_7 * x * y * z + + CONST022 * g_5 * x * y * z + + g_0 * (CONST017 * VAR08 * z - CONST025 * VAR25) + + g_1 * y * (CONST020 * VAR08 - CONST020 * VAR26) + + g_2 * (-CONST019 * VAR17 * z + CONST023 * VAR08 * z + CONST028 * VAR25) + + g_3 * (CONST006 * VAR16 + CONST018 * VAR08 * y + CONST026 * VAR26 * y) + + g_4 + * (CONST000 * x * (CONST002 * VAR26 + CONST024 * VAR17) + CONST001 * VAR07) + + g_6 * (-CONST016 * VAR07 + CONST019 * VAR17 * x) + + g_8 * (CONST017 * VAR26 * x - CONST025 * VAR07) + ) + g_y += ( + CONST000 * g_6 * y * (CONST023 * VAR08 - CONST023 * VAR26) + + CONST014 * g_2 * x * y * z + + g_1 * (-CONST020 * VAR26 * x + CONST027 * VAR07) + + g_3 * (CONST026 * VAR07 + x * (CONST012 * VAR17 + CONST026 * VAR26)) + + g_4 * (CONST008 * VAR16 + CONST021 * VAR08 * y + CONST021 * VAR26 * y) + + g_5 * (CONST026 * VAR25 + z * (CONST012 * VAR17 + CONST026 * VAR08)) + + g_7 * (CONST020 * VAR08 * z - CONST027 * VAR25) + ) + g_z += ( + -CONST015 * g_1 * x * y * z + + CONST022 * g_3 * x * y * z + + g_0 * (-CONST017 * VAR26 * x + CONST025 * VAR07) + + g_2 * (CONST028 * VAR07 + x * (-CONST019 * VAR17 + CONST023 * VAR26)) + + g_4 * (CONST001 * VAR08 * z + CONST001 * VAR25 + CONST021 * VAR17 * z) + + g_5 * (CONST006 * VAR16 + CONST018 * VAR26 * y + CONST026 * VAR08 * y) + + g_6 * (CONST016 * VAR25 - CONST019 * VAR17 * z) + + g_7 * y * (CONST020 * VAR08 - CONST020 * VAR26) + + g_8 * (CONST017 * VAR08 * z - CONST025 * VAR25) + ) + # write out gradients + tl.store( + coord_grad_ptr + coord_row_offset, g_x, mask=coord_row_offset < coord_numel + ) + tl.store( + coord_grad_ptr + coord_row_offset + 1, + g_y, + mask=coord_row_offset + 1 < coord_numel, + ) + tl.store( + coord_grad_ptr + coord_row_offset + 2, + g_z, + mask=coord_row_offset + 2 < coord_numel, + ) diff --git a/src/equitriton/sph_harm/direct/y_5.py b/src/equitriton/sph_harm/direct/y_5.py new file mode 100644 index 0000000..22065af --- /dev/null +++ b/src/equitriton/sph_harm/direct/y_5.py @@ -0,0 +1,552 @@ +import triton +import torch +from triton import language as tl + +from equitriton.utils import calculate_lastdim_num_blocks + +__all__ = ["FifthOrderSphericalHarmonic"] + + +class FifthOrderSphericalHarmonic(torch.autograd.Function): + @staticmethod + def forward( + ctx, + coords: torch.Tensor, + output_tensor: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + block_size: int = 64, + col_offset: int = 0, + ): + if not isinstance(output_tensor, torch.Tensor): + output_tensor = torch.empty( + (*coords.shape[:-1], 11), dtype=coords.dtype, device=coords.device + ) + coord_numel = coords.numel() + output_numel = output_tensor.numel() + num_blocks = calculate_lastdim_num_blocks(coords, block_size) + # apply the kernel + fifth_order_fwd[num_blocks,]( + coords, + output_tensor, + block_size, + coord_numel, + output_numel, + col_offset, + output_tensor.stride(-2), + ) + ctx.save_for_backward(coords) + return output_tensor + + @staticmethod + def backward( + ctx, + sph_grad_tensor: torch.Tensor, + block_size: int = 64, + col_offset: int = 0, + ) -> torch.Tensor: + (coords,) = ctx.saved_tensors + coord_grad_output = torch.zeros_like(coords) + num_blocks = calculate_lastdim_num_blocks(coords, block_size) + # call backward kernel + fifth_order_bwd[num_blocks,]( + coords, + coord_grad_output, + sph_grad_tensor, + block_size, + coords.numel(), + sph_grad_tensor.numel(), + col_offset, + sph_grad_tensor.stride(-2), + ) + return coord_grad_output + + +def _torch_fwd(coords: torch.Tensor) -> torch.Tensor: + """ + PyTorch implementation of the kernel. This is designed + purely for unit testing to ensure that the Triton implementation + is behaving as intended. + + Parameters + ---------- + coords : torch.Tensor + N-d tensor, where the last dimension corresponds to + xyz values. + + Returns + ------- + torch.Tensor + N-d tensor, where the last dimension corresponds to + each projection of the second order spherical harmonic. + """ + x = coords[..., 0].contiguous().unsqueeze(-1) + y = coords[..., 1].contiguous().unsqueeze(-1) + z = coords[..., 2].contiguous().unsqueeze(-1) + # -------------------- variable and constant definitions + CONST000 = 1.73430461568895 + CONST001 = 2.32681380862329 + CONST002 = 1.60565407233314 + CONST003 = 3.21130814466628 + CONST004 = 3.31662479035540 + CONST005 = 6.21867148191637 + CONST006 = 6.21867148191637 + CONST007 = 1.60565407233314 + CONST009 = 11.6340690431164 + CONST010 = 12.8452325786651 + CONST011 = 12.4373429638327 + CONST012 = 12.8452325786651 + CONST013 = 13.8744369255116 + CONST017 = 33.9852909359329 + CONST018 = 7.35803132638072 + CONST020 = -44.1481879582843 + CONST021 = -41.6233107765348 + CONST022 = -29.4321253055229 + CONST023 = -23.2681380862329 + CONST024 = -19.2678488679977 + CONST025 = -19.2678488679977 + CONST026 = -16.9926454679664 + CONST027 = -16.9926454679664 + CONST028 = -13.8744369255116 + CONST029 = -16.5831239517770 + CONST030 = 3.46860923137790 + CONST031 = -8.49632273398321 + CONST032 = -5.20291384706685 + CONST033 = -3.46860923137790 + CONST034 = -1.73430461568895 + VAR05 = x**5 + VAR06 = x**4 + VAR07 = x**3 + VAR08 = x**2 + VAR14 = y**5 + VAR15 = y**4 + VAR16 = y**3 + VAR17 = y**2 + VAR23 = z**5 + VAR24 = z**4 + VAR25 = z**3 + VAR26 = z**2 + # -------------------- kernel implementations + Y00 = CONST001 * VAR05 + CONST009 * VAR24 * x + CONST023 * VAR07 * VAR26 + Y01 = y * (CONST022 * VAR07 * z - CONST022 * VAR25 * x) + Y02 = ( + CONST000 * VAR05 + + VAR07 * (CONST028 * VAR17 + CONST033 * VAR26) + + x * (-CONST021 * VAR17 * VAR26 + CONST032 * VAR24) + ) + Y03 = CONST027 * VAR07 * y * z + x * (CONST017 * VAR16 * z + CONST026 * VAR25 * y) + Y04 = ( + CONST002 * VAR05 + + VAR07 * (CONST003 * VAR26 + CONST025 * VAR17) + + x * (CONST002 * VAR24 + CONST010 * VAR15 + CONST024 * VAR17 * VAR26) + ) + Y05 = ( + CONST004 * VAR14 + + VAR16 * (CONST029 * VAR08 + CONST029 * VAR26) + + y * (CONST005 * VAR06 + CONST006 * VAR24 + CONST011 * VAR08 * VAR26) + ) + Y06 = ( + CONST002 * VAR23 + + VAR25 * (CONST003 * VAR08 + CONST024 * VAR17) + + z * (CONST007 * VAR06 + CONST012 * VAR15 + CONST024 * VAR08 * VAR17) + ) + Y07 = VAR16 * (CONST026 * VAR08 - CONST026 * VAR26) + y * ( + -CONST031 * VAR06 + CONST031 * VAR24 + ) + Y08 = ( + CONST034 * VAR23 + + VAR25 * (CONST013 * VAR17 + CONST030 * VAR08) + + z * (CONST021 * VAR08 * VAR17 - CONST032 * VAR06) + ) + Y09 = y * (CONST018 * VAR06 + CONST018 * VAR24 + CONST020 * VAR08 * VAR26) + Y10 = CONST001 * VAR23 + CONST009 * VAR06 * z + CONST023 * VAR08 * VAR25 + # not the prettiest way to concatenate, but better than + # messing with the linter + tensors = [Y00, Y01, Y02, Y03, Y04, Y05, Y06, Y07, Y08, Y09, Y10] + return torch.cat(tensors, dim=-1) + + +@triton.jit +def fifth_order_fwd( + coord_ptr: tl.tensor, + output_ptr: tl.tensor, + block_size: tl.constexpr, + coord_numel: tl.constexpr, + output_numel: tl.constexpr, + col_offset: tl.constexpr, + output_stride: tl.constexpr, +): + # these are hardcoded because they are predetermined; + coord_stride = 3 + # work out the row offsets + block_id = tl.program_id(0) + coord_striding = tl.arange(0, block_size) * coord_stride + # as the name suggests, this is effectively every node/atom + coord_row_offset = coord_striding + (block_size * coord_stride * block_id) + x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel) + y = tl.load( + coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + z = tl.load( + coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + # -------------------- variable and constant definitions + CONST000 = 1.73430461568895 + CONST001 = 2.32681380862329 + CONST002 = 1.60565407233314 + CONST003 = 3.21130814466628 + CONST004 = 3.31662479035540 + CONST005 = 6.21867148191637 + CONST006 = 6.21867148191637 + CONST007 = 1.60565407233314 + CONST009 = 11.6340690431164 + CONST010 = 12.8452325786651 + CONST011 = 12.4373429638327 + CONST012 = 12.8452325786651 + CONST013 = 13.8744369255116 + CONST017 = 33.9852909359329 + CONST018 = 7.35803132638072 + CONST020 = -44.1481879582843 + CONST021 = -41.6233107765348 + CONST022 = -29.4321253055229 + CONST023 = -23.2681380862329 + CONST024 = -19.2678488679977 + CONST025 = -19.2678488679977 + CONST026 = -16.9926454679664 + CONST027 = -16.9926454679664 + CONST028 = -13.8744369255116 + CONST029 = -16.5831239517770 + CONST030 = 3.46860923137790 + CONST031 = -8.49632273398321 + CONST032 = -5.20291384706685 + CONST033 = -3.46860923137790 + CONST034 = -1.73430461568895 + VAR05 = x * x * x * x * x + VAR06 = x * x * x * x + VAR07 = x * x * x + VAR08 = x * x + VAR14 = y * y * y * y * y + VAR15 = y * y * y * y + VAR16 = y * y * y + VAR17 = y * y + VAR23 = z * z * z * z * z + VAR24 = z * z * z * z + VAR25 = z * z * z + VAR26 = z * z + # -------------------- kernel implementations + Y00 = CONST001 * VAR05 + CONST009 * VAR24 * x + CONST023 * VAR07 * VAR26 + Y01 = y * (CONST022 * VAR07 * z - CONST022 * VAR25 * x) + Y02 = ( + CONST000 * VAR05 + + VAR07 * (CONST028 * VAR17 + CONST033 * VAR26) + + x * (-CONST021 * VAR17 * VAR26 + CONST032 * VAR24) + ) + Y03 = CONST027 * VAR07 * y * z + x * (CONST017 * VAR16 * z + CONST026 * VAR25 * y) + Y04 = ( + CONST002 * VAR05 + + VAR07 * (CONST003 * VAR26 + CONST025 * VAR17) + + x * (CONST002 * VAR24 + CONST010 * VAR15 + CONST024 * VAR17 * VAR26) + ) + Y05 = ( + CONST004 * VAR14 + + VAR16 * (CONST029 * VAR08 + CONST029 * VAR26) + + y * (CONST005 * VAR06 + CONST006 * VAR24 + CONST011 * VAR08 * VAR26) + ) + Y06 = ( + CONST002 * VAR23 + + VAR25 * (CONST003 * VAR08 + CONST024 * VAR17) + + z * (CONST007 * VAR06 + CONST012 * VAR15 + CONST024 * VAR08 * VAR17) + ) + Y07 = VAR16 * (CONST026 * VAR08 - CONST026 * VAR26) + y * ( + -CONST031 * VAR06 + CONST031 * VAR24 + ) + Y08 = ( + CONST034 * VAR23 + + VAR25 * (CONST013 * VAR17 + CONST030 * VAR08) + + z * (CONST021 * VAR08 * VAR17 - CONST032 * VAR06) + ) + Y09 = y * (CONST018 * VAR06 + CONST018 * VAR24 + CONST020 * VAR08 * VAR26) + Y10 = CONST001 * VAR23 + CONST009 * VAR06 * z + CONST023 * VAR08 * VAR25 + output_striding = tl.arange(0, block_size) * output_stride + output_row_offset = ( + output_striding + (block_size * output_stride * block_id) + col_offset + ) + tl.store(output_ptr + output_row_offset, Y00, mask=output_row_offset < output_numel) + tl.store( + output_ptr + output_row_offset + 1, + Y01, + mask=output_row_offset + 1 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 2, + Y02, + mask=output_row_offset + 2 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 3, + Y03, + mask=output_row_offset + 3 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 4, + Y04, + mask=output_row_offset + 4 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 5, + Y05, + mask=output_row_offset + 5 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 6, + Y06, + mask=output_row_offset + 6 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 7, + Y07, + mask=output_row_offset + 7 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 8, + Y08, + mask=output_row_offset + 8 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 9, + Y09, + mask=output_row_offset + 9 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 10, + Y10, + mask=output_row_offset + 10 < output_numel, + ) + + +@triton.jit +def fifth_order_bwd( + coord_ptr: tl.tensor, + coord_grad_ptr: tl.tensor, + sph_grad_ptr: tl.tensor, + block_size: tl.constexpr, + coord_numel: tl.constexpr, + output_numel: tl.constexpr, + col_offset: tl.constexpr, + output_stride: tl.constexpr, +): + # work out the row offsets + block_id = tl.program_id(0) + # these are hardcoded because they are predetermined; + coord_stride = 3 + coord_striding = tl.arange(0, block_size) * coord_stride + # as the name suggests, this is effectively every node/atom + coord_row_offset = coord_striding + (block_size * coord_stride * block_id) + x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel) + y = tl.load( + coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + z = tl.load( + coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + output_striding = tl.arange(0, block_size) * output_stride + output_row_offset = ( + output_striding + (block_size * output_stride * block_id) + col_offset + ) + # load in gradients w.r.t. spherical harmonic projections + g_0 = tl.load( + sph_grad_ptr + output_row_offset, mask=output_row_offset < output_numel + ) + g_1 = tl.load( + sph_grad_ptr + output_row_offset + 1, mask=output_row_offset + 1 < output_numel + ) + g_2 = tl.load( + sph_grad_ptr + output_row_offset + 2, mask=output_row_offset + 2 < output_numel + ) + g_3 = tl.load( + sph_grad_ptr + output_row_offset + 3, mask=output_row_offset + 3 < output_numel + ) + g_4 = tl.load( + sph_grad_ptr + output_row_offset + 4, mask=output_row_offset + 4 < output_numel + ) + g_5 = tl.load( + sph_grad_ptr + output_row_offset + 5, mask=output_row_offset + 5 < output_numel + ) + g_6 = tl.load( + sph_grad_ptr + output_row_offset + 6, mask=output_row_offset + 6 < output_numel + ) + g_7 = tl.load( + sph_grad_ptr + output_row_offset + 7, mask=output_row_offset + 7 < output_numel + ) + g_8 = tl.load( + sph_grad_ptr + output_row_offset + 8, mask=output_row_offset + 8 < output_numel + ) + g_9 = tl.load( + sph_grad_ptr + output_row_offset + 9, mask=output_row_offset + 9 < output_numel + ) + g_10 = tl.load( + sph_grad_ptr + output_row_offset + 10, + mask=output_row_offset + 10 < output_numel, + ) + # -------------------- variable and constant definitions + CONST000 = 1.60565407233314 + CONST001 = 3.00000000000000 + CONST002 = 3.21130814466628 + CONST003 = 1.60565407233314 + CONST004 = 6.42261628933256 + CONST005 = 6.42261628933256 + CONST006 = 8.67152307844476 + CONST007 = 8.02827036166571 + CONST008 = 6.93721846275580 + CONST009 = 11.6340690431164 + CONST010 = 12.8452325786651 + CONST011 = 6.21867148191637 + CONST012 = 6.21867148191637 + CONST014 = 12.4373429638327 + CONST017 = 12.8452325786651 + CONST018 = 13.8744369255116 + CONST019 = 24.8746859276655 + CONST020 = 24.8746859276655 + CONST021 = 27.7488738510232 + CONST024 = 29.4321253055229 + CONST027 = 7.35803132638072 + CONST029 = 46.5362761724657 + CONST030 = 51.3809303146605 + CONST031 = 51.3809303146605 + CONST034 = 101.955872807799 + CONST036 = -8.67152307844475 + CONST037 = 3.46860923137790 + CONST038 = -88.2963759165686 + CONST039 = -83.2466215530696 + CONST040 = -69.8044142586986 + CONST041 = -50.9779364038993 + CONST042 = -50.9779364038993 + CONST043 = -46.5362761724657 + CONST044 = -44.1481879582843 + CONST045 = -41.6233107765348 + CONST046 = -38.5356977359954 + CONST047 = -38.5356977359954 + CONST048 = -33.1662479035540 + CONST049 = -33.9852909359329 + CONST050 = 6.42261628933257 + CONST051 = -33.9852909359329 + CONST052 = -29.4321253055229 + CONST053 = -27.7488738510232 + CONST054 = -20.8116553882674 + CONST055 = -19.2678488679977 + CONST056 = -19.2678488679977 + CONST057 = -16.9926454679664 + CONST058 = -16.9926454679664 + CONST059 = -13.8744369255116 + CONST060 = -16.5831239517770 + CONST061 = -8.49632273398321 + CONST062 = -6.93721846275580 + CONST063 = -5.20291384706685 + CONST064 = -3.46860923137790 + VAR06 = x * x * x * x + VAR07 = x * x * x + VAR08 = x * x + VAR15 = y * y * y * y + VAR16 = y * y * y + VAR17 = y * y + VAR24 = z * z * z * z + VAR25 = z * z * z + VAR26 = z * z + # -------------------- kernel implementations + g_x = tl.load( + coord_grad_ptr + coord_row_offset, mask=coord_row_offset < coord_numel + ) + g_y = tl.load( + coord_grad_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + g_z = tl.load( + coord_grad_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + g_x += ( + g_0 * (CONST009 * VAR06 + CONST009 * VAR24 + CONST040 * VAR08 * VAR26) + + g_1 * y * (CONST038 * VAR08 * z - CONST052 * VAR25) + + g_10 * (CONST029 * VAR07 * z + CONST043 * VAR25 * x) + + g_2 + * ( + CONST001 * VAR08 * (CONST059 * VAR17 + CONST064 * VAR26) + + CONST006 * VAR06 + - CONST045 * VAR17 * VAR26 + + CONST063 * VAR24 + ) + + g_3 * (CONST041 * VAR08 * y * z - CONST049 * VAR16 * z + CONST057 * VAR25 * y) + + g_4 + * ( + CONST000 * VAR24 + + CONST001 * VAR08 * (CONST002 * VAR26 + CONST055 * VAR17) + + CONST007 * VAR06 + + CONST010 * VAR15 + + CONST056 * VAR17 * VAR26 + ) + + g_5 * (CONST048 * VAR16 * x + y * (CONST019 * VAR07 + CONST019 * VAR26 * x)) + + g_6 * (CONST005 * VAR25 * x + z * (CONST004 * VAR07 + CONST046 * VAR17 * x)) + + g_7 * (CONST049 * VAR16 * x - CONST051 * VAR07 * y) + + g_8 * (CONST008 * VAR25 * x + z * (CONST039 * VAR17 * x - CONST054 * VAR07)) + + g_9 * y * (CONST024 * VAR07 + CONST038 * VAR26 * x) + ) + g_y += ( + g_1 * (CONST052 * VAR07 * z - CONST052 * VAR25 * x) + + g_2 * (-CONST039 * VAR26 * x * y + CONST053 * VAR07 * y) + + g_3 * (CONST058 * VAR07 * z + x * (CONST034 * VAR17 * z + CONST057 * VAR25)) + + g_4 * (CONST047 * VAR07 * y + x * (CONST030 * VAR16 + CONST046 * VAR26 * y)) + + g_5 + * ( + CONST001 * VAR17 * (CONST060 * VAR08 + CONST060 * VAR26) + + CONST011 * VAR06 + + CONST012 * VAR24 + + CONST014 * VAR08 * VAR26 + - CONST060 * VAR15 + ) + + g_6 * (CONST046 * VAR25 * y + z * (CONST031 * VAR16 + CONST046 * VAR08 * y)) + + g_7 + * ( + CONST001 * VAR17 * (CONST057 * VAR08 - CONST057 * VAR26) + - CONST061 * VAR06 + + CONST061 * VAR24 + ) + + g_8 * (CONST021 * VAR25 * y + CONST039 * VAR08 * y * z) + + g_9 * (CONST027 * VAR06 + CONST027 * VAR24 + CONST044 * VAR08 * VAR26) + ) + g_z += ( + g_0 * (CONST029 * VAR25 * x + CONST043 * VAR07 * z) + + g_1 * y * (-CONST038 * VAR26 * x + CONST052 * VAR07) + + g_10 * (CONST009 * VAR06 + CONST009 * VAR24 + CONST040 * VAR08 * VAR26) + + g_2 * (CONST062 * VAR07 * z + x * (-CONST039 * VAR17 * z + CONST054 * VAR25)) + + g_3 * (CONST058 * VAR07 * y + x * (CONST042 * VAR26 * y - CONST049 * VAR16)) + + g_4 * (CONST005 * VAR07 * z + x * (CONST046 * VAR17 * z + CONST050 * VAR25)) + + g_5 * (CONST048 * VAR16 * z + y * (CONST019 * VAR08 * z + CONST020 * VAR25)) + + g_6 + * ( + CONST001 * VAR26 * (CONST002 * VAR08 + CONST056 * VAR17) + + CONST003 * VAR06 + + CONST007 * VAR24 + + CONST017 * VAR15 + + CONST056 * VAR08 * VAR17 + ) + + g_7 * (-CONST049 * VAR16 * z + CONST051 * VAR25 * y) + + g_8 + * ( + CONST001 * VAR26 * (CONST018 * VAR17 + CONST037 * VAR08) + + CONST036 * VAR24 + + CONST045 * VAR08 * VAR17 + - CONST063 * VAR06 + ) + + g_9 * y * (CONST024 * VAR25 + CONST038 * VAR08 * z) + ) + # write out gradients + tl.store( + coord_grad_ptr + coord_row_offset, g_x, mask=coord_row_offset < coord_numel + ) + tl.store( + coord_grad_ptr + coord_row_offset + 1, + g_y, + mask=coord_row_offset + 1 < coord_numel, + ) + tl.store( + coord_grad_ptr + coord_row_offset + 2, + g_z, + mask=coord_row_offset + 2 < coord_numel, + ) diff --git a/src/equitriton/sph_harm/direct/y_6.py b/src/equitriton/sph_harm/direct/y_6.py new file mode 100644 index 0000000..c376c61 --- /dev/null +++ b/src/equitriton/sph_harm/direct/y_6.py @@ -0,0 +1,758 @@ +import triton +import torch +from triton import language as tl + +from equitriton.utils import calculate_lastdim_num_blocks + +__all__ = ["SixthOrderSphericalHarmonic"] + + +class SixthOrderSphericalHarmonic(torch.autograd.Function): + @staticmethod + def forward( + ctx, + coords: torch.Tensor, + output_tensor: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + block_size: int = 64, + col_offset: int = 0, + ): + if not isinstance(output_tensor, torch.Tensor): + output_tensor = torch.empty( + (*coords.shape[:-1], 13), dtype=coords.dtype, device=coords.device + ) + coord_numel = coords.numel() + output_numel = output_tensor.numel() + num_blocks = calculate_lastdim_num_blocks(coords, block_size) + # apply the kernel + sixth_order_fwd[num_blocks,]( + coords, + output_tensor, + block_size, + coord_numel, + output_numel, + col_offset, + output_tensor.stride(-2), + ) + ctx.save_for_backward(coords) + return output_tensor + + @staticmethod + def backward( + ctx, sph_grad_tensor: torch.Tensor, block_size: int = 64, col_offset: int = 0 + ) -> torch.Tensor: + (coords,) = ctx.saved_tensors + coord_grad_output = torch.zeros_like(coords) + num_blocks = calculate_lastdim_num_blocks(coords, block_size) + # call backward kernel + sixth_order_bwd[num_blocks,]( + coords, + coord_grad_output, + sph_grad_tensor, + block_size, + coords.numel(), + sph_grad_tensor.numel(), + col_offset, + sph_grad_tensor.stride(-2), + ) + return coord_grad_output + + +def _torch_fwd(coords: torch.Tensor) -> torch.Tensor: + """ + PyTorch implementation of the kernel. This is designed + purely for unit testing to ensure that the Triton implementation + is behaving as intended. + + Parameters + ---------- + coords : torch.Tensor + N-d tensor, where the last dimension corresponds to + xyz values. + + Returns + ------- + torch.Tensor + N-d tensor, where the last dimension corresponds to + each projection of the second order spherical harmonic. + """ + x = coords[..., 0].contiguous().unsqueeze(-1) + y = coords[..., 1].contiguous().unsqueeze(-1) + z = coords[..., 2].contiguous().unsqueeze(-1) + # -------------------- variable and constant definitions + CONST002 = 3.26558761940328 + CONST003 = 3.26558761940328 + CONST004 = 6.53117523880657 + CONST006 = 8.38944649544891 + CONST007 = 9.79676285820985 + CONST008 = 10.3266947761614 + CONST009 = 3.60555127546399 + CONST010 = -1.78863600265677 + CONST011 = 14.5309475774982 + CONST012 = 8.94318001328386 + CONST013 = 16.5227116418583 + CONST014 = 16.5227116418583 + CONST015 = 17.8863600265677 + CONST017 = 20.6533895523229 + CONST018 = 20.2812259244849 + CONST019 = -107.318160159406 + CONST020 = 17.8863600265677 + CONST022 = 29.3902885746295 + CONST024 = 40.5624518489699 + CONST025 = 41.9472324772445 + CONST026 = -1.63279380970164 + CONST027 = -83.8944649544891 + CONST028 = -78.3741028656788 + CONST030 = -71.5454401062709 + CONST032 = -52.2494019104525 + CONST033 = -52.2494019104525 + CONST035 = -48.4364919249939 + CONST036 = -41.3067791046458 + CONST037 = -36.3273689437454 + CONST038 = -29.3902885746295 + CONST039 = -27.0416345659799 + CONST040 = -26.1247009552263 + CONST041 = -26.1247009552263 + CONST042 = -19.5935257164197 + CONST043 = -2.42182459624970 + CONST044 = -9.79676285820985 + CONST045 = -7.15454401062709 + CONST046 = -3.38020432074749 + CONST047 = -1.12673477358250 + VAR07 = x * x * x + VAR08 = x * x + VAR04 = VAR07 * VAR07 + VAR05 = VAR07 * VAR08 + VAR06 = VAR08 * VAR08 + VAR16 = y * y * y + VAR17 = y * y + VAR13 = VAR16 * VAR16 + VAR14 = VAR16 * VAR17 + VAR15 = VAR17 * VAR17 + VAR25 = z * z * z + VAR26 = z * z + VAR22 = VAR25 * VAR25 + VAR23 = VAR25 * VAR26 + VAR24 = VAR26 * VAR26 + # -------------------- kernel implementations + Y00 = CONST011 * VAR05 * z + CONST011 * VAR23 * x + CONST035 * VAR07 * VAR25 + Y01 = y * (CONST006 * VAR05 + CONST025 * VAR24 * x + CONST027 * VAR07 * VAR26) + Y02 = ( + -CONST045 * VAR05 * z + + CONST045 * VAR23 * x + + VAR17 * (CONST030 * VAR07 * z - CONST030 * VAR25 * x) + ) + Y03 = VAR16 * (-CONST028 * VAR26 * x + CONST040 * VAR07) + y * ( + CONST007 * VAR05 + CONST038 * VAR24 * x + CONST042 * VAR07 * VAR26 + ) + Y04 = ( + CONST003 * VAR05 * z + + VAR07 * (CONST004 * VAR25 + CONST033 * VAR17 * z) + + x * (CONST002 * VAR23 - CONST032 * VAR15 * z + CONST032 * VAR17 * VAR25) + ) + Y05 = ( + CONST008 * VAR05 * y + + VAR07 * (CONST017 * VAR26 * y + CONST036 * VAR16) + + x * (CONST008 * VAR24 * y + CONST013 * VAR14 + CONST036 * VAR16 * VAR26) + ) + Y06 = ( + CONST009 * VAR13 + + CONST018 * VAR17 * VAR24 + + CONST039 * VAR15 * VAR26 + + CONST047 * VAR04 + + CONST047 * VAR22 + + VAR06 * (CONST018 * VAR17 + CONST046 * VAR26) + + VAR08 * (CONST024 * VAR17 * VAR26 + CONST039 * VAR15 + CONST046 * VAR24) + ) + Y07 = ( + CONST008 * VAR23 * y + + VAR25 * (CONST017 * VAR08 * y + CONST036 * VAR16) + + z * (CONST008 * VAR06 * y + CONST014 * VAR14 + CONST036 * VAR08 * VAR16) + ) + Y08 = ( + CONST026 * VAR04 + - CONST026 * VAR22 + + CONST040 * VAR17 * VAR24 + - CONST041 * VAR15 * VAR26 + + VAR06 * (CONST026 * VAR26 - CONST041 * VAR17) + + VAR08 * (-CONST026 * VAR24 + CONST041 * VAR15) + ) + Y09 = VAR16 * (CONST028 * VAR08 * z - CONST041 * VAR25) + y * ( + CONST022 * VAR06 * z - CONST042 * VAR08 * VAR25 + CONST044 * VAR23 + ) + Y10 = ( + CONST010 * VAR04 + + CONST010 * VAR22 + + CONST020 * VAR17 * VAR24 + + VAR06 * (CONST012 * VAR26 + CONST015 * VAR17) + + VAR08 * (CONST012 * VAR24 + CONST019 * VAR17 * VAR26) + ) + Y11 = y * (CONST006 * VAR23 + CONST025 * VAR06 * z + CONST027 * VAR08 * VAR25) + Y12 = ( + -CONST037 * VAR06 * VAR26 + + CONST037 * VAR08 * VAR24 + + CONST043 * VAR04 + - CONST043 * VAR22 + ) + # not the prettiest way to concatenate, but better than + # messing with the linter + tensors = [Y00, Y01, Y02, Y03, Y04, Y05, Y06, Y07, Y08, Y09, Y10, Y11, Y12] + return torch.cat(tensors, dim=-1) + + +@triton.jit +def sixth_order_fwd( + coord_ptr: tl.tensor, + output_ptr: tl.tensor, + block_size: tl.constexpr, + coord_numel: tl.constexpr, + output_numel: tl.constexpr, + col_offset: tl.constexpr, + output_stride: tl.constexpr, +): + # these are hardcoded because they are predetermined; + coord_stride = 3 + # work out the row offsets + block_id = tl.program_id(0) + coord_striding = tl.arange(0, block_size) * coord_stride + # as the name suggests, this is effectively every node/atom + coord_row_offset = coord_striding + (block_size * coord_stride * block_id) + x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel) + y = tl.load( + coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + z = tl.load( + coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + # -------------------- variable and constant definitions + CONST002 = 3.26558761940328 + CONST003 = 3.26558761940328 + CONST004 = 6.53117523880657 + CONST006 = 8.38944649544891 + CONST007 = 9.79676285820985 + CONST008 = 10.3266947761614 + CONST009 = 3.60555127546399 + CONST010 = -1.78863600265677 + CONST011 = 14.5309475774982 + CONST012 = 8.94318001328386 + CONST013 = 16.5227116418583 + CONST014 = 16.5227116418583 + CONST015 = 17.8863600265677 + CONST017 = 20.6533895523229 + CONST018 = 20.2812259244849 + CONST019 = -107.318160159406 + CONST020 = 17.8863600265677 + CONST022 = 29.3902885746295 + CONST024 = 40.5624518489699 + CONST025 = 41.9472324772445 + CONST026 = -1.63279380970164 + CONST027 = -83.8944649544891 + CONST028 = -78.3741028656788 + CONST030 = -71.5454401062709 + CONST032 = -52.2494019104525 + CONST033 = -52.2494019104525 + CONST035 = -48.4364919249939 + CONST036 = -41.3067791046458 + CONST037 = -36.3273689437454 + CONST038 = -29.3902885746295 + CONST039 = -27.0416345659799 + CONST040 = -26.1247009552263 + CONST041 = -26.1247009552263 + CONST042 = -19.5935257164197 + CONST043 = -2.42182459624970 + CONST044 = -9.79676285820985 + CONST045 = -7.15454401062709 + CONST046 = -3.38020432074749 + CONST047 = -1.12673477358250 + VAR07 = x * x * x + VAR08 = x * x + VAR04 = VAR07 * VAR07 + VAR05 = VAR07 * VAR08 + VAR06 = VAR08 * VAR08 + VAR16 = y * y * y + VAR17 = y * y + VAR13 = VAR16 * VAR16 + VAR14 = VAR16 * VAR17 + VAR15 = VAR17 * VAR17 + VAR25 = z * z * z + VAR26 = z * z + VAR22 = VAR25 * VAR25 + VAR23 = VAR25 * VAR26 + VAR24 = VAR26 * VAR26 + # -------------------- kernel implementations + Y00 = CONST011 * VAR05 * z + CONST011 * VAR23 * x + CONST035 * VAR07 * VAR25 + Y01 = y * (CONST006 * VAR05 + CONST025 * VAR24 * x + CONST027 * VAR07 * VAR26) + Y02 = ( + -CONST045 * VAR05 * z + + CONST045 * VAR23 * x + + VAR17 * (CONST030 * VAR07 * z - CONST030 * VAR25 * x) + ) + Y03 = VAR16 * (-CONST028 * VAR26 * x + CONST040 * VAR07) + y * ( + CONST007 * VAR05 + CONST038 * VAR24 * x + CONST042 * VAR07 * VAR26 + ) + Y04 = ( + CONST003 * VAR05 * z + + VAR07 * (CONST004 * VAR25 + CONST033 * VAR17 * z) + + x * (CONST002 * VAR23 - CONST032 * VAR15 * z + CONST032 * VAR17 * VAR25) + ) + Y05 = ( + CONST008 * VAR05 * y + + VAR07 * (CONST017 * VAR26 * y + CONST036 * VAR16) + + x * (CONST008 * VAR24 * y + CONST013 * VAR14 + CONST036 * VAR16 * VAR26) + ) + Y06 = ( + CONST009 * VAR13 + + CONST018 * VAR17 * VAR24 + + CONST039 * VAR15 * VAR26 + + CONST047 * VAR04 + + CONST047 * VAR22 + + VAR06 * (CONST018 * VAR17 + CONST046 * VAR26) + + VAR08 * (CONST024 * VAR17 * VAR26 + CONST039 * VAR15 + CONST046 * VAR24) + ) + Y07 = ( + CONST008 * VAR23 * y + + VAR25 * (CONST017 * VAR08 * y + CONST036 * VAR16) + + z * (CONST008 * VAR06 * y + CONST014 * VAR14 + CONST036 * VAR08 * VAR16) + ) + Y08 = ( + CONST026 * VAR04 + - CONST026 * VAR22 + + CONST040 * VAR17 * VAR24 + - CONST041 * VAR15 * VAR26 + + VAR06 * (CONST026 * VAR26 - CONST041 * VAR17) + + VAR08 * (-CONST026 * VAR24 + CONST041 * VAR15) + ) + Y09 = VAR16 * (CONST028 * VAR08 * z - CONST041 * VAR25) + y * ( + CONST022 * VAR06 * z - CONST042 * VAR08 * VAR25 + CONST044 * VAR23 + ) + Y10 = ( + CONST010 * VAR04 + + CONST010 * VAR22 + + CONST020 * VAR17 * VAR24 + + VAR06 * (CONST012 * VAR26 + CONST015 * VAR17) + + VAR08 * (CONST012 * VAR24 + CONST019 * VAR17 * VAR26) + ) + Y11 = y * (CONST006 * VAR23 + CONST025 * VAR06 * z + CONST027 * VAR08 * VAR25) + Y12 = ( + -CONST037 * VAR06 * VAR26 + + CONST037 * VAR08 * VAR24 + + CONST043 * VAR04 + - CONST043 * VAR22 + ) + output_striding = tl.arange(0, block_size) * output_stride + output_row_offset = ( + output_striding + (block_size * output_stride * block_id) + col_offset + ) + tl.store(output_ptr + output_row_offset, Y00, mask=output_row_offset < output_numel) + tl.store( + output_ptr + output_row_offset + 1, + Y01, + mask=output_row_offset + 1 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 2, + Y02, + mask=output_row_offset + 2 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 3, + Y03, + mask=output_row_offset + 3 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 4, + Y04, + mask=output_row_offset + 4 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 5, + Y05, + mask=output_row_offset + 5 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 6, + Y06, + mask=output_row_offset + 6 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 7, + Y07, + mask=output_row_offset + 7 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 8, + Y08, + mask=output_row_offset + 8 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 9, + Y09, + mask=output_row_offset + 9 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 10, + Y10, + mask=output_row_offset + 10 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 11, + Y11, + mask=output_row_offset + 11 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 12, + Y12, + mask=output_row_offset + 12 < output_numel, + ) + + +@triton.jit +def sixth_order_bwd( + coord_ptr: tl.tensor, + coord_grad_ptr: tl.tensor, + sph_grad_ptr: tl.tensor, + block_size: tl.constexpr, + coord_numel: tl.constexpr, + output_numel: tl.constexpr, + col_offset: tl.constexpr, + output_stride: tl.constexpr, +): + # work out the row offsets + block_id = tl.program_id(0) + # these are hardcoded because they are predetermined; + coord_stride = 3 + coord_striding = tl.arange(0, block_size) * coord_stride + # as the name suggests, this is effectively every node/atom + coord_row_offset = coord_striding + (block_size * coord_stride * block_id) + x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel) + y = tl.load( + coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + z = tl.load( + coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + output_striding = tl.arange(0, block_size) * output_stride + output_row_offset = ( + output_striding + (block_size * output_stride * block_id) + col_offset + ) + # load in gradients w.r.t. spherical harmonic projections + g_0 = tl.load( + sph_grad_ptr + output_row_offset, mask=output_row_offset < output_numel + ) + g_1 = tl.load( + sph_grad_ptr + output_row_offset + 1, mask=output_row_offset + 1 < output_numel + ) + g_2 = tl.load( + sph_grad_ptr + output_row_offset + 2, mask=output_row_offset + 2 < output_numel + ) + g_3 = tl.load( + sph_grad_ptr + output_row_offset + 3, mask=output_row_offset + 3 < output_numel + ) + g_4 = tl.load( + sph_grad_ptr + output_row_offset + 4, mask=output_row_offset + 4 < output_numel + ) + g_5 = tl.load( + sph_grad_ptr + output_row_offset + 5, mask=output_row_offset + 5 < output_numel + ) + g_6 = tl.load( + sph_grad_ptr + output_row_offset + 6, mask=output_row_offset + 6 < output_numel + ) + g_7 = tl.load( + sph_grad_ptr + output_row_offset + 7, mask=output_row_offset + 7 < output_numel + ) + g_8 = tl.load( + sph_grad_ptr + output_row_offset + 8, mask=output_row_offset + 8 < output_numel + ) + g_9 = tl.load( + sph_grad_ptr + output_row_offset + 9, mask=output_row_offset + 9 < output_numel + ) + g_10 = tl.load( + sph_grad_ptr + output_row_offset + 10, + mask=output_row_offset + 10 < output_numel, + ) + g_11 = tl.load( + sph_grad_ptr + output_row_offset + 11, + mask=output_row_offset + 11 < output_numel, + ) + g_12 = tl.load( + sph_grad_ptr + output_row_offset + 12, + mask=output_row_offset + 12 < output_numel, + ) + # -------------------- variable and constant definitions + CONST000 = 2.00000000000000 + CONST002 = 4.00000000000000 + CONST003 = 3.00000000000000 + CONST004 = 6.53117523880657 + CONST006 = 8.94318001328386 + CONST007 = 8.38944649544891 + CONST008 = 10.3266947761614 + CONST009 = 9.79676285820985 + CONST013 = 16.3279380970164 + CONST014 = 17.8863600265677 + CONST015 = 16.5227116418583 + CONST016 = 20.6533895523229 + CONST017 = 20.2812259244849 + CONST018 = 21.6333076527839 + CONST020 = 17.8863600265677 + CONST022 = 29.3902885746295 + CONST024 = 35.7727200531355 + CONST026 = 40.5624518489699 + CONST028 = 41.9472324772445 + CONST029 = 48.9838142910493 + CONST030 = 51.6334738808072 + CONST035 = 71.5454401062709 + CONST037 = 81.1249036979398 + CONST039 = 82.6135582092915 + CONST040 = -3.26558761940328 + CONST042 = 117.561154298518 + CONST046 = 208.997607641810 + CONST048 = -251.683394863467 + CONST049 = -214.636320318813 + CONST050 = -214.636320318813 + CONST051 = 16.5227116418583 + CONST052 = -167.788929908978 + CONST053 = -156.748205731358 + CONST054 = -145.309475774982 + CONST055 = -123.920337313937 + CONST056 = -117.561154298518 + CONST057 = 3.26558761940328 + CONST058 = -108.166538263920 + CONST059 = -107.318160159406 + CONST060 = -104.498803820905 + CONST061 = -104.498803820905 + CONST062 = -83.8944649544891 + CONST063 = -82.6135582092915 + CONST064 = -78.3741028656788 + CONST065 = -72.6547378874909 + CONST066 = -71.5454401062709 + CONST067 = -58.7805771492591 + CONST068 = -54.0832691319598 + CONST069 = -52.2494019104525 + CONST070 = -52.2494019104525 + CONST071 = -48.9838142910492 + CONST072 = -41.3067791046458 + CONST073 = -39.1870514328394 + CONST074 = -35.7727200531355 + CONST075 = -29.3902885746295 + CONST076 = -27.0416345659799 + CONST077 = -26.1247009552263 + CONST078 = -26.1247009552263 + CONST079 = -19.5935257164197 + CONST080 = -14.5309475774982 + CONST081 = -13.5208172829900 + CONST082 = -10.7318160159406 + CONST083 = -9.79676285820985 + CONST084 = -7.15454401062709 + CONST085 = -6.76040864149498 + CONST086 = -3.38020432074749 + CONST087 = -1.63279380970164 + VAR07 = x * x * x + VAR08 = x * x + VAR05 = VAR07 * VAR08 + VAR06 = VAR08 * VAR08 + VAR16 = y * y * y + VAR17 = y * y + VAR14 = VAR16 * VAR17 + VAR15 = VAR17 * VAR17 + VAR25 = z * z * z + VAR26 = z * z + VAR23 = VAR25 * VAR26 + VAR24 = VAR26 * VAR26 + # -------------------- kernel implementations + g_x = tl.load( + coord_grad_ptr + coord_row_offset, mask=coord_row_offset < coord_numel + ) + g_y = tl.load( + coord_grad_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + g_z = tl.load( + coord_grad_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + g_x += ( + g_0 * (CONST054 * VAR08 * VAR25 - CONST065 * VAR06 * z - CONST080 * VAR23) + + g_1 * y * (CONST028 * VAR06 + CONST028 * VAR24 + CONST048 * VAR08 * VAR26) + + g_10 + * ( + CONST000 * x * (CONST006 * VAR24 + CONST059 * VAR17 * VAR26) + + CONST002 * VAR07 * (CONST006 * VAR26 + CONST014 * VAR17) + + CONST082 * VAR05 + ) + + g_11 * y * (-CONST052 * VAR07 * z + CONST052 * VAR25 * x) + + g_12 * (-CONST054 * VAR07 * VAR26 + CONST065 * VAR24 * x + CONST080 * VAR05) + + g_2 + * ( + -CONST074 * VAR06 * z + + CONST084 * VAR23 + + VAR17 * (CONST049 * VAR08 * z - CONST066 * VAR25) + ) + + g_3 + * ( + VAR16 * (CONST064 * VAR08 - CONST064 * VAR26) + + y * (CONST029 * VAR06 + CONST067 * VAR08 * VAR26 + CONST075 * VAR24) + ) + + g_4 + * ( + CONST003 * VAR08 * (CONST004 * VAR25 + CONST069 * VAR17 * z) + + CONST013 * VAR06 * z + - CONST040 * VAR23 + - CONST070 * VAR15 * z + + CONST070 * VAR17 * VAR25 + ) + + g_5 + * ( + CONST003 * VAR08 * (CONST016 * VAR26 * y + CONST072 * VAR16) + + CONST008 * VAR24 * y + + CONST015 * VAR14 + + CONST030 * VAR06 * y + + CONST072 * VAR16 * VAR26 + ) + + g_6 + * ( + CONST000 + * x + * (CONST026 * VAR17 * VAR26 + CONST076 * VAR15 + CONST086 * VAR24) + + CONST002 * VAR07 * (CONST017 * VAR17 + CONST086 * VAR26) + + CONST085 * VAR05 + ) + + g_7 + * ( + -CONST072 * VAR25 * x * y + + z * (CONST063 * VAR16 * x - CONST072 * VAR07 * y) + ) + + g_8 + * ( + CONST000 * x * (CONST077 * VAR15 - CONST087 * VAR24) + + CONST002 * VAR07 * (-CONST077 * VAR17 + CONST087 * VAR26) + + CONST083 * VAR05 + ) + + g_9 + * (CONST053 * VAR16 * x * z + y * (CONST042 * VAR07 * z - CONST073 * VAR25 * x)) + ) + g_y += ( + CONST000 * g_2 * y * (CONST066 * VAR07 * z - CONST066 * VAR25 * x) + + g_1 * (CONST007 * VAR05 + CONST028 * VAR24 * x + CONST062 * VAR07 * VAR26) + + g_10 + * (CONST024 * VAR06 * y + CONST050 * VAR08 * VAR26 * y - CONST074 * VAR24 * y) + + g_11 * (CONST007 * VAR23 + CONST028 * VAR06 * z + CONST062 * VAR08 * VAR25) + + g_3 + * ( + CONST003 * VAR17 * (-CONST064 * VAR26 * x + CONST078 * VAR07) + + CONST009 * VAR05 + + CONST075 * VAR24 * x + + CONST079 * VAR07 * VAR26 + ) + + g_4 + * (CONST061 * VAR07 * y * z + x * (CONST046 * VAR16 * z + CONST060 * VAR25 * y)) + + g_5 + * ( + CONST008 * VAR05 + + VAR07 * (CONST016 * VAR26 + CONST055 * VAR17) + + x * (CONST008 * VAR24 + CONST055 * VAR17 * VAR26 - CONST063 * VAR15) + ) + + g_6 + * ( + CONST018 * VAR14 + + CONST026 * VAR06 * y + + CONST026 * VAR24 * y + + CONST058 * VAR16 * VAR26 + + VAR08 * (CONST037 * VAR26 * y + CONST058 * VAR16) + ) + + g_7 + * ( + CONST008 * VAR23 + + VAR25 * (CONST016 * VAR08 + CONST055 * VAR17) + + z * (CONST008 * VAR06 + CONST039 * VAR15 + CONST055 * VAR08 * VAR17) + ) + + g_8 + * ( + CONST060 * VAR08 * VAR16 + - CONST060 * VAR16 * VAR26 + + CONST069 * VAR24 * y + - CONST070 * VAR06 * y + ) + + g_9 + * ( + CONST003 * VAR17 * (CONST064 * VAR08 * z - CONST077 * VAR25) + + CONST022 * VAR06 * z + - CONST079 * VAR08 * VAR25 + + CONST083 * VAR23 + ) + ) + g_z += ( + g_0 * (CONST054 * VAR07 * VAR26 - CONST065 * VAR24 * x - CONST080 * VAR05) + + g_1 * y * (CONST052 * VAR07 * z - CONST052 * VAR25 * x) + + g_10 + * ( + CONST020 * VAR06 * z + + CONST035 * VAR17 * VAR25 + + CONST082 * VAR23 + + VAR08 * (CONST050 * VAR17 * z - CONST074 * VAR25) + ) + + g_11 * y * (CONST028 * VAR06 + CONST028 * VAR24 + CONST048 * VAR08 * VAR26) + + g_12 * (CONST054 * VAR08 * VAR25 - CONST065 * VAR06 * z - CONST080 * VAR23) + + g_2 + * ( + CONST074 * VAR24 * x + - CONST084 * VAR05 + + VAR17 * (-CONST049 * VAR26 * x + CONST066 * VAR07) + ) + + g_3 + * ( + -CONST053 * VAR16 * x * z + + y * (CONST056 * VAR25 * x + CONST073 * VAR07 * z) + ) + + g_4 + * ( + CONST057 * VAR05 + + VAR07 * (CONST069 * VAR17 - CONST079 * VAR26) + + x * (CONST013 * VAR24 + CONST053 * VAR17 * VAR26 - CONST070 * VAR15) + ) + + g_5 + * ( + -CONST072 * VAR07 * y * z + + x * (CONST063 * VAR16 * z - CONST072 * VAR25 * y) + ) + + g_6 + * ( + CONST037 * VAR17 * VAR25 + + CONST068 * VAR15 * z + + CONST085 * VAR06 * z + + CONST085 * VAR23 + + VAR08 * (CONST037 * VAR17 * z + CONST081 * VAR25) + ) + + g_7 + * ( + CONST003 * VAR26 * (CONST016 * VAR08 * y + CONST072 * VAR16) + + CONST008 * VAR06 * y + + CONST030 * VAR24 * y + + CONST051 * VAR14 + + CONST072 * VAR08 * VAR16 + ) + + g_8 + * ( + CONST004 * VAR08 * VAR25 + + CONST040 * VAR06 * z + + CONST061 * VAR17 * VAR25 + - CONST070 * VAR15 * z + - CONST083 * VAR23 + ) + + g_9 + * ( + VAR16 * (CONST064 * VAR08 - CONST064 * VAR26) + + y * (CONST022 * VAR06 - CONST067 * VAR08 * VAR26 + CONST071 * VAR24) + ) + ) + # write out gradients + tl.store( + coord_grad_ptr + coord_row_offset, g_x, mask=coord_row_offset < coord_numel + ) + tl.store( + coord_grad_ptr + coord_row_offset + 1, + g_y, + mask=coord_row_offset + 1 < coord_numel, + ) + tl.store( + coord_grad_ptr + coord_row_offset + 2, + g_z, + mask=coord_row_offset + 2 < coord_numel, + ) diff --git a/src/equitriton/sph_harm/direct/y_7.py b/src/equitriton/sph_harm/direct/y_7.py new file mode 100644 index 0000000..deece19 --- /dev/null +++ b/src/equitriton/sph_harm/direct/y_7.py @@ -0,0 +1,1088 @@ +import triton +import torch +from triton import language as tl + +from equitriton.utils import calculate_lastdim_num_blocks + +__all__ = ["SeventhOrderSphericalHarmonic"] + + +class SeventhOrderSphericalHarmonic(torch.autograd.Function): + @staticmethod + def forward( + ctx, + coords: torch.Tensor, + output_tensor: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + block_size: int = 64, + col_offset: int = 0, + ): + if not isinstance(output_tensor, torch.Tensor): + output_tensor = torch.empty( + (*coords.shape[:-1], 15), dtype=coords.dtype, device=coords.device + ) + coord_numel = coords.numel() + output_numel = output_tensor.numel() + num_blocks = calculate_lastdim_num_blocks(coords, block_size) + # apply the kernel + seventh_order_fwd[num_blocks,]( + coords, + output_tensor, + block_size, + coord_numel, + output_numel, + col_offset, + output_tensor.stride(-2), + ) + ctx.save_for_backward(coords) + return output_tensor + + @staticmethod + def backward( + ctx, + sph_grad_tensor: torch.Tensor, + block_size: int = 64, + col_offset: int = 0, + ) -> torch.Tensor: + (coords,) = ctx.saved_tensors + coord_grad_output = torch.zeros_like(coords) + num_blocks = calculate_lastdim_num_blocks(coords, block_size) + # call backward kernel + seventh_order_bwd[num_blocks,]( + coords, + coord_grad_output, + sph_grad_tensor, + block_size, + coords.numel(), + sph_grad_tensor.numel(), + col_offset, + sph_grad_tensor.stride(-2), + ) + return coord_grad_output + + +def _torch_fwd(coords: torch.Tensor) -> torch.Tensor: + """ + PyTorch implementation of the kernel. This is designed + purely for unit testing to ensure that the Triton implementation + is behaving as intended. + + Parameters + ---------- + coords : torch.Tensor + N-d tensor, where the last dimension corresponds to + xyz values. + + Returns + ------- + torch.Tensor + N-d tensor, where the last dimension corresponds to + each projection of the second order spherical harmonic. + """ + x = coords[..., 0].contiguous().unsqueeze(-1) + y = coords[..., 1].contiguous().unsqueeze(-1) + z = coords[..., 2].contiguous().unsqueeze(-1) + # -------------------- variable and constant definitions + CONST002 = 3.87298334620742 + CONST008 = 11.7655316231354 + CONST010 = 16.5555704843566 + CONST012 = 20.4939015319192 + CONST013 = 20.4939015319192 + CONST014 = 22.0740939791422 + CONST015 = 23.5310632462709 + CONST017 = 36.7901566319036 + CONST019 = 38.4260653723485 + CONST020 = 38.4260653723485 + CONST021 = 38.4260653723485 + CONST023 = -4.99169231699030 + CONST025 = 47.0621264925418 + CONST026 = 50.8329064189723 + CONST028 = 55.1852349478554 + CONST029 = 56.2781179722634 + CONST030 = 56.2781179722634 + CONST032 = 66.5558975598707 + CONST033 = 75.2994023880668 + CONST037 = 101.665812837945 + CONST038 = 110.370469895711 + CONST041 = 147.160626527614 + CONST042 = -1.66389743899677 + CONST043 = -9.37968632871057 + CONST044 = -1.66389743899677 + CONST045 = -220.740939791422 + CONST046 = -220.740939791422 + CONST047 = -1.60108605718119 + CONST048 = -187.593726574211 + CONST049 = -9.19753915797590 + CONST050 = -1.83950783159518 + CONST051 = -1.83950783159518 + CONST052 = -4.80325817154356 + CONST053 = -147.160626527614 + CONST054 = -140.695294930659 + CONST055 = -133.111795119741 + CONST056 = -125.499003980111 + CONST057 = -125.499003980111 + CONST058 = -99.8338463398060 + CONST059 = -87.7389315936062 + CONST060 = -76.8521307446970 + CONST061 = -66.5558975598707 + CONST062 = -62.7495019900557 + CONST063 = -52.6433589561637 + CONST064 = -44.1481879582843 + CONST065 = -44.3705983732471 + CONST066 = -40.6663251351779 + CONST067 = -40.6663251351779 + CONST068 = -8.31948719498384 + CONST069 = -37.6497011940334 + CONST070 = -33.2779487799353 + CONST071 = -25.4164532094862 + CONST072 = -25.4164532094862 + CONST073 = -17.5477863187212 + CONST074 = -11.7655316231354 + CONST075 = -11.0370469895711 + CONST076 = -9.19753915797590 + CONST077 = -8.47215106982872 + CONST078 = -4.80325817154356 + CONST079 = -2.50682661696018 + CONST080 = -1.60108605718119 + VAR06 = x * x * x * x + VAR07 = x * x * x + VAR08 = x * x + VAR03 = VAR06 * VAR07 + VAR04 = VAR07 * VAR07 + VAR05 = VAR07 * VAR08 + VAR15 = y * y * y * y + VAR16 = y * y * y + VAR17 = y * y + VAR12 = VAR15 * VAR16 + VAR13 = VAR16 * VAR16 + VAR14 = VAR16 * VAR17 + VAR24 = z * z * z * z + VAR25 = z * z * z + VAR26 = z * z + VAR21 = VAR24 * VAR25 + VAR22 = VAR25 * VAR25 + VAR23 = VAR25 * VAR26 + # -------------------- kernel implementations + Y00 = ( + CONST059 * VAR07 * VAR24 + - CONST063 * VAR05 * VAR26 + - CONST073 * VAR22 * x + + CONST079 * VAR03 + ) + Y01 = y * (CONST029 * VAR23 * x + CONST030 * VAR05 * z + CONST048 * VAR07 * VAR25) + Y02 = ( + CONST050 * VAR03 + + VAR05 * (CONST010 * VAR26 + CONST014 * VAR17) + + VAR07 * (CONST045 * VAR17 * VAR26 - CONST076 * VAR24) + + x * (CONST038 * VAR17 * VAR24 + CONST076 * VAR22) + ) + Y03 = VAR16 * (CONST041 * VAR25 * x + CONST053 * VAR07 * z) + y * ( + -CONST064 * VAR05 * z + CONST064 * VAR23 * x + ) + Y04 = ( + CONST042 * VAR03 + + VAR05 * (-CONST042 * VAR26 - CONST070 * VAR17) + + VAR07 * (CONST061 * VAR17 * VAR26 + CONST065 * VAR15 - CONST068 * VAR24) + + x * (-CONST023 * VAR22 - CONST055 * VAR15 * VAR26 + CONST058 * VAR17 * VAR24) + ) + Y05 = ( + CONST015 * VAR05 * y * z + + VAR07 * (CONST025 * VAR25 * y + CONST057 * VAR16 * z) + + x * (CONST015 * VAR23 * y + CONST033 * VAR14 * z + CONST056 * VAR16 * VAR25) + ) + Y06 = ( + CONST047 * VAR03 + + VAR05 * (CONST020 * VAR17 + CONST078 * VAR26) + + VAR07 * (CONST052 * VAR24 + CONST060 * VAR15 - CONST060 * VAR17 * VAR26) + + x + * ( + CONST012 * VAR13 + + CONST019 * VAR17 * VAR24 + + CONST060 * VAR15 * VAR26 + + CONST080 * VAR22 + ) + ) + Y07 = ( + CONST002 * VAR12 + + VAR14 * (CONST066 * VAR08 + CONST067 * VAR26) + + VAR16 * (CONST026 * VAR06 + CONST026 * VAR24 + CONST037 * VAR08 * VAR26) + + y + * ( + CONST071 * VAR06 * VAR26 + + CONST072 * VAR08 * VAR24 + + CONST077 * VAR04 + + CONST077 * VAR22 + ) + ) + Y08 = ( + CONST047 * VAR21 + + VAR23 * (CONST020 * VAR17 + CONST052 * VAR08) + + VAR25 * (CONST052 * VAR06 - CONST060 * VAR08 * VAR17 + CONST060 * VAR15) + + z + * ( + CONST013 * VAR13 + + CONST021 * VAR06 * VAR17 + + CONST047 * VAR04 + + CONST060 * VAR08 * VAR15 + ) + ) + Y09 = ( + VAR14 * (CONST069 * VAR08 - CONST069 * VAR26) + + VAR16 * (-CONST062 * VAR06 + CONST062 * VAR24) + + y + * ( + CONST008 * VAR08 * VAR24 + + CONST074 * VAR04 + + CONST074 * VAR06 * VAR26 + - CONST074 * VAR22 + ) + ) + Y10 = ( + -CONST042 * VAR21 + + VAR23 * (CONST044 * VAR08 + CONST070 * VAR17) + + VAR25 * (CONST032 * VAR08 * VAR17 - CONST065 * VAR15 + CONST068 * VAR06) + + z * (CONST023 * VAR04 + CONST055 * VAR08 * VAR15 - CONST058 * VAR06 * VAR17) + ) + Y11 = VAR16 * ( + CONST017 * VAR06 + CONST017 * VAR24 + CONST046 * VAR08 * VAR26 + ) + y * ( + CONST028 * VAR06 * VAR26 + + CONST028 * VAR08 * VAR24 + + CONST075 * VAR04 + + CONST075 * VAR22 + ) + Y12 = ( + CONST051 * VAR21 + + VAR23 * (CONST010 * VAR08 + CONST014 * VAR17) + + VAR25 * (CONST045 * VAR08 * VAR17 - CONST049 * VAR06) + + z * (CONST038 * VAR06 * VAR17 + CONST049 * VAR04) + ) + Y13 = y * ( + CONST043 * VAR04 + - CONST043 * VAR22 + - CONST054 * VAR06 * VAR26 + + CONST054 * VAR08 * VAR24 + ) + Y14 = ( + -CONST059 * VAR06 * VAR25 + + CONST063 * VAR08 * VAR23 + + CONST073 * VAR04 * z + - CONST079 * VAR21 + ) + # not the prettiest way to concatenate, but better than + # messing with the linter + tensors = [ + Y00, + Y01, + Y02, + Y03, + Y04, + Y05, + Y06, + Y07, + Y08, + Y09, + Y10, + Y11, + Y12, + Y13, + Y14, + ] + return torch.cat(tensors, dim=-1) + + +@triton.jit +def seventh_order_fwd( + coord_ptr: tl.tensor, + output_ptr: tl.tensor, + block_size: tl.constexpr, + coord_numel: tl.constexpr, + output_numel: tl.constexpr, + col_offset: tl.constexpr, + output_stride: tl.constexpr, +): + # these are hardcoded because they are predetermined; + coord_stride = 3 + # work out the row offsets + block_id = tl.program_id(0) + coord_striding = tl.arange(0, block_size) * coord_stride + # as the name suggests, this is effectively every node/atom + coord_row_offset = coord_striding + (block_size * coord_stride * block_id) + x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel) + y = tl.load( + coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + z = tl.load( + coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + # -------------------- variable and constant definitions + CONST002 = 3.87298334620742 + CONST008 = 11.7655316231354 + CONST010 = 16.5555704843566 + CONST012 = 20.4939015319192 + CONST013 = 20.4939015319192 + CONST014 = 22.0740939791422 + CONST015 = 23.5310632462709 + CONST017 = 36.7901566319036 + CONST019 = 38.4260653723485 + CONST020 = 38.4260653723485 + CONST021 = 38.4260653723485 + CONST023 = -4.99169231699030 + CONST025 = 47.0621264925418 + CONST026 = 50.8329064189723 + CONST028 = 55.1852349478554 + CONST029 = 56.2781179722634 + CONST030 = 56.2781179722634 + CONST032 = 66.5558975598707 + CONST033 = 75.2994023880668 + CONST037 = 101.665812837945 + CONST038 = 110.370469895711 + CONST041 = 147.160626527614 + CONST042 = -1.66389743899677 + CONST043 = -9.37968632871057 + CONST044 = -1.66389743899677 + CONST045 = -220.740939791422 + CONST046 = -220.740939791422 + CONST047 = -1.60108605718119 + CONST048 = -187.593726574211 + CONST049 = -9.19753915797590 + CONST050 = -1.83950783159518 + CONST051 = -1.83950783159518 + CONST052 = -4.80325817154356 + CONST053 = -147.160626527614 + CONST054 = -140.695294930659 + CONST055 = -133.111795119741 + CONST056 = -125.499003980111 + CONST057 = -125.499003980111 + CONST058 = -99.8338463398060 + CONST059 = -87.7389315936062 + CONST060 = -76.8521307446970 + CONST061 = -66.5558975598707 + CONST062 = -62.7495019900557 + CONST063 = -52.6433589561637 + CONST064 = -44.1481879582843 + CONST065 = -44.3705983732471 + CONST066 = -40.6663251351779 + CONST067 = -40.6663251351779 + CONST068 = -8.31948719498384 + CONST069 = -37.6497011940334 + CONST070 = -33.2779487799353 + CONST071 = -25.4164532094862 + CONST072 = -25.4164532094862 + CONST073 = -17.5477863187212 + CONST074 = -11.7655316231354 + CONST075 = -11.0370469895711 + CONST076 = -9.19753915797590 + CONST077 = -8.47215106982872 + CONST078 = -4.80325817154356 + CONST079 = -2.50682661696018 + CONST080 = -1.60108605718119 + VAR06 = x * x * x * x + VAR07 = x * x * x + VAR08 = x * x + VAR03 = VAR06 * VAR07 + VAR04 = VAR07 * VAR07 + VAR05 = VAR07 * VAR08 + VAR15 = y * y * y * y + VAR16 = y * y * y + VAR17 = y * y + VAR12 = VAR15 * VAR16 + VAR13 = VAR16 * VAR16 + VAR14 = VAR16 * VAR17 + VAR24 = z * z * z * z + VAR25 = z * z * z + VAR26 = z * z + VAR21 = VAR24 * VAR25 + VAR22 = VAR25 * VAR25 + VAR23 = VAR25 * VAR26 + # -------------------- kernel implementations + Y00 = ( + CONST059 * VAR07 * VAR24 + - CONST063 * VAR05 * VAR26 + - CONST073 * VAR22 * x + + CONST079 * VAR03 + ) + Y01 = y * (CONST029 * VAR23 * x + CONST030 * VAR05 * z + CONST048 * VAR07 * VAR25) + Y02 = ( + CONST050 * VAR03 + + VAR05 * (CONST010 * VAR26 + CONST014 * VAR17) + + VAR07 * (CONST045 * VAR17 * VAR26 - CONST076 * VAR24) + + x * (CONST038 * VAR17 * VAR24 + CONST076 * VAR22) + ) + Y03 = VAR16 * (CONST041 * VAR25 * x + CONST053 * VAR07 * z) + y * ( + -CONST064 * VAR05 * z + CONST064 * VAR23 * x + ) + Y04 = ( + CONST042 * VAR03 + + VAR05 * (-CONST042 * VAR26 - CONST070 * VAR17) + + VAR07 * (CONST061 * VAR17 * VAR26 + CONST065 * VAR15 - CONST068 * VAR24) + + x * (-CONST023 * VAR22 - CONST055 * VAR15 * VAR26 + CONST058 * VAR17 * VAR24) + ) + Y05 = ( + CONST015 * VAR05 * y * z + + VAR07 * (CONST025 * VAR25 * y + CONST057 * VAR16 * z) + + x * (CONST015 * VAR23 * y + CONST033 * VAR14 * z + CONST056 * VAR16 * VAR25) + ) + Y06 = ( + CONST047 * VAR03 + + VAR05 * (CONST020 * VAR17 + CONST078 * VAR26) + + VAR07 * (CONST052 * VAR24 + CONST060 * VAR15 - CONST060 * VAR17 * VAR26) + + x + * ( + CONST012 * VAR13 + + CONST019 * VAR17 * VAR24 + + CONST060 * VAR15 * VAR26 + + CONST080 * VAR22 + ) + ) + Y07 = ( + CONST002 * VAR12 + + VAR14 * (CONST066 * VAR08 + CONST067 * VAR26) + + VAR16 * (CONST026 * VAR06 + CONST026 * VAR24 + CONST037 * VAR08 * VAR26) + + y + * ( + CONST071 * VAR06 * VAR26 + + CONST072 * VAR08 * VAR24 + + CONST077 * VAR04 + + CONST077 * VAR22 + ) + ) + Y08 = ( + CONST047 * VAR21 + + VAR23 * (CONST020 * VAR17 + CONST052 * VAR08) + + VAR25 * (CONST052 * VAR06 - CONST060 * VAR08 * VAR17 + CONST060 * VAR15) + + z + * ( + CONST013 * VAR13 + + CONST021 * VAR06 * VAR17 + + CONST047 * VAR04 + + CONST060 * VAR08 * VAR15 + ) + ) + Y09 = ( + VAR14 * (CONST069 * VAR08 - CONST069 * VAR26) + + VAR16 * (-CONST062 * VAR06 + CONST062 * VAR24) + + y + * ( + CONST008 * VAR08 * VAR24 + + CONST074 * VAR04 + + CONST074 * VAR06 * VAR26 + - CONST074 * VAR22 + ) + ) + Y10 = ( + -CONST042 * VAR21 + + VAR23 * (CONST044 * VAR08 + CONST070 * VAR17) + + VAR25 * (CONST032 * VAR08 * VAR17 - CONST065 * VAR15 + CONST068 * VAR06) + + z * (CONST023 * VAR04 + CONST055 * VAR08 * VAR15 - CONST058 * VAR06 * VAR17) + ) + Y11 = VAR16 * ( + CONST017 * VAR06 + CONST017 * VAR24 + CONST046 * VAR08 * VAR26 + ) + y * ( + CONST028 * VAR06 * VAR26 + + CONST028 * VAR08 * VAR24 + + CONST075 * VAR04 + + CONST075 * VAR22 + ) + Y12 = ( + CONST051 * VAR21 + + VAR23 * (CONST010 * VAR08 + CONST014 * VAR17) + + VAR25 * (CONST045 * VAR08 * VAR17 - CONST049 * VAR06) + + z * (CONST038 * VAR06 * VAR17 + CONST049 * VAR04) + ) + Y13 = y * ( + CONST043 * VAR04 + - CONST043 * VAR22 + - CONST054 * VAR06 * VAR26 + + CONST054 * VAR08 * VAR24 + ) + Y14 = ( + -CONST059 * VAR06 * VAR25 + + CONST063 * VAR08 * VAR23 + + CONST073 * VAR04 * z + - CONST079 * VAR21 + ) + output_striding = tl.arange(0, block_size) * output_stride + output_row_offset = ( + output_striding + (block_size * output_stride * block_id) + col_offset + ) + tl.store(output_ptr + output_row_offset, Y00, mask=output_row_offset < output_numel) + tl.store( + output_ptr + output_row_offset + 1, + Y01, + mask=output_row_offset + 1 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 2, + Y02, + mask=output_row_offset + 2 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 3, + Y03, + mask=output_row_offset + 3 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 4, + Y04, + mask=output_row_offset + 4 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 5, + Y05, + mask=output_row_offset + 5 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 6, + Y06, + mask=output_row_offset + 6 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 7, + Y07, + mask=output_row_offset + 7 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 8, + Y08, + mask=output_row_offset + 8 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 9, + Y09, + mask=output_row_offset + 9 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 10, + Y10, + mask=output_row_offset + 10 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 11, + Y11, + mask=output_row_offset + 11 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 12, + Y12, + mask=output_row_offset + 12 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 13, + Y13, + mask=output_row_offset + 13 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 14, + Y14, + mask=output_row_offset + 14 < output_numel, + ) + + +@triton.jit +def seventh_order_bwd( + coord_ptr: tl.tensor, + coord_grad_ptr: tl.tensor, + sph_grad_ptr: tl.tensor, + block_size: tl.constexpr, + coord_numel: tl.constexpr, + output_numel: tl.constexpr, + col_offset: tl.constexpr, + output_stride: tl.constexpr, +): + # work out the row offsets + block_id = tl.program_id(0) + # these are hardcoded because they are predetermined; + coord_stride = 3 + coord_striding = tl.arange(0, block_size) * coord_stride + # as the name suggests, this is effectively every node/atom + coord_row_offset = coord_striding + (block_size * coord_stride * block_id) + x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel) + y = tl.load( + coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + z = tl.load( + coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + output_striding = tl.arange(0, block_size) * output_stride + output_row_offset = ( + output_striding + (block_size * output_stride * block_id) + col_offset + ) + # load in gradients w.r.t. spherical harmonic projections + g_0 = tl.load( + sph_grad_ptr + output_row_offset, mask=output_row_offset < output_numel + ) + g_1 = tl.load( + sph_grad_ptr + output_row_offset + 1, mask=output_row_offset + 1 < output_numel + ) + g_2 = tl.load( + sph_grad_ptr + output_row_offset + 2, mask=output_row_offset + 2 < output_numel + ) + g_3 = tl.load( + sph_grad_ptr + output_row_offset + 3, mask=output_row_offset + 3 < output_numel + ) + g_4 = tl.load( + sph_grad_ptr + output_row_offset + 4, mask=output_row_offset + 4 < output_numel + ) + g_5 = tl.load( + sph_grad_ptr + output_row_offset + 5, mask=output_row_offset + 5 < output_numel + ) + g_6 = tl.load( + sph_grad_ptr + output_row_offset + 6, mask=output_row_offset + 6 < output_numel + ) + g_7 = tl.load( + sph_grad_ptr + output_row_offset + 7, mask=output_row_offset + 7 < output_numel + ) + g_8 = tl.load( + sph_grad_ptr + output_row_offset + 8, mask=output_row_offset + 8 < output_numel + ) + g_9 = tl.load( + sph_grad_ptr + output_row_offset + 9, mask=output_row_offset + 9 < output_numel + ) + g_10 = tl.load( + sph_grad_ptr + output_row_offset + 10, + mask=output_row_offset + 10 < output_numel, + ) + g_11 = tl.load( + sph_grad_ptr + output_row_offset + 11, + mask=output_row_offset + 11 < output_numel, + ) + g_12 = tl.load( + sph_grad_ptr + output_row_offset + 12, + mask=output_row_offset + 12 < output_numel, + ) + g_13 = tl.load( + sph_grad_ptr + output_row_offset + 13, + mask=output_row_offset + 13 < output_numel, + ) + g_14 = tl.load( + sph_grad_ptr + output_row_offset + 14, + mask=output_row_offset + 14 < output_numel, + ) + # -------------------- variable and constant definitions + CONST000 = 1.66389743899677 + CONST001 = 3.00000000000000 + CONST003 = 5.00000000000000 + CONST004 = 3.32779487799353 + CONST009 = 11.7655316231354 + CONST012 = 16.5555704843566 + CONST014 = 20.4939015319192 + CONST016 = 22.0740939791422 + CONST018 = 23.5310632462709 + CONST019 = 20.4939015319192 + CONST020 = 27.1108834234519 + CONST022 = 33.1111409687132 + CONST024 = 36.7901566319036 + CONST025 = 36.7901566319036 + CONST026 = 38.4260653723485 + CONST027 = 38.4260653723485 + CONST029 = 38.4260653723485 + CONST030 = 44.1481879582843 + CONST032 = -4.99169231699030 + CONST037 = 47.0621264925417 + CONST039 = 56.2781179722634 + CONST044 = -441.481879582843 + CONST045 = -441.481879582843 + CONST048 = 76.8521307446970 + CONST049 = 76.8521307446970 + CONST050 = -8.47215106982872 + CONST054 = 110.370469895711 + CONST055 = 110.370469895711 + CONST056 = -399.335385359224 + CONST057 = 117.655316231354 + CONST058 = 122.963409191515 + CONST059 = 122.963409191515 + CONST061 = -376.497011940334 + CONST062 = -376.497011940334 + CONST064 = 141.186379477625 + CONST066 = 147.160626527614 + CONST067 = 153.704261489394 + CONST069 = -350.955726374425 + CONST072 = 203.331625675889 + CONST073 = 203.331625675889 + CONST074 = -307.408522978788 + CONST075 = -9.60651634308713 + CONST076 = -9.37968632871057 + CONST079 = -281.390589861317 + CONST080 = -1.66389743899677 + CONST081 = -266.223590239483 + CONST082 = -263.216794780819 + CONST084 = -263.216794780818 + CONST085 = -250.998007960223 + CONST089 = 281.390589861317 + CONST091 = -220.740939791422 + CONST092 = -220.740939791422 + CONST093 = -199.667692679612 + CONST094 = -1.60108605718119 + CONST095 = -187.593726574211 + CONST096 = -177.482393492989 + CONST097 = -9.60651634308712 + CONST098 = -9.19753915797590 + CONST100 = -153.704261489394 + CONST101 = -147.160626527614 + CONST102 = -140.695294930659 + CONST104 = -133.111795119741 + CONST105 = -133.111795119741 + CONST106 = -125.499003980111 + CONST107 = -125.499003980111 + CONST109 = -105.286717912327 + CONST110 = -101.665812837945 + CONST111 = -99.8338463398060 + CONST112 = -101.665812837945 + CONST113 = -4.80325817154356 + CONST114 = -81.3326502703558 + CONST115 = -81.3326502703557 + CONST116 = -76.8521307446970 + CONST117 = -75.2994023880668 + CONST119 = -70.5931897388126 + CONST121 = -66.2222819374265 + CONST122 = -66.5558975598707 + CONST123 = -66.5558975598707 + CONST124 = -62.7495019900557 + CONST125 = -56.2781179722634 + CONST126 = -55.1852349478554 + CONST127 = -55.1852349478554 + CONST128 = -50.8329064189723 + CONST129 = -50.8329064189723 + CONST130 = -562.781179722634 + CONST131 = -47.0621264925418 + CONST132 = -50.8329064189724 + CONST133 = -44.1481879582843 + CONST134 = -44.3705983732471 + CONST135 = -40.6663251351779 + CONST136 = -40.6663251351779 + CONST137 = -8.31948719498384 + CONST138 = -37.6497011940334 + CONST139 = -33.2779487799353 + CONST140 = -29.9501539019418 + CONST141 = -25.4164532094862 + CONST142 = -25.4164532094862 + CONST143 = -23.5310632462709 + CONST144 = -532.447180478965 + CONST145 = -19.2130326861743 + CONST146 = -17.5477863187212 + CONST147 = -12.8765548211663 + CONST148 = -11.6472820729774 + CONST149 = -11.2076024002683 + CONST150 = -9.19753915797590 + CONST151 = -11.0370469895711 + CONST152 = -11.7655316231354 + CONST153 = -12.8765548211663 + CONST154 = -4.80325817154356 + CONST155 = -3.32779487799353 + CONST156 = -1.60108605718119 + VAR06 = x * x * x * x + VAR07 = x * x * x + VAR08 = x * x + VAR04 = VAR07 * VAR07 + VAR05 = VAR07 * VAR08 + VAR16 = y * y * y + VAR17 = y * y + VAR13 = VAR16 * VAR16 + VAR14 = VAR16 * VAR17 + VAR15 = VAR17 * VAR17 + VAR25 = z * z * z + VAR26 = z * z + VAR22 = VAR25 * VAR25 + VAR23 = VAR25 * VAR26 + VAR24 = VAR26 * VAR26 + # -------------------- kernel implementations + g_x = tl.load( + coord_grad_ptr + coord_row_offset, mask=coord_row_offset < coord_numel + ) + g_y = tl.load( + coord_grad_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + g_z = tl.load( + coord_grad_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + g_x += ( + g_0 + * ( + CONST082 * VAR08 * VAR24 + - CONST084 * VAR06 * VAR26 + + CONST146 * VAR04 + - CONST146 * VAR22 + ) + + g_1 * y * (CONST039 * VAR23 + CONST089 * VAR06 * z + CONST130 * VAR08 * VAR25) + + g_10 + * ( + CONST155 * VAR23 * x + + VAR25 * (-CONST105 * VAR17 * x + CONST139 * VAR07) + + z * (-CONST056 * VAR07 * VAR17 + CONST081 * VAR15 * x + CONST140 * VAR05) + ) + + g_11 + * ( + VAR16 * (CONST044 * VAR26 * x - CONST101 * VAR07) + + y * (CONST054 * VAR24 * x - CONST091 * VAR07 * VAR26 + CONST121 * VAR05) + ) + + g_12 + * ( + CONST022 * VAR23 * x + + VAR25 * (CONST024 * VAR07 + CONST045 * VAR17 * x) + + z * (-CONST044 * VAR07 * VAR17 + CONST126 * VAR05) + ) + + g_13 + * y + * (CONST079 * VAR24 * x + CONST125 * VAR05 - CONST130 * VAR07 * VAR26) + + g_14 + * (-CONST069 * VAR07 * VAR25 + CONST109 * VAR05 * z + CONST109 * VAR23 * x) + + g_2 + * ( + CONST001 * VAR08 * (CONST091 * VAR17 * VAR26 - CONST150 * VAR24) + + CONST003 * VAR06 * (CONST012 * VAR26 + CONST016 * VAR17) + + CONST055 * VAR17 * VAR24 + + CONST147 * VAR04 + + CONST150 * VAR22 + ) + + g_3 + * ( + VAR16 * (CONST044 * VAR08 * z + CONST066 * VAR25) + + y * (-CONST091 * VAR06 * z + CONST133 * VAR23) + ) + + g_4 + * ( + CONST001 + * VAR08 + * (CONST122 * VAR17 * VAR26 + CONST134 * VAR15 - CONST137 * VAR24) + + CONST003 * VAR06 * (CONST000 * VAR26 - CONST139 * VAR17) + - CONST032 * VAR22 + - CONST105 * VAR15 * VAR26 + + CONST111 * VAR17 * VAR24 + + CONST148 * VAR04 + ) + + g_5 + * ( + CONST001 * VAR08 * (CONST106 * VAR16 * z - CONST131 * VAR25 * y) + + CONST057 * VAR06 * y * z + + CONST107 * VAR16 * VAR25 + - CONST117 * VAR14 * z + - CONST143 * VAR23 * y + ) + + g_6 + * ( + CONST001 + * VAR08 + * (CONST116 * VAR15 - CONST116 * VAR17 * VAR26 + CONST154 * VAR24) + + CONST003 * VAR06 * (CONST026 * VAR17 + CONST113 * VAR26) + + CONST014 * VAR13 + + CONST027 * VAR17 * VAR24 + + CONST116 * VAR15 * VAR26 + + CONST149 * VAR04 + + CONST156 * VAR22 + ) + + g_7 + * ( + CONST114 * VAR14 * x + + VAR16 * (CONST072 * VAR07 + CONST073 * VAR26 * x) + + y * (CONST110 * VAR07 * VAR26 + CONST128 * VAR05 + CONST129 * VAR24 * x) + ) + + g_8 + * ( + CONST075 * VAR23 * x + + VAR25 * (-CONST100 * VAR17 * x + CONST145 * VAR07) + + z * (CONST067 * VAR07 * VAR17 + CONST097 * VAR05 + CONST100 * VAR15 * x) + ) + + g_9 + * ( + -CONST085 * VAR07 * VAR16 + + CONST117 * VAR14 * x + + y * (CONST018 * VAR24 * x + CONST119 * VAR05 + CONST131 * VAR07 * VAR26) + ) + ) + g_y += ( + g_1 * (CONST039 * VAR23 * x + CONST095 * VAR07 * VAR25 - CONST125 * VAR05 * z) + + g_10 + * ( + CONST123 * VAR23 * y + + VAR25 * (-CONST096 * VAR16 - CONST105 * VAR08 * y) + + z * (-CONST093 * VAR06 * y + CONST144 * VAR08 * VAR16) + ) + + g_11 + * ( + CONST001 + * VAR17 + * (CONST025 * VAR06 + CONST025 * VAR24 + CONST092 * VAR08 * VAR26) + - CONST126 * VAR06 * VAR26 + - CONST126 * VAR08 * VAR24 + + CONST151 * VAR04 + + CONST151 * VAR22 + ) + + g_12 + * ( + CONST030 * VAR23 * y + + CONST045 * VAR08 * VAR25 * y + - CONST092 * VAR06 * y * z + ) + + g_13 + * ( + CONST076 * VAR04 + - CONST076 * VAR22 + - CONST102 * VAR06 * VAR26 + + CONST102 * VAR08 * VAR24 + ) + + g_2 + * ( + CONST030 * VAR05 * y + + CONST045 * VAR07 * VAR26 * y + - CONST092 * VAR24 * x * y + ) + + g_3 + * ( + CONST001 * VAR17 * (CONST066 * VAR25 * x + CONST101 * VAR07 * z) + - CONST133 * VAR05 * z + + CONST133 * VAR23 * x + ) + + g_4 + * ( + -CONST123 * VAR05 * y + + VAR07 * (CONST096 * VAR16 + CONST104 * VAR26 * y) + + x * (CONST093 * VAR24 * y - CONST144 * VAR16 * VAR26) + ) + + g_5 + * ( + -CONST143 * VAR05 * z + + VAR07 * (CONST062 * VAR17 * z - CONST131 * VAR25) + + x * (CONST061 * VAR17 * VAR25 - CONST062 * VAR15 * z - CONST143 * VAR23) + ) + + g_6 + * ( + CONST048 * VAR05 * y + + VAR07 * (CONST074 * VAR16 - CONST100 * VAR26 * y) + + x * (CONST058 * VAR14 + CONST074 * VAR16 * VAR26 - CONST116 * VAR24 * y) + ) + + g_7 + * ( + CONST001 + * VAR17 + * (-CONST112 * VAR08 * VAR26 - CONST128 * VAR06 - CONST128 * VAR24) + + CONST003 * VAR15 * (CONST135 * VAR08 + CONST136 * VAR26) + + CONST020 * VAR13 + + CONST050 * VAR04 + + CONST050 * VAR22 + + CONST141 * VAR06 * VAR26 + + CONST142 * VAR08 * VAR24 + ) + + g_8 + * ( + CONST048 * VAR23 * y + + VAR25 * (CONST074 * VAR16 - CONST100 * VAR08 * y) + + z * (CONST049 * VAR06 * y + CONST059 * VAR14 + CONST074 * VAR08 * VAR16) + ) + + g_9 + * ( + CONST001 * VAR17 * (-CONST124 * VAR06 + CONST124 * VAR24) + + CONST003 * VAR15 * (CONST138 * VAR08 - CONST138 * VAR26) + + CONST009 * VAR08 * VAR24 + + CONST152 * VAR04 + + CONST152 * VAR06 * VAR26 + - CONST152 * VAR22 + ) + ) + g_z += ( + g_0 * (CONST069 * VAR07 * VAR25 - CONST109 * VAR05 * z - CONST109 * VAR23 * x) + + g_1 + * y + * (-CONST079 * VAR24 * x - CONST125 * VAR05 + CONST130 * VAR07 * VAR26) + + g_10 + * ( + CONST001 + * VAR26 + * (-CONST123 * VAR08 * VAR17 - CONST134 * VAR15 + CONST137 * VAR06) + + CONST003 * VAR24 * (CONST080 * VAR08 + CONST139 * VAR17) + + CONST032 * VAR04 + + CONST105 * VAR08 * VAR15 + - CONST111 * VAR06 * VAR17 + - CONST148 * VAR22 + ) + + g_11 + * ( + VAR16 * (CONST044 * VAR08 * z - CONST101 * VAR25) + + y * (CONST054 * VAR06 * z - CONST091 * VAR08 * VAR25 + CONST121 * VAR23) + ) + + g_12 + * ( + CONST001 * VAR26 * (CONST091 * VAR08 * VAR17 - CONST098 * VAR06) + + CONST003 * VAR24 * (CONST012 * VAR08 + CONST016 * VAR17) + + CONST055 * VAR06 * VAR17 + + CONST098 * VAR04 + + CONST153 * VAR22 + ) + + g_13 + * y + * (-CONST079 * VAR06 * z - CONST125 * VAR23 + CONST130 * VAR08 * VAR25) + + g_14 + * ( + -CONST082 * VAR06 * VAR26 + + CONST084 * VAR08 * VAR24 + + CONST146 * VAR04 + - CONST146 * VAR22 + ) + + g_2 + * ( + CONST022 * VAR05 * z + + VAR07 * (CONST025 * VAR25 + CONST045 * VAR17 * z) + + x * (-CONST044 * VAR17 * VAR25 + CONST127 * VAR23) + ) + + g_3 + * ( + VAR16 * (-CONST045 * VAR26 * x + CONST101 * VAR07) + + y * (CONST091 * VAR24 * x - CONST133 * VAR05) + ) + + g_4 + * ( + CONST004 * VAR05 * z + + VAR07 * (CONST104 * VAR17 * z - CONST139 * VAR25) + + x * (CONST056 * VAR17 * VAR25 - CONST081 * VAR15 * z - CONST140 * VAR23) + ) + + g_5 + * ( + -CONST143 * VAR05 * y + + VAR07 * (CONST064 * VAR26 * y + CONST106 * VAR16) + + x * (CONST057 * VAR24 * y + CONST061 * VAR16 * VAR26 - CONST117 * VAR14) + ) + + g_6 + * ( + CONST097 * VAR05 * z + + VAR07 * (-CONST100 * VAR17 * z + CONST145 * VAR25) + + x * (CONST075 * VAR23 + CONST100 * VAR15 * z - CONST100 * VAR17 * VAR25) + ) + + g_7 + * ( + CONST115 * VAR14 * z + + VAR16 * (CONST072 * VAR25 + CONST073 * VAR08 * z) + + y * (CONST112 * VAR08 * VAR25 + CONST128 * VAR23 + CONST132 * VAR06 * z) + ) + + g_8 + * ( + CONST001 + * VAR26 + * (-CONST116 * VAR08 * VAR17 + CONST116 * VAR15 + CONST154 * VAR06) + + CONST003 * VAR24 * (CONST026 * VAR17 + CONST154 * VAR08) + + CONST019 * VAR13 + + CONST029 * VAR06 * VAR17 + + CONST094 * VAR04 + + CONST116 * VAR08 * VAR15 + + CONST149 * VAR22 + ) + + g_9 + * ( + CONST085 * VAR16 * VAR25 + - CONST117 * VAR14 * z + + y * (CONST037 * VAR08 * VAR25 - CONST119 * VAR23 + CONST143 * VAR06 * z) + ) + ) + # write out gradients + tl.store( + coord_grad_ptr + coord_row_offset, g_x, mask=coord_row_offset < coord_numel + ) + tl.store( + coord_grad_ptr + coord_row_offset + 1, + g_y, + mask=coord_row_offset + 1 < coord_numel, + ) + tl.store( + coord_grad_ptr + coord_row_offset + 2, + g_z, + mask=coord_row_offset + 2 < coord_numel, + ) diff --git a/src/equitriton/sph_harm/direct/y_8.py b/src/equitriton/sph_harm/direct/y_8.py new file mode 100644 index 0000000..ae987c4 --- /dev/null +++ b/src/equitriton/sph_harm/direct/y_8.py @@ -0,0 +1,1529 @@ +import triton +import torch +from triton import language as tl + +from equitriton.utils import calculate_lastdim_num_blocks + +__all__ = ["EighthOrderSphericalHarmonic"] + + +class EighthOrderSphericalHarmonic(torch.autograd.Function): + @staticmethod + def forward( + ctx, + coords: torch.Tensor, + output_tensor: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + block_size: int = 64, + col_offset: int = 0, + ): + if not isinstance(output_tensor, torch.Tensor): + output_tensor = torch.empty( + (*coords.shape[:-1], 17), dtype=coords.dtype, device=coords.device + ) + coord_numel = coords.numel() + output_numel = output_tensor.numel() + num_blocks = calculate_lastdim_num_blocks(coords, block_size) + # apply the kernel + eighth_order_fwd[num_blocks,]( + coords, + output_tensor, + block_size, + coord_numel, + output_numel, + col_offset, + output_tensor.stride(-2), + ) + ctx.save_for_backward(coords) + return output_tensor + + @staticmethod + def backward( + ctx, + sph_grad_tensor: torch.Tensor, + block_size: int = 64, + col_offset: int = 0, + ) -> torch.Tensor: + (coords,) = ctx.saved_tensors + coord_grad_output = torch.zeros_like(coords) + num_blocks = calculate_lastdim_num_blocks(coords, block_size) + # call backward kernel + eighth_order_bwd[num_blocks,]( + coords, + coord_grad_output, + sph_grad_tensor, + block_size, + coords.numel(), + sph_grad_tensor.numel(), + col_offset, + sph_grad_tensor.stride(-2), + ) + return coord_grad_output + + +def _torch_fwd(coords: torch.Tensor) -> torch.Tensor: + """ + PyTorch implementation of the kernel. This is designed + purely for unit testing to ensure that the Triton implementation + is behaving as intended. + + Parameters + ---------- + coords : torch.Tensor + N-d tensor, where the last dimension corresponds to + xyz values. + + Returns + ------- + torch.Tensor + N-d tensor, where the last dimension corresponds to + each projection of the second order spherical harmonic. + """ + x = coords[..., 0].contiguous().unsqueeze(-1) + y = coords[..., 1].contiguous().unsqueeze(-1) + z = coords[..., 2].contiguous().unsqueeze(-1) + CONST000 = 1.12741169450483 + CONST003 = 4.12310562561766 + CONST004 = 4.50964677801932 + CONST006 = 6.76447016702898 + CONST007 = 1.69594242329302 + CONST008 = 1.88707052233084 + CONST010 = 2.58397773170915 + CONST011 = 13.1367135230810 + CONST012 = 13.1367135230810 + CONST014 = -489.184589393411 + CONST015 = 24.7386337537060 + CONST017 = 24.7386337537060 + CONST019 = 48.9184589393411 + CONST020 = 48.5105296237322 + CONST021 = 51.7445649319810 + CONST024 = 65.6835676154051 + CONST025 = 67.8376969317208 + CONST029 = 97.0210592474644 + CONST030 = -6.78376969317208 + CONST031 = 103.489129863962 + CONST032 = -407.026181590325 + CONST033 = 108.231522672464 + CONST035 = 110.066532613517 + CONST036 = 110.066532613517 + CONST037 = -396.284809689477 + CONST040 = -361.756882439281 + CONST041 = -1.88707052233084 + CONST042 = 158.513923875791 + CONST045 = 180.878441219640 + CONST046 = 194.042118494929 + CONST047 = -12.2296147348353 + CONST048 = 203.513090795162 + CONST050 = 216.463045344927 + CONST051 = 217.054129463568 + CONST052 = 216.463045344927 + CONST053 = -6.78376969317208 + CONST054 = -271.350787726883 + CONST055 = 244.592294696706 + CONST056 = 244.592294696706 + CONST057 = -262.734270461621 + CONST058 = -258.722824659905 + CONST061 = -217.054129463568 + CONST062 = -210.187416369296 + CONST063 = -175.156180307747 + CONST064 = -162.810472636130 + CONST066 = -144.702752975712 + CONST067 = -129.877827206956 + CONST068 = -129.361412329953 + CONST070 = -108.231522672464 + CONST071 = -108.231522672464 + CONST072 = -87.5780901538735 + CONST073 = -3.23403530824881 + CONST074 = -72.3513764878561 + CONST075 = -70.0624721230988 + CONST076 = -65.6835676154052 + CONST077 = -61.1480736741764 + CONST078 = -61.1480736741764 + CONST079 = -57.7234787586472 + CONST080 = -57.7234787586472 + CONST081 = -51.7445649319810 + CONST082 = -48.5105296237322 + CONST083 = -40.5868210021738 + CONST084 = -39.4101405692431 + CONST085 = -40.7026181590325 + CONST086 = -36.0771742241545 + CONST087 = -36.0771742241545 + CONST088 = -26.4189873126318 + CONST089 = -20.6718218536732 + CONST090 = -528.379746252636 + CONST091 = -16.9594242329302 + CONST092 = -13.1367135230810 + CONST093 = -12.2296147348353 + CONST094 = -11.3224231339851 + CONST095 = -10.3359109268366 + CONST096 = -9.70210592474644 + CONST097 = -11.3224231339851 + CONST098 = -13.5289403340579 + CONST099 = -6.78376969317208 + CONST100 = -13.5289403340579 + CONST101 = -13.1367135230810 + CONST102 = -3.23403530824881 + CONST103 = -1.61701765412441 + VAR06 = x * x * x * x + VAR07 = x * x * x + VAR08 = x * x + VAR02 = VAR06 * VAR06 + VAR03 = VAR06 * VAR07 + VAR04 = VAR07 * VAR07 + VAR05 = VAR07 * VAR08 + VAR15 = y * y * y * y + VAR16 = y * y * y + VAR17 = y * y + VAR11 = VAR15 * VAR16 + VAR12 = VAR15 * VAR16 + VAR13 = VAR16 * VAR16 + VAR14 = VAR16 * VAR17 + VAR24 = z * z * z * z + VAR25 = z * z * z + VAR26 = z * z + VAR20 = VAR24 * VAR24 + VAR21 = VAR24 * VAR25 + VAR22 = VAR25 * VAR25 + VAR23 = VAR25 * VAR26 + # -------------------- kernel implementations + Y00 = ( + -CONST066 * VAR05 * VAR25 + + CONST066 * VAR07 * VAR23 + + CONST089 * VAR03 * z + - CONST089 * VAR21 * x + ) + Y01 = y * ( + CONST040 * VAR07 * VAR24 + + CONST051 * VAR05 * VAR26 + - CONST074 * VAR22 * x + + CONST095 * VAR03 + ) + Y02 = ( + CONST097 * VAR03 * z + + VAR05 * (CONST042 * VAR17 * z - CONST088 * VAR25) + + VAR07 * (-CONST088 * VAR23 + CONST090 * VAR17 * VAR25) + + x * (CONST042 * VAR17 * VAR23 + CONST094 * VAR21) + ) + Y03 = VAR16 * ( + CONST014 * VAR07 * VAR26 + CONST019 * VAR05 + CONST055 * VAR24 * x + ) + y * ( + CONST035 * VAR05 * VAR26 + + CONST077 * VAR22 * x + - CONST078 * VAR07 * VAR24 + + CONST093 * VAR03 + ) + Y04 = ( + CONST099 * VAR03 * z + + VAR05 * (-CONST064 * VAR17 * z + CONST099 * VAR25) + + VAR07 * (-CONST053 * VAR23 + CONST054 * VAR15 * z) + + x * (-CONST053 * VAR21 - CONST054 * VAR15 * VAR25 + CONST064 * VAR17 * VAR23) + ) + Y05 = ( + VAR14 * (-CONST062 * VAR26 * x + CONST075 * VAR07) + + VAR16 * (CONST057 * VAR24 * x + CONST063 * VAR07 * VAR26 - CONST072 * VAR05) + + y + * ( + CONST011 * VAR05 * VAR26 + + CONST024 * VAR07 * VAR24 + - CONST084 * VAR22 * x + + CONST092 * VAR03 + ) + ) + Y06 = ( + CONST102 * VAR03 * z + + VAR05 * (CONST029 * VAR17 * z + CONST096 * VAR25) + + VAR07 * (CONST046 * VAR17 * VAR25 + CONST058 * VAR15 * z + CONST096 * VAR23) + + x + * ( + CONST029 * VAR17 * VAR23 + + CONST031 * VAR13 * z + + CONST058 * VAR15 * VAR25 + + CONST102 * VAR21 + ) + ) + Y07 = ( + CONST098 * VAR03 * y + + VAR05 * (CONST033 * VAR16 + CONST083 * VAR26 * y) + + VAR07 * (CONST050 * VAR16 * VAR26 + CONST067 * VAR14 + CONST083 * VAR24 * y) + + x + * ( + CONST015 * VAR12 + + CONST067 * VAR14 * VAR26 + - CONST070 * VAR16 * VAR24 + + CONST098 * VAR22 * y + ) + ) + Y08 = ( + CONST000 * VAR02 + + CONST000 * VAR20 + + CONST003 * VAR11 + - CONST070 * VAR15 * VAR24 + + CONST080 * VAR13 * VAR26 + + CONST087 * VAR17 * VAR22 + + VAR04 * (CONST004 * VAR26 + CONST086 * VAR17) + + VAR06 * (CONST006 * VAR24 - CONST070 * VAR15 + CONST071 * VAR17 * VAR26) + + VAR08 + * ( + CONST004 * VAR22 + + CONST050 * VAR15 * VAR26 + + CONST070 * VAR17 * VAR24 + + CONST079 * VAR13 + ) + ) + Y09 = ( + CONST098 * VAR21 * y + + VAR23 * (CONST033 * VAR16 + CONST083 * VAR08 * y) + + VAR25 * (CONST052 * VAR08 * VAR16 + CONST067 * VAR14 + CONST083 * VAR06 * y) + + z + * ( + CONST017 * VAR12 + + CONST033 * VAR06 * VAR16 + + CONST067 * VAR08 * VAR14 + + CONST100 * VAR04 * y + ) + ) + Y10 = ( + CONST073 * VAR08 * VAR22 + - CONST102 * VAR04 * VAR26 + - CONST103 * VAR02 + + CONST103 * VAR20 + + VAR13 * (CONST021 * VAR26 + CONST081 * VAR08) + + VAR15 * (-CONST068 * VAR06 + CONST068 * VAR24) + + VAR17 + * ( + CONST020 * VAR08 * VAR24 + + CONST020 * VAR22 + + CONST082 * VAR04 + + CONST082 * VAR06 * VAR26 + ) + ) + Y11 = ( + VAR14 * (CONST062 * VAR08 * z - CONST075 * VAR25) + + VAR16 * (-CONST057 * VAR06 * z - CONST063 * VAR08 * VAR25 + CONST072 * VAR23) + + y + * ( + CONST012 * VAR21 + + CONST076 * VAR06 * VAR25 + + CONST084 * VAR04 * z + + CONST101 * VAR08 * VAR23 + ) + ) + Y12 = ( + CONST007 * VAR02 + + CONST007 * VAR20 + + CONST030 * VAR04 * VAR26 + + CONST053 * VAR08 * VAR22 + + CONST091 * VAR06 * VAR24 + + VAR15 * (CONST025 * VAR06 + CONST025 * VAR24 + CONST032 * VAR08 * VAR26) + + VAR17 + * ( + CONST048 * VAR06 * VAR26 + + CONST048 * VAR08 * VAR24 + + CONST085 * VAR04 + + CONST085 * VAR22 + ) + ) + Y13 = VAR16 * ( + CONST014 * VAR08 * VAR25 + CONST019 * VAR23 + CONST056 * VAR06 * z + ) + y * ( + CONST036 * VAR08 * VAR23 + + CONST047 * VAR21 + - CONST077 * VAR06 * VAR25 + + CONST078 * VAR04 * z + ) + Y14 = ( + CONST008 * VAR02 + + CONST041 * VAR20 + + CONST088 * VAR04 * VAR26 + - CONST088 * VAR08 * VAR22 + + VAR17 + * ( + -CONST037 * VAR06 * VAR26 + + CONST037 * VAR08 * VAR24 + + CONST088 * VAR04 + - CONST088 * VAR22 + ) + ) + Y15 = y * ( + -CONST040 * VAR06 * VAR25 + + CONST061 * VAR08 * VAR23 + + CONST074 * VAR04 * z + - CONST095 * VAR21 + ) + Y16 = ( + CONST010 * VAR02 + + CONST010 * VAR20 + + CONST045 * VAR06 * VAR24 + + CONST074 * VAR04 * VAR26 + + CONST074 * VAR08 * VAR22 + ) + # not the prettiest way to concatenate, but better than + # messing with the linter + tensors = [ + Y00, + Y01, + Y02, + Y03, + Y04, + Y05, + Y06, + Y07, + Y08, + Y09, + Y10, + Y11, + Y12, + Y13, + Y14, + Y15, + Y16, + ] + return torch.cat(tensors, dim=-1) + + +@triton.jit +def eighth_order_fwd( + coord_ptr: tl.tensor, + output_ptr: tl.tensor, + block_size: tl.constexpr, + coord_numel: tl.constexpr, + output_numel: tl.constexpr, + col_offset: tl.constexpr, + output_stride: tl.constexpr, +): + # these are hardcoded because they are predetermined; + coord_stride = 3 + # work out the row offsets + block_id = tl.program_id(0) + coord_striding = tl.arange(0, block_size) * coord_stride + # as the name suggests, this is effectively every node/atom + coord_row_offset = coord_striding + (block_size * coord_stride * block_id) + x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel) + y = tl.load( + coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + z = tl.load( + coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + CONST000 = 1.12741169450483 + CONST003 = 4.12310562561766 + CONST004 = 4.50964677801932 + CONST006 = 6.76447016702898 + CONST007 = 1.69594242329302 + CONST008 = 1.88707052233084 + CONST010 = 2.58397773170915 + CONST011 = 13.1367135230810 + CONST012 = 13.1367135230810 + CONST014 = -489.184589393411 + CONST015 = 24.7386337537060 + CONST017 = 24.7386337537060 + CONST019 = 48.9184589393411 + CONST020 = 48.5105296237322 + CONST021 = 51.7445649319810 + CONST024 = 65.6835676154051 + CONST025 = 67.8376969317208 + CONST029 = 97.0210592474644 + CONST030 = -6.78376969317208 + CONST031 = 103.489129863962 + CONST032 = -407.026181590325 + CONST033 = 108.231522672464 + CONST035 = 110.066532613517 + CONST036 = 110.066532613517 + CONST037 = -396.284809689477 + CONST040 = -361.756882439281 + CONST041 = -1.88707052233084 + CONST042 = 158.513923875791 + CONST045 = 180.878441219640 + CONST046 = 194.042118494929 + CONST047 = -12.2296147348353 + CONST048 = 203.513090795162 + CONST050 = 216.463045344927 + CONST051 = 217.054129463568 + CONST052 = 216.463045344927 + CONST053 = -6.78376969317208 + CONST054 = -271.350787726883 + CONST055 = 244.592294696706 + CONST056 = 244.592294696706 + CONST057 = -262.734270461621 + CONST058 = -258.722824659905 + CONST061 = -217.054129463568 + CONST062 = -210.187416369296 + CONST063 = -175.156180307747 + CONST064 = -162.810472636130 + CONST066 = -144.702752975712 + CONST067 = -129.877827206956 + CONST068 = -129.361412329953 + CONST070 = -108.231522672464 + CONST071 = -108.231522672464 + CONST072 = -87.5780901538735 + CONST073 = -3.23403530824881 + CONST074 = -72.3513764878561 + CONST075 = -70.0624721230988 + CONST076 = -65.6835676154052 + CONST077 = -61.1480736741764 + CONST078 = -61.1480736741764 + CONST079 = -57.7234787586472 + CONST080 = -57.7234787586472 + CONST081 = -51.7445649319810 + CONST082 = -48.5105296237322 + CONST083 = -40.5868210021738 + CONST084 = -39.4101405692431 + CONST085 = -40.7026181590325 + CONST086 = -36.0771742241545 + CONST087 = -36.0771742241545 + CONST088 = -26.4189873126318 + CONST089 = -20.6718218536732 + CONST090 = -528.379746252636 + CONST091 = -16.9594242329302 + CONST092 = -13.1367135230810 + CONST093 = -12.2296147348353 + CONST094 = -11.3224231339851 + CONST095 = -10.3359109268366 + CONST096 = -9.70210592474644 + CONST097 = -11.3224231339851 + CONST098 = -13.5289403340579 + CONST099 = -6.78376969317208 + CONST100 = -13.5289403340579 + CONST101 = -13.1367135230810 + CONST102 = -3.23403530824881 + CONST103 = -1.61701765412441 + VAR06 = x * x * x * x + VAR07 = x * x * x + VAR08 = x * x + VAR02 = VAR06 * VAR06 + VAR03 = VAR06 * VAR07 + VAR04 = VAR07 * VAR07 + VAR05 = VAR07 * VAR08 + VAR15 = y * y * y * y + VAR16 = y * y * y + VAR17 = y * y + VAR11 = VAR15 * VAR16 + VAR12 = VAR15 * VAR16 + VAR13 = VAR16 * VAR16 + VAR14 = VAR16 * VAR17 + VAR24 = z * z * z * z + VAR25 = z * z * z + VAR26 = z * z + VAR20 = VAR24 * VAR24 + VAR21 = VAR24 * VAR25 + VAR22 = VAR25 * VAR25 + VAR23 = VAR25 * VAR26 + # -------------------- kernel implementations + Y00 = ( + -CONST066 * VAR05 * VAR25 + + CONST066 * VAR07 * VAR23 + + CONST089 * VAR03 * z + - CONST089 * VAR21 * x + ) + Y01 = y * ( + CONST040 * VAR07 * VAR24 + + CONST051 * VAR05 * VAR26 + - CONST074 * VAR22 * x + + CONST095 * VAR03 + ) + Y02 = ( + CONST097 * VAR03 * z + + VAR05 * (CONST042 * VAR17 * z - CONST088 * VAR25) + + VAR07 * (-CONST088 * VAR23 + CONST090 * VAR17 * VAR25) + + x * (CONST042 * VAR17 * VAR23 + CONST094 * VAR21) + ) + Y03 = VAR16 * ( + CONST014 * VAR07 * VAR26 + CONST019 * VAR05 + CONST055 * VAR24 * x + ) + y * ( + CONST035 * VAR05 * VAR26 + + CONST077 * VAR22 * x + - CONST078 * VAR07 * VAR24 + + CONST093 * VAR03 + ) + Y04 = ( + CONST099 * VAR03 * z + + VAR05 * (-CONST064 * VAR17 * z + CONST099 * VAR25) + + VAR07 * (-CONST053 * VAR23 + CONST054 * VAR15 * z) + + x * (-CONST053 * VAR21 - CONST054 * VAR15 * VAR25 + CONST064 * VAR17 * VAR23) + ) + Y05 = ( + VAR14 * (-CONST062 * VAR26 * x + CONST075 * VAR07) + + VAR16 * (CONST057 * VAR24 * x + CONST063 * VAR07 * VAR26 - CONST072 * VAR05) + + y + * ( + CONST011 * VAR05 * VAR26 + + CONST024 * VAR07 * VAR24 + - CONST084 * VAR22 * x + + CONST092 * VAR03 + ) + ) + Y06 = ( + CONST102 * VAR03 * z + + VAR05 * (CONST029 * VAR17 * z + CONST096 * VAR25) + + VAR07 * (CONST046 * VAR17 * VAR25 + CONST058 * VAR15 * z + CONST096 * VAR23) + + x + * ( + CONST029 * VAR17 * VAR23 + + CONST031 * VAR13 * z + + CONST058 * VAR15 * VAR25 + + CONST102 * VAR21 + ) + ) + Y07 = ( + CONST098 * VAR03 * y + + VAR05 * (CONST033 * VAR16 + CONST083 * VAR26 * y) + + VAR07 * (CONST050 * VAR16 * VAR26 + CONST067 * VAR14 + CONST083 * VAR24 * y) + + x + * ( + CONST015 * VAR12 + + CONST067 * VAR14 * VAR26 + - CONST070 * VAR16 * VAR24 + + CONST098 * VAR22 * y + ) + ) + Y08 = ( + CONST000 * VAR02 + + CONST000 * VAR20 + + CONST003 * VAR11 + - CONST070 * VAR15 * VAR24 + + CONST080 * VAR13 * VAR26 + + CONST087 * VAR17 * VAR22 + + VAR04 * (CONST004 * VAR26 + CONST086 * VAR17) + + VAR06 * (CONST006 * VAR24 - CONST070 * VAR15 + CONST071 * VAR17 * VAR26) + + VAR08 + * ( + CONST004 * VAR22 + + CONST050 * VAR15 * VAR26 + + CONST070 * VAR17 * VAR24 + + CONST079 * VAR13 + ) + ) + Y09 = ( + CONST098 * VAR21 * y + + VAR23 * (CONST033 * VAR16 + CONST083 * VAR08 * y) + + VAR25 * (CONST052 * VAR08 * VAR16 + CONST067 * VAR14 + CONST083 * VAR06 * y) + + z + * ( + CONST017 * VAR12 + + CONST033 * VAR06 * VAR16 + + CONST067 * VAR08 * VAR14 + + CONST100 * VAR04 * y + ) + ) + Y10 = ( + CONST073 * VAR08 * VAR22 + - CONST102 * VAR04 * VAR26 + - CONST103 * VAR02 + + CONST103 * VAR20 + + VAR13 * (CONST021 * VAR26 + CONST081 * VAR08) + + VAR15 * (-CONST068 * VAR06 + CONST068 * VAR24) + + VAR17 + * ( + CONST020 * VAR08 * VAR24 + + CONST020 * VAR22 + + CONST082 * VAR04 + + CONST082 * VAR06 * VAR26 + ) + ) + Y11 = ( + VAR14 * (CONST062 * VAR08 * z - CONST075 * VAR25) + + VAR16 * (-CONST057 * VAR06 * z - CONST063 * VAR08 * VAR25 + CONST072 * VAR23) + + y + * ( + CONST012 * VAR21 + + CONST076 * VAR06 * VAR25 + + CONST084 * VAR04 * z + + CONST101 * VAR08 * VAR23 + ) + ) + Y12 = ( + CONST007 * VAR02 + + CONST007 * VAR20 + + CONST030 * VAR04 * VAR26 + + CONST053 * VAR08 * VAR22 + + CONST091 * VAR06 * VAR24 + + VAR15 * (CONST025 * VAR06 + CONST025 * VAR24 + CONST032 * VAR08 * VAR26) + + VAR17 + * ( + CONST048 * VAR06 * VAR26 + + CONST048 * VAR08 * VAR24 + + CONST085 * VAR04 + + CONST085 * VAR22 + ) + ) + Y13 = VAR16 * ( + CONST014 * VAR08 * VAR25 + CONST019 * VAR23 + CONST056 * VAR06 * z + ) + y * ( + CONST036 * VAR08 * VAR23 + + CONST047 * VAR21 + - CONST077 * VAR06 * VAR25 + + CONST078 * VAR04 * z + ) + Y14 = ( + CONST008 * VAR02 + + CONST041 * VAR20 + + CONST088 * VAR04 * VAR26 + - CONST088 * VAR08 * VAR22 + + VAR17 + * ( + -CONST037 * VAR06 * VAR26 + + CONST037 * VAR08 * VAR24 + + CONST088 * VAR04 + - CONST088 * VAR22 + ) + ) + Y15 = y * ( + -CONST040 * VAR06 * VAR25 + + CONST061 * VAR08 * VAR23 + + CONST074 * VAR04 * z + - CONST095 * VAR21 + ) + Y16 = ( + CONST010 * VAR02 + + CONST010 * VAR20 + + CONST045 * VAR06 * VAR24 + + CONST074 * VAR04 * VAR26 + + CONST074 * VAR08 * VAR22 + ) + output_striding = tl.arange(0, block_size) * output_stride + output_row_offset = ( + output_striding + (block_size * output_stride * block_id) + col_offset + ) + tl.store(output_ptr + output_row_offset, Y00, mask=output_row_offset < output_numel) + tl.store( + output_ptr + output_row_offset + 1, + Y01, + mask=output_row_offset + 1 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 2, + Y02, + mask=output_row_offset + 2 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 3, + Y03, + mask=output_row_offset + 3 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 4, + Y04, + mask=output_row_offset + 4 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 5, + Y05, + mask=output_row_offset + 5 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 6, + Y06, + mask=output_row_offset + 6 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 7, + Y07, + mask=output_row_offset + 7 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 8, + Y08, + mask=output_row_offset + 8 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 9, + Y09, + mask=output_row_offset + 9 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 10, + Y10, + mask=output_row_offset + 10 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 11, + Y11, + mask=output_row_offset + 11 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 12, + Y12, + mask=output_row_offset + 12 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 13, + Y13, + mask=output_row_offset + 13 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 14, + Y14, + mask=output_row_offset + 14 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 15, + Y15, + mask=output_row_offset + 15 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 16, + Y16, + mask=output_row_offset + 16 < output_numel, + ) + + +@triton.jit +def eighth_order_bwd( + coord_ptr: tl.tensor, + coord_grad_ptr: tl.tensor, + sph_grad_ptr: tl.tensor, + block_size: tl.constexpr, + coord_numel: tl.constexpr, + output_numel: tl.constexpr, + col_offset: tl.constexpr, + output_stride: tl.constexpr, +): + # work out the row offsets + block_id = tl.program_id(0) + # these are hardcoded because they are predetermined; + coord_stride = 3 + coord_striding = tl.arange(0, block_size) * coord_stride + # as the name suggests, this is effectively every node/atom + coord_row_offset = coord_striding + (block_size * coord_stride * block_id) + x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel) + y = tl.load( + coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + z = tl.load( + coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + output_striding = tl.arange(0, block_size) * output_stride + output_row_offset = ( + output_striding + (block_size * output_stride * block_id) + col_offset + ) + # load in gradients w.r.t. spherical harmonic projections + g_0 = tl.load( + sph_grad_ptr + output_row_offset, mask=output_row_offset < output_numel + ) + g_1 = tl.load( + sph_grad_ptr + output_row_offset + 1, mask=output_row_offset + 1 < output_numel + ) + g_2 = tl.load( + sph_grad_ptr + output_row_offset + 2, mask=output_row_offset + 2 < output_numel + ) + g_3 = tl.load( + sph_grad_ptr + output_row_offset + 3, mask=output_row_offset + 3 < output_numel + ) + g_4 = tl.load( + sph_grad_ptr + output_row_offset + 4, mask=output_row_offset + 4 < output_numel + ) + g_5 = tl.load( + sph_grad_ptr + output_row_offset + 5, mask=output_row_offset + 5 < output_numel + ) + g_6 = tl.load( + sph_grad_ptr + output_row_offset + 6, mask=output_row_offset + 6 < output_numel + ) + g_7 = tl.load( + sph_grad_ptr + output_row_offset + 7, mask=output_row_offset + 7 < output_numel + ) + g_8 = tl.load( + sph_grad_ptr + output_row_offset + 8, mask=output_row_offset + 8 < output_numel + ) + g_9 = tl.load( + sph_grad_ptr + output_row_offset + 9, mask=output_row_offset + 9 < output_numel + ) + g_10 = tl.load( + sph_grad_ptr + output_row_offset + 10, + mask=output_row_offset + 10 < output_numel, + ) + g_11 = tl.load( + sph_grad_ptr + output_row_offset + 11, + mask=output_row_offset + 11 < output_numel, + ) + g_12 = tl.load( + sph_grad_ptr + output_row_offset + 12, + mask=output_row_offset + 12 < output_numel, + ) + g_13 = tl.load( + sph_grad_ptr + output_row_offset + 13, + mask=output_row_offset + 13 < output_numel, + ) + g_14 = tl.load( + sph_grad_ptr + output_row_offset + 14, + mask=output_row_offset + 14 < output_numel, + ) + g_15 = tl.load( + sph_grad_ptr + output_row_offset + 15, + mask=output_row_offset + 15 < output_numel, + ) + g_16 = tl.load( + sph_grad_ptr + output_row_offset + 16, + mask=output_row_offset + 16 < output_numel, + ) + # -------------------- variable and constant definitions + CONST000 = 2.00000000000000 + CONST001 = 3.00000000000000 + CONST002 = 4.50964677801932 + CONST004 = 5.00000000000000 + CONST005 = 6.78376969317208 + CONST006 = 4.00000000000000 + CONST007 = 9.01929355603863 + CONST008 = 6.76447016702898 + CONST009 = 6.00000000000000 + CONST011 = 13.5675393863442 + CONST012 = 15.0965641786467 + CONST013 = 13.1367135230810 + CONST015 = 13.1367135230810 + CONST017 = 19.4042118494929 + CONST019 = -489.184589393411 + CONST020 = 24.7386337537060 + CONST023 = 26.2734270461621 + CONST024 = 27.0578806681159 + CONST025 = 24.7386337537060 + CONST026 = 32.9848450049413 + CONST027 = 33.9188484658604 + CONST028 = 550.332663067587 + CONST030 = -978.369178786822 + CONST031 = 48.5105296237322 + CONST033 = 51.7445649319810 + CONST035 = 48.9184589393411 + CONST041 = 65.6835676154051 + CONST043 = -1467.55376818023 + CONST045 = -12.2296147348353 + CONST047 = 582.126355484786 + CONST048 = -437.890450769368 + CONST049 = -434.108258927137 + CONST050 = -434.108258927137 + CONST052 = -432.926090689854 + CONST054 = -1447.02752975712 + CONST055 = 91.9569946615672 + CONST056 = -420.374832738593 + CONST057 = 6.46807061649763 + CONST058 = 97.0210592474644 + CONST061 = 103.489129863962 + CONST062 = -407.026181590325 + CONST063 = 108.231522672464 + CONST065 = 110.066532613517 + CONST066 = 110.066532613517 + CONST067 = 620.934779183772 + CONST068 = -396.284809689477 + CONST070 = 132.094936563159 + CONST071 = 434.108258927137 + CONST073 = 649.389136034781 + CONST076 = -366.888442045058 + CONST077 = -366.888442045058 + CONST078 = -361.756882439281 + CONST080 = -6.78376969317208 + CONST082 = -350.312360615494 + CONST083 = -346.340872551883 + CONST084 = -346.340872551883 + CONST085 = 173.170436275942 + CONST086 = 173.170436275942 + CONST088 = 183.444221022529 + CONST089 = 183.444221022529 + CONST090 = -325.620945272260 + CONST091 = -13.5289403340579 + CONST092 = -13.5675393863442 + CONST093 = 194.042118494929 + CONST095 = 197.050702846215 + CONST096 = -11.3224231339851 + CONST097 = 203.513090795162 + CONST098 = -814.052363180650 + CONST102 = -814.052363180650 + CONST104 = 217.054129463568 + CONST105 = 216.463045344927 + CONST106 = 220.133065227035 + CONST107 = -291.063177742393 + CONST108 = 220.133065227035 + CONST109 = -792.569619378954 + CONST111 = -271.350787726883 + CONST112 = 244.592294696705 + CONST113 = 244.592294696706 + CONST114 = 244.592294696706 + CONST115 = -776.168473979715 + CONST116 = -262.734270461621 + CONST117 = -259.755654413913 + CONST118 = -258.722824659905 + CONST120 = 262.734270461621 + CONST121 = -244.215708954195 + CONST122 = 271.350787726883 + CONST124 = -236.460843415458 + CONST127 = -217.054129463568 + CONST128 = -216.463045344927 + CONST129 = -216.463045344927 + CONST130 = -216.463045344927 + CONST131 = -723.513764878561 + CONST133 = -210.187416369296 + CONST134 = -210.187416369296 + CONST135 = 814.052363180650 + CONST136 = -197.050702846215 + CONST137 = 317.027847751582 + CONST138 = -194.042118494929 + CONST139 = -13.1367135230810 + CONST140 = 324.694568017391 + CONST142 = 324.694568017391 + CONST143 = -175.156180307747 + CONST146 = -162.810472636130 + CONST147 = -162.347284008695 + CONST148 = 865.852181379709 + CONST149 = -158.513923875791 + CONST151 = -144.702752975712 + CONST152 = -649.389136034782 + CONST153 = -129.877827206956 + CONST154 = -129.361412329953 + CONST155 = 388.084236989858 + CONST157 = -115.446957517294 + CONST158 = -108.231522672464 + CONST159 = -108.231522672464 + CONST160 = 407.026181590325 + CONST161 = -103.489129863962 + CONST162 = -97.0210592474644 + CONST163 = -94.7025823384056 + CONST165 = -91.9569946615672 + CONST167 = -87.5780901538735 + CONST168 = -85.6073031438469 + CONST169 = -85.6073031438469 + CONST170 = -81.1736420043477 + CONST171 = 432.926090689854 + CONST172 = -79.2569619378954 + CONST173 = -81.1736420043477 + CONST177 = -79.2569619378954 + CONST178 = -72.3513764878561 + CONST179 = -72.1543484483091 + CONST180 = -70.0624721230988 + CONST181 = -72.1543484483091 + CONST182 = -67.8376969317208 + CONST183 = -65.6835676154052 + CONST184 = -61.1480736741764 + CONST185 = -1085.27064731784 + CONST186 = -61.1480736741764 + CONST187 = -1085.40315090753 + CONST188 = -57.7234787586472 + CONST189 = -12.9361412329953 + CONST190 = -1085.27064731784 + CONST191 = -52.8379746252636 + CONST192 = -51.7445649319810 + CONST193 = -1585.13923875791 + CONST194 = -48.5105296237322 + CONST195 = -47.4863878522046 + CONST197 = 978.369178786822 + CONST198 = -517.445649319810 + CONST199 = -40.7026181590325 + CONST200 = -40.5868210021738 + CONST201 = -39.4101405692431 + CONST202 = -40.7026181590325 + CONST203 = -36.0771742241545 + CONST204 = -1056.75949250527 + CONST205 = -29.1063177742393 + CONST206 = 485.105296237322 + CONST207 = -26.2734270461621 + CONST208 = -26.4189873126318 + CONST209 = -1050.93708184648 + CONST210 = -22.6382471577417 + CONST211 = -20.6718218536732 + CONST212 = -19.4042118494929 + CONST213 = -20.3513090795162 + CONST214 = -528.379746252636 + CONST215 = -15.0965641786467 + CONST216 = -13.5675393863442 + CONST217 = -525.468540923241 + CONST218 = -11.3224231339851 + CONST219 = -13.5289403340579 + CONST220 = -9.70210592474644 + CONST221 = -10.3359109268366 + CONST222 = -6.46807061649763 + CONST223 = -13.1367135230810 + CONST224 = -12.2296147348353 + CONST225 = -3.23403530824881 + CONST226 = -1034.89129863962 + VAR06 = x * x * x * x + VAR07 = x * x * x + VAR08 = x * x + VAR03 = VAR06 * VAR07 + VAR04 = VAR07 * VAR07 + VAR05 = VAR07 * VAR08 + VAR15 = y * y * y * y + VAR16 = y * y * y + VAR17 = y * y + VAR12 = VAR15 * VAR16 + VAR13 = VAR16 * VAR16 + VAR14 = VAR16 * VAR17 + VAR24 = z * z * z * z + VAR25 = z * z * z + VAR26 = z * z + VAR21 = VAR24 * VAR25 + VAR22 = VAR25 * VAR25 + VAR23 = VAR25 * VAR26 + # -------------------- kernel implementations + g_x = tl.load( + coord_grad_ptr + coord_row_offset, mask=coord_row_offset < coord_numel + ) + g_y = tl.load( + coord_grad_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + g_z = tl.load( + coord_grad_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + g_x += ( + g_0 + * ( + CONST049 * VAR08 * VAR23 + - CONST131 * VAR06 * VAR25 + + CONST151 * VAR04 * z + - CONST211 * VAR21 + ) + + g_1 + * y + * ( + CONST178 * VAR04 + - CONST178 * VAR22 + + CONST185 * VAR08 * VAR24 + - CONST190 * VAR06 * VAR26 + ) + + g_10 + * ( + CONST017 * VAR05 * VAR26 + + CONST161 * VAR13 * x + - CONST189 * VAR03 + - CONST198 * VAR07 * VAR15 + + CONST222 * VAR22 * x + + VAR17 + * (CONST058 * VAR24 * x + CONST107 * VAR05 + CONST138 * VAR07 * VAR26) + ) + + g_11 + * ( + CONST056 * VAR14 * x * z + + VAR16 * (-CONST082 * VAR25 * x - CONST209 * VAR07 * z) + + y + * (CONST116 * VAR07 * VAR25 + CONST124 * VAR05 * z + CONST207 * VAR23 * x) + ) + + g_12 + * ( + CONST011 * VAR03 + + CONST182 * VAR07 * VAR24 + + CONST199 * VAR05 * VAR26 + + CONST216 * VAR22 * x + + VAR15 * (CONST098 * VAR26 * x + CONST122 * VAR07) + + VAR17 + * (-CONST102 * VAR07 * VAR26 + CONST121 * VAR05 + CONST160 * VAR24 * x) + ) + + g_13 + * ( + VAR16 * (-CONST030 * VAR07 * z + CONST030 * VAR25 * x) + + y + * (CONST076 * VAR05 * z + CONST106 * VAR23 * x + CONST112 * VAR07 * VAR25) + ) + + g_14 + * ( + CONST012 * VAR03 + + CONST149 * VAR05 * VAR26 + - CONST191 * VAR22 * x + + VAR17 + * (CONST109 * VAR24 * x + CONST149 * VAR05 - CONST193 * VAR07 * VAR26) + ) + + g_15 + * y + * (CONST050 * VAR05 * z + CONST050 * VAR23 * x - CONST054 * VAR07 * VAR25) + + g_16 + * ( + CONST050 * VAR05 * VAR26 + - CONST131 * VAR07 * VAR24 + + CONST151 * VAR22 * x + - CONST211 * VAR03 + ) + + g_2 + * ( + CONST001 * VAR08 * (-CONST208 * VAR23 + CONST214 * VAR17 * VAR25) + + CONST004 * VAR06 * (-CONST149 * VAR17 * z - CONST208 * VAR25) + - CONST149 * VAR17 * VAR23 + + CONST172 * VAR04 * z + + CONST218 * VAR21 + ) + + g_3 + * ( + VAR16 * (CONST043 * VAR08 * VAR26 + CONST113 * VAR06 + CONST114 * VAR24) + + y + * ( + CONST028 * VAR06 * VAR26 + + CONST088 * VAR08 * VAR24 + + CONST168 * VAR04 + + CONST184 * VAR22 + ) + ) + + g_4 + * ( + CONST001 * VAR08 * (CONST005 * VAR23 + CONST111 * VAR15 * z) + + CONST004 * VAR06 * (CONST080 * VAR25 - CONST146 * VAR17 * z) + + CONST005 * VAR21 + - CONST111 * VAR15 * VAR25 + + CONST146 * VAR17 * VAR23 + + CONST195 * VAR04 * z + ) + + g_5 + * ( + VAR14 * (CONST133 * VAR08 - CONST134 * VAR26) + + VAR16 * (-CONST048 * VAR06 + CONST116 * VAR24 + CONST217 * VAR08 * VAR26) + + y + * ( + CONST041 * VAR06 * VAR26 + + CONST095 * VAR08 * VAR24 + + CONST165 * VAR04 + - CONST201 * VAR22 + ) + ) + + g_6 + * ( + CONST001 + * VAR08 + * (CONST093 * VAR17 * VAR25 + CONST118 * VAR15 * z + CONST220 * VAR23) + + CONST004 * VAR06 * (-CONST162 * VAR17 * z + CONST220 * VAR25) + + CONST118 * VAR15 * VAR25 + - CONST161 * VAR13 * z + - CONST162 * VAR17 * VAR23 + + CONST210 * VAR04 * z + + CONST225 * VAR21 + ) + + g_7 + * ( + CONST001 + * VAR08 + * (-CONST128 * VAR16 * VAR26 + CONST153 * VAR14 + CONST200 * VAR24 * y) + + CONST004 * VAR06 * (CONST063 * VAR16 + CONST200 * VAR26 * y) + + CONST020 * VAR12 + + CONST153 * VAR14 * VAR26 + - CONST158 * VAR16 * VAR24 + + CONST163 * VAR04 * y + + CONST219 * VAR22 * y + ) + + g_8 + * ( + CONST000 + * x + * ( + CONST002 * VAR22 + - CONST128 * VAR15 * VAR26 + + CONST158 * VAR17 * VAR24 + + CONST188 * VAR13 + ) + + CONST006 + * VAR07 + * (CONST008 * VAR24 - CONST158 * VAR15 + CONST159 * VAR17 * VAR26) + + CONST007 * VAR03 + + CONST009 * VAR05 * (CONST002 * VAR26 + CONST203 * VAR17) + ) + + g_9 + * ( + CONST173 * VAR23 * x * y + + VAR25 * (CONST147 * VAR07 * y + CONST171 * VAR16 * x) + + z + * (CONST117 * VAR14 * x + CONST170 * VAR05 * y + CONST171 * VAR07 * VAR16) + ) + ) + g_y += ( + CONST000 + * g_14 + * y + * ( + -CONST068 * VAR06 * VAR26 + + CONST068 * VAR08 * VAR24 + + CONST208 * VAR04 + - CONST208 * VAR22 + ) + + g_1 + * ( + CONST078 * VAR07 * VAR24 + + CONST104 * VAR05 * VAR26 + - CONST178 * VAR22 * x + + CONST221 * VAR03 + ) + + g_10 + * ( + CONST000 + * y + * ( + CONST031 * VAR08 * VAR24 + + CONST031 * VAR22 + + CONST194 * VAR04 + + CONST194 * VAR06 * VAR26 + ) + + CONST006 * VAR16 * (-CONST154 * VAR06 + CONST154 * VAR24) + + CONST009 * VAR14 * (CONST033 * VAR26 + CONST192 * VAR08) + ) + + g_11 + * ( + CONST001 + * VAR17 + * (-CONST116 * VAR06 * z - CONST143 * VAR08 * VAR25 + CONST167 * VAR23) + + CONST004 * VAR15 * (CONST134 * VAR08 * z - CONST180 * VAR25) + + CONST013 * VAR21 + + CONST183 * VAR06 * VAR25 + + CONST201 * VAR04 * z + + CONST223 * VAR08 * VAR23 + ) + + g_12 + * ( + CONST000 + * y + * ( + CONST097 * VAR06 * VAR26 + + CONST097 * VAR08 * VAR24 + + CONST199 * VAR04 + + CONST199 * VAR22 + ) + + CONST006 + * VAR16 + * (CONST062 * VAR08 * VAR26 - CONST182 * VAR06 - CONST182 * VAR24) + ) + + g_13 + * ( + CONST001 + * VAR17 + * (CONST019 * VAR08 * VAR25 + CONST035 * VAR23 + CONST113 * VAR06 * z) + + CONST065 * VAR08 * VAR23 + - CONST184 * VAR06 * VAR25 + + CONST186 * VAR04 * z + + CONST224 * VAR21 + ) + + g_15 + * ( + -CONST078 * VAR06 * VAR25 + + CONST127 * VAR08 * VAR23 + + CONST178 * VAR04 * z + - CONST221 * VAR21 + ) + + g_2 + * ( + CONST137 * VAR05 * y * z + + CONST137 * VAR23 * x * y + + CONST204 * VAR07 * VAR25 * y + ) + + g_3 + * ( + CONST001 + * VAR17 + * (CONST019 * VAR07 * VAR26 + CONST035 * VAR05 + CONST114 * VAR24 * x) + + CONST045 * VAR03 + + CONST066 * VAR05 * VAR26 + + CONST184 * VAR22 * x + - CONST186 * VAR07 * VAR24 + ) + + g_4 + * ( + -CONST090 * VAR05 * y * z + + CONST187 * VAR07 * VAR16 * z + + x * (CONST090 * VAR23 * y - CONST187 * VAR16 * VAR25) + ) + + g_5 + * ( + CONST001 + * VAR17 + * (CONST116 * VAR24 * x + CONST143 * VAR07 * VAR26 - CONST167 * VAR05) + + CONST004 * VAR15 * (-CONST134 * VAR26 * x + CONST180 * VAR07) + + CONST015 * VAR05 * VAR26 + + CONST041 * VAR07 * VAR24 + + CONST139 * VAR03 + - CONST201 * VAR22 * x + ) + + g_6 + * ( + -CONST138 * VAR05 * y * z + + VAR07 * (CONST155 * VAR25 * y + CONST226 * VAR16 * z) + + x + * (CONST067 * VAR14 * z - CONST138 * VAR23 * y + CONST226 * VAR16 * VAR25) + ) + + g_7 + * ( + CONST219 * VAR03 + + VAR05 * (CONST142 * VAR17 + CONST200 * VAR26) + + VAR07 * (CONST152 * VAR15 - CONST152 * VAR17 * VAR26 + CONST200 * VAR24) + + x + * ( + CONST085 * VAR13 + + CONST140 * VAR17 * VAR24 + + CONST152 * VAR15 * VAR26 + + CONST219 * VAR22 + ) + ) + + g_8 + * ( + CONST026 * VAR12 + - CONST052 * VAR16 * VAR24 + + CONST084 * VAR14 * VAR26 + + CONST179 * VAR04 * y + + CONST181 * VAR22 * y + + VAR06 * (-CONST052 * VAR16 + CONST129 * VAR26 * y) + + VAR08 + * (CONST083 * VAR14 + CONST128 * VAR24 * y + CONST148 * VAR16 * VAR26) + ) + + g_9 + * ( + CONST219 * VAR21 + + VAR23 * (CONST142 * VAR17 + CONST200 * VAR08) + + VAR25 * (CONST073 * VAR08 * VAR17 + CONST152 * VAR15 + CONST200 * VAR06) + + z + * ( + CONST086 * VAR13 + + CONST091 * VAR04 + + CONST142 * VAR06 * VAR17 + + CONST152 * VAR08 * VAR15 + ) + ) + ) + g_z += ( + g_0 + * ( + -CONST049 * VAR05 * VAR26 + + CONST131 * VAR07 * VAR24 + - CONST151 * VAR22 * x + + CONST211 * VAR03 + ) + + g_1 + * y + * (-CONST050 * VAR23 * x + CONST054 * VAR07 * VAR25 + CONST071 * VAR05 * z) + + g_10 + * ( + CONST057 * VAR04 * z + + CONST061 * VAR13 * z + + CONST189 * VAR21 + + CONST198 * VAR15 * VAR25 + + CONST212 * VAR08 * VAR23 + + VAR17 + * (CONST093 * VAR08 * VAR25 - CONST107 * VAR23 + CONST162 * VAR06 * z) + ) + + g_11 + * ( + VAR14 * (-CONST133 * VAR26 + CONST134 * VAR08) + + VAR16 * (CONST048 * VAR24 - CONST116 * VAR06 - CONST217 * VAR08 * VAR26) + + y + * ( + CONST055 * VAR22 + + CONST136 * VAR06 * VAR26 + + CONST183 * VAR08 * VAR24 + + CONST201 * VAR04 + ) + ) + + g_12 + * ( + CONST011 * VAR21 + + CONST092 * VAR04 * z + + CONST182 * VAR06 * VAR25 + + CONST202 * VAR08 * VAR23 + + VAR15 * (CONST098 * VAR08 * z + CONST122 * VAR25) + + VAR17 + * (-CONST102 * VAR08 * VAR25 + CONST121 * VAR23 + CONST160 * VAR06 * z) + ) + + g_13 + * ( + VAR16 * (CONST043 * VAR08 * VAR26 + CONST113 * VAR06 + CONST113 * VAR24) + + y + * ( + CONST028 * VAR08 * VAR24 + + CONST089 * VAR06 * VAR26 + + CONST169 * VAR22 + + CONST186 * VAR04 + ) + ) + + g_14 + * ( + -CONST149 * VAR08 * VAR23 + + CONST191 * VAR04 * z + + CONST215 * VAR21 + + VAR17 + * (-CONST109 * VAR06 * z - CONST149 * VAR23 + CONST193 * VAR08 * VAR25) + ) + + g_15 + * y + * ( + CONST178 * VAR04 + - CONST178 * VAR22 + - CONST185 * VAR06 * VAR26 + + CONST190 * VAR08 * VAR24 + ) + + g_16 + * ( + CONST050 * VAR08 * VAR23 + - CONST131 * VAR06 * VAR25 + + CONST151 * VAR04 * z + - CONST211 * VAR21 + ) + + g_2 + * ( + CONST096 * VAR03 + + VAR05 * (-CONST149 * VAR17 - CONST177 * VAR26) + + VAR07 * (CONST070 * VAR24 + CONST193 * VAR17 * VAR26) + + x * (-CONST109 * VAR17 * VAR24 + CONST177 * VAR22) + ) + + g_3 + * ( + VAR16 * (CONST030 * VAR07 * z + CONST197 * VAR25 * x) + + y + * (CONST077 * VAR23 * x + CONST108 * VAR05 * z + CONST114 * VAR07 * VAR25) + ) + + g_4 + * ( + CONST080 * VAR03 + + VAR05 * (-CONST146 * VAR17 + CONST213 * VAR26) + + VAR07 * (CONST027 * VAR24 + CONST111 * VAR15) + + x + * (CONST102 * VAR17 * VAR24 + CONST135 * VAR15 * VAR26 - CONST195 * VAR22) + ) + + g_5 + * ( + -CONST056 * VAR14 * x * z + + VAR16 * (CONST082 * VAR07 * z + CONST209 * VAR25 * x) + + y + * (CONST023 * VAR05 * z + CONST120 * VAR07 * VAR25 - CONST124 * VAR23 * x) + ) + + g_6 + * ( + CONST225 * VAR03 + + VAR05 * (-CONST162 * VAR17 + CONST205 * VAR26) + + VAR07 * (CONST047 * VAR17 * VAR26 + CONST118 * VAR15 + CONST194 * VAR24) + + x + * ( + CONST115 * VAR15 * VAR26 + - CONST161 * VAR13 + + CONST206 * VAR17 * VAR24 + + CONST210 * VAR22 + ) + ) + + g_7 + * ( + CONST173 * VAR05 * y * z + + VAR07 * (-CONST052 * VAR16 * z + CONST147 * VAR25 * y) + + x + * (-CONST052 * VAR16 * VAR25 + CONST117 * VAR14 * z + CONST173 * VAR23 * y) + ) + + g_8 + * ( + CONST007 * VAR04 * z + + CONST007 * VAR21 + - CONST052 * VAR15 * VAR25 + + CONST130 * VAR17 * VAR23 + + CONST157 * VAR13 * z + + VAR06 * (CONST024 * VAR25 + CONST129 * VAR17 * z) + + VAR08 + * (CONST024 * VAR23 - CONST052 * VAR15 * z + CONST052 * VAR17 * VAR25) + ) + + g_9 + * ( + CONST001 + * VAR26 + * (CONST105 * VAR08 * VAR16 + CONST153 * VAR14 + CONST200 * VAR06 * y) + + CONST004 * VAR24 * (CONST063 * VAR16 + CONST200 * VAR08 * y) + + CONST025 * VAR12 + + CONST063 * VAR06 * VAR16 + + CONST091 * VAR04 * y + + CONST153 * VAR08 * VAR14 + + CONST163 * VAR22 * y + ) + ) + # write out gradients + tl.store( + coord_grad_ptr + coord_row_offset, g_x, mask=coord_row_offset < coord_numel + ) + tl.store( + coord_grad_ptr + coord_row_offset + 1, + g_y, + mask=coord_row_offset + 1 < coord_numel, + ) + tl.store( + coord_grad_ptr + coord_row_offset + 2, + g_z, + mask=coord_row_offset + 2 < coord_numel, + ) diff --git a/src/equitriton/sph_harm/direct/y_9.py b/src/equitriton/sph_harm/direct/y_9.py new file mode 100644 index 0000000..a188f21 --- /dev/null +++ b/src/equitriton/sph_harm/direct/y_9.py @@ -0,0 +1,2088 @@ +import triton +import torch +from triton import language as tl + +from equitriton.utils import calculate_lastdim_num_blocks + +__all__ = ["NinthOrderSphericalHarmonic"] + + +class NinthOrderSphericalHarmonic(torch.autograd.Function): + @staticmethod + def forward( + ctx, + coords: torch.Tensor, + output_tensor: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + block_size: int = 64, + col_offset: int = 0, + ): + # allocate a tensor if one isn't given + if not isinstance(output_tensor, torch.Tensor): + output_tensor = torch.empty( + (*coords.shape[:-1], 19), dtype=coords.dtype, device=coords.device + ) + coord_numel = coords.numel() + output_numel = output_tensor.numel() + num_blocks = calculate_lastdim_num_blocks(coords, block_size) + # apply the kernel + ninth_order_fwd[num_blocks,]( + coords, + output_tensor, + block_size, + coord_numel, + output_numel, + col_offset, + output_tensor.stride(-2), + ) + ctx.save_for_backward(coords) + return output_tensor + + @staticmethod + def backward( + ctx, + sph_grad_tensor: torch.Tensor, + block_size: int = 64, + col_offset: int = 0, + ) -> torch.Tensor: + (coords,) = ctx.saved_tensors + coord_grad_output = torch.zeros_like(coords) + num_blocks = calculate_lastdim_num_blocks(coords, block_size) + # call backward kernel + ninth_order_bwd[num_blocks,]( + coords, + coord_grad_output, + sph_grad_tensor, + block_size, + coords.numel(), + sph_grad_tensor.numel(), + col_offset, + sph_grad_tensor.stride(-2), + ) + return coord_grad_output + + +def _torch_fwd(coords: torch.Tensor) -> torch.Tensor: + """ + PyTorch implementation of the kernel. This is designed + purely for unit testing to ensure that the Triton implementation + is behaving as intended. + + Parameters + ---------- + coords : torch.Tensor + N-d tensor, where the last dimension corresponds to + xyz values. + + Returns + ------- + torch.Tensor + N-d tensor, where the last dimension corresponds to + each projection of the second order spherical harmonic. + """ + x = coords[..., 0].contiguous().unsqueeze(-1) + y = coords[..., 1].contiguous().unsqueeze(-1) + z = coords[..., 2].contiguous().unsqueeze(-1) + # -------------------- variable and constant definitions + CONST000 = 1.93163963757558 + CONST001 = 2.65478475211798 + CONST002 = 1.72771101506082 + CONST004 = 1.59908344719522 + CONST005 = 6.39633378878088 + CONST006 = 6.39633378878088 + CONST007 = 8.63855507530412 + CONST008 = 9.59450068317133 + CONST009 = 4.35889894354067 + CONST010 = 10.7269778688696 + CONST011 = 10.7269778688696 + CONST012 = 6.39633378878088 + CONST013 = 15.0007324039945 + CONST014 = 13.0937127087774 + CONST016 = 14.4550674370400 + CONST017 = 14.4550674370400 + CONST018 = 13.3827919767794 + CONST019 = 13.5214774630291 + CONST020 = 23.8930627690618 + CONST021 = 27.0429549260581 + CONST022 = 29.2403830344269 + CONST023 = 29.2403830344269 + CONST024 = 30.0014648079890 + CONST025 = -480.023436927823 + CONST026 = -480.023436927823 + CONST029 = 42.9079114754785 + CONST030 = -462.562157985281 + CONST032 = -967.518168434061 + CONST034 = 57.8202697481601 + CONST035 = 58.9217071894985 + CONST036 = 58.9217071894985 + CONST037 = 62.4530292249704 + CONST038 = 1081.71819704233 + CONST039 = 64.3618672132178 + CONST040 = 578.202697481601 + CONST044 = 600.029296159779 + CONST045 = -936.795438374555 + CONST047 = 96.7518168434061 + CONST049 = 115.640539496320 + CONST051 = -392.811381263323 + CONST053 = 137.149553407950 + CONST055 = 150.007324039945 + CONST056 = -343.263291803828 + CONST058 = 11.2632978048796 + CONST061 = -315.372338536630 + CONST062 = -314.249105010659 + CONST063 = 205.957975082297 + CONST065 = -294.608535947493 + CONST066 = 240.011718463912 + CONST068 = 241.879542108515 + CONST069 = 255.853351551235 + CONST070 = 255.853351551235 + CONST071 = -241.879542108515 + CONST072 = -240.011718463912 + CONST073 = -241.879542108515 + CONST074 = 788.430846341574 + CONST075 = 1.72771101506082 + CONST076 = -1.93163963757558 + CONST077 = -1249.06058449941 + CONST078 = -223.001919177910 + CONST080 = -216.343639408465 + CONST081 = 300.014648079890 + CONST082 = -204.682681240988 + CONST083 = -204.682681240988 + CONST084 = -204.682681240988 + CONST086 = -196.405690631662 + CONST087 = -191.890013663426 + CONST088 = -191.890013663427 + CONST089 = -187.359087674911 + CONST090 = -693.843236977922 + CONST091 = 334.502878766866 + CONST092 = -176.765121568496 + CONST093 = -150.007324039945 + CONST094 = -144.550674370400 + CONST095 = 374.718175349822 + CONST096 = 374.718175349822 + CONST097 = -649.030918225395 + CONST099 = -630.744677073259 + CONST100 = -115.640539496320 + CONST101 = -114.421097267943 + CONST102 = -115.640539496320 + CONST103 = -104.749701670220 + CONST104 = 411.915950164594 + CONST105 = -95.5722510762473 + CONST106 = -90.1063824390370 + CONST107 = -90.0043944239669 + CONST109 = -80.2967518606762 + CONST110 = -78.4601809837321 + CONST111 = 435.383175795327 + CONST112 = -589.217071894985 + CONST113 = -78.4601809837321 + CONST114 = 435.383175795328 + CONST115 = -68.5747767039748 + CONST116 = -63.9633378878088 + CONST117 = -63.9633378878088 + CONST118 = -62.4530292249704 + CONST119 = -58.9217071894985 + CONST120 = -1081.71819704233 + CONST121 = -57.8202697481601 + CONST122 = -57.8202697481601 + CONST123 = -58.9217071894985 + CONST124 = -54.0859098521163 + CONST125 = 462.562157985281 + CONST127 = -48.3759084217031 + CONST128 = -48.3759084217030 + CONST129 = -38.6327927515116 + CONST130 = -30.9062342012093 + CONST131 = 483.759084217031 + CONST132 = -30.0014648079890 + CONST133 = -30.0014648079890 + CONST134 = -27.0429549260581 + CONST135 = -24.1879542108515 + CONST136 = -24.1879542108515 + CONST137 = -1.63671408859718 + CONST138 = -15.0007324039945 + CONST139 = -13.5214774630291 + CONST140 = -13.8216881204866 + CONST141 = -13.0937127087774 + CONST142 = -13.3827919767794 + CONST143 = -9.82028453158308 + CONST144 = -4.91014226579154 + CONST145 = 511.706703102471 + VAR06 = x * x * x * x + VAR07 = x * x * x + VAR08 = x * x + VAR01 = VAR07 * VAR07 * VAR07 + VAR02 = VAR06 * VAR06 + VAR03 = VAR06 * VAR07 + VAR04 = VAR07 * VAR07 + VAR05 = VAR07 * VAR08 + VAR15 = y * y * y * y + VAR16 = y * y * y + VAR17 = y * y + VAR10 = VAR16 * VAR16 * VAR16 + VAR11 = VAR15 * VAR15 + VAR12 = VAR15 * VAR16 + VAR13 = VAR16 * VAR16 + VAR14 = VAR16 * VAR17 + VAR24 = z * z * z * z + VAR25 = z * z * z + VAR26 = z * z + VAR19 = VAR25 * VAR25 * VAR25 + VAR20 = VAR24 * VAR24 + VAR21 = VAR24 * VAR25 + VAR22 = VAR25 * VAR25 + VAR23 = VAR25 * VAR26 + # -------------------- kernel implementations + Y00 = ( + CONST001 * VAR01 + + CONST020 * VAR20 * x + + CONST078 * VAR07 * VAR22 + + CONST091 * VAR05 * VAR24 + + CONST105 * VAR03 * VAR26 + ) + Y01 = y * ( + -CONST099 * VAR05 * VAR25 + + CONST099 * VAR07 * VAR23 + + CONST106 * VAR03 * z + - CONST106 * VAR21 * x + ) + Y02 = ( + CONST000 * VAR01 + + VAR03 * (CONST129 * VAR26 + CONST130 * VAR17) + + VAR05 * (CONST021 * VAR24 - CONST097 * VAR17 * VAR26) + + VAR07 * (CONST120 * VAR17 * VAR24 - CONST124 * VAR22) + + x * (-CONST080 * VAR17 * VAR22 + CONST139 * VAR20) + ) + Y03 = VAR16 * ( + CONST077 * VAR07 * VAR25 + CONST095 * VAR05 * z + CONST096 * VAR23 * x + ) + y * ( + -CONST089 * VAR05 * VAR25 + - CONST089 * VAR07 * VAR23 + + CONST109 * VAR03 * z + + CONST109 * VAR21 * x + ) + Y04 = ( + CONST002 * VAR01 + + CONST007 * VAR20 * x + + CONST135 * VAR05 * VAR24 + + CONST140 * VAR03 * VAR26 + + VAR15 * (CONST032 * VAR07 * VAR26 + CONST047 * VAR05 + CONST131 * VAR24 * x) + + VAR17 + * ( + -CONST071 * VAR07 * VAR24 + + CONST071 * VAR22 * x + + CONST111 * VAR05 * VAR26 + + CONST127 * VAR03 + ) + ) + Y05 = ( + VAR14 * (CONST030 * VAR07 * z - CONST030 * VAR25 * x) + + VAR16 * (CONST030 * VAR23 * x + CONST125 * VAR05 * z) + + y + * ( + CONST034 * VAR07 * VAR23 + + CONST121 * VAR05 * VAR25 + - CONST121 * VAR21 * x + + CONST122 * VAR03 * z + ) + ) + Y06 = ( + CONST119 * VAR03 * VAR17 + - CONST137 * VAR01 + + VAR05 * (CONST035 * VAR17 * VAR26 - CONST086 * VAR15 + CONST143 * VAR24) + + VAR07 + * ( + CONST051 * VAR15 * VAR26 + - CONST065 * VAR17 * VAR24 + + CONST103 * VAR13 + + CONST141 * VAR22 + ) + + x + * ( + -CONST062 * VAR13 * VAR26 + - CONST092 * VAR17 * VAR22 + + CONST112 * VAR15 * VAR24 + + CONST144 * VAR20 + ) + ) + Y07 = ( + CONST132 * VAR03 * y * z + + VAR05 * (CONST081 * VAR16 * z + CONST107 * VAR25 * y) + + VAR07 + * (CONST026 * VAR14 * z + CONST044 * VAR16 * VAR25 + CONST107 * VAR23 * y) + + x + * ( + CONST025 * VAR14 * VAR25 + + CONST053 * VAR12 * z + + CONST081 * VAR16 * VAR23 + + CONST132 * VAR21 * y + ) + ) + Y08 = ( + CONST004 * VAR01 + + VAR03 * (CONST006 * VAR26 + CONST116 * VAR17) + + VAR05 * (CONST008 * VAR24 + CONST069 * VAR15 + CONST087 * VAR17 * VAR26) + + VAR07 + * ( + CONST005 * VAR22 + + CONST083 * VAR13 + + CONST087 * VAR17 * VAR24 + + CONST145 * VAR15 * VAR26 + ) + + x + * ( + CONST004 * VAR20 + + CONST022 * VAR11 + + CONST069 * VAR15 * VAR24 + + CONST082 * VAR13 * VAR26 + + CONST116 * VAR17 * VAR22 + ) + ) + Y09 = ( + CONST009 * VAR10 + + VAR12 * (CONST110 * VAR26 + CONST113 * VAR08) + + VAR14 * (CONST063 * VAR06 + CONST063 * VAR24 + CONST104 * VAR08 * VAR26) + + VAR16 + * ( + CONST056 * VAR06 * VAR26 + + CONST056 * VAR08 * VAR24 + + CONST101 * VAR04 + + CONST101 * VAR22 + ) + + y + * ( + CONST010 * VAR20 + + CONST011 * VAR02 + + CONST029 * VAR04 * VAR26 + + CONST029 * VAR08 * VAR22 + + CONST039 * VAR06 * VAR24 + ) + ) + Y10 = ( + CONST004 * VAR19 + + VAR21 * (CONST005 * VAR08 + CONST117 * VAR17) + + VAR23 * (CONST008 * VAR06 + CONST070 * VAR15 + CONST088 * VAR08 * VAR17) + + VAR25 + * ( + CONST012 * VAR04 + + CONST082 * VAR13 + + CONST087 * VAR06 * VAR17 + + CONST145 * VAR08 * VAR15 + ) + + z + * ( + CONST004 * VAR02 + + CONST023 * VAR11 + + CONST070 * VAR06 * VAR15 + + CONST084 * VAR08 * VAR13 + + CONST117 * VAR04 * VAR17 + ) + ) + Y11 = ( + VAR12 * (CONST115 * VAR08 - CONST115 * VAR26) + + VAR14 * (CONST066 * VAR06 + CONST072 * VAR24) + + VAR16 + * ( + CONST055 * VAR08 * VAR24 + + CONST093 * VAR04 + + CONST093 * VAR06 * VAR26 + - CONST093 * VAR22 + ) + + y + * ( + CONST013 * VAR02 + + CONST024 * VAR04 * VAR26 + + CONST133 * VAR08 * VAR22 + + CONST138 * VAR20 + ) + ) + Y12 = ( + CONST036 * VAR17 * VAR21 + + CONST137 * VAR19 + + VAR23 * (CONST086 * VAR15 + CONST123 * VAR08 * VAR17 - CONST143 * VAR06) + + VAR25 + * ( + CONST014 * VAR04 + - CONST051 * VAR08 * VAR15 + + CONST065 * VAR06 * VAR17 + - CONST103 * VAR13 + ) + + z + * ( + CONST062 * VAR08 * VAR13 + + CONST092 * VAR04 * VAR17 + - CONST112 * VAR06 * VAR15 + - CONST144 * VAR02 + ) + ) + Y13 = ( + VAR14 * (CONST049 * VAR06 + CONST049 * VAR24 + CONST090 * VAR08 * VAR26) + + VAR16 + * ( + CONST040 * VAR06 * VAR26 + + CONST040 * VAR08 * VAR24 + + CONST100 * VAR22 + + CONST102 * VAR04 + ) + + y + * ( + CONST016 * VAR20 + + CONST017 * VAR02 + + CONST094 * VAR06 * VAR24 + + CONST121 * VAR04 * VAR26 + + CONST122 * VAR08 * VAR22 + ) + ) + Y14 = ( + CONST007 * VAR02 * z + + CONST075 * VAR19 + + CONST136 * VAR06 * VAR23 + + CONST140 * VAR08 * VAR21 + + VAR15 * (CONST032 * VAR08 * VAR25 + CONST047 * VAR23 + CONST131 * VAR06 * z) + + VAR17 + * ( + CONST068 * VAR06 * VAR25 + + CONST073 * VAR04 * z + + CONST114 * VAR08 * VAR23 + + CONST128 * VAR21 + ) + ) + Y15 = VAR16 * ( + CONST037 * VAR22 + - CONST045 * VAR06 * VAR26 + + CONST045 * VAR08 * VAR24 + + CONST118 * VAR04 + ) + y * ( + CONST018 * VAR02 + + CONST089 * VAR04 * VAR26 + - CONST089 * VAR08 * VAR22 + + CONST142 * VAR20 + ) + Y16 = ( + CONST019 * VAR02 * z + + CONST076 * VAR19 + + CONST124 * VAR04 * VAR25 + - CONST129 * VAR08 * VAR21 + + CONST134 * VAR06 * VAR23 + + VAR17 + * ( + CONST038 * VAR06 * VAR25 + + CONST080 * VAR04 * z + + CONST097 * VAR08 * VAR23 + - CONST130 * VAR21 + ) + ) + Y17 = y * ( + CONST058 * VAR02 + + CONST058 * VAR20 + + CONST061 * VAR04 * VAR26 + + CONST061 * VAR08 * VAR22 + + CONST074 * VAR06 * VAR24 + ) + Y18 = ( + CONST001 * VAR19 + + CONST020 * VAR02 * z + + CONST078 * VAR04 * VAR25 + + CONST091 * VAR06 * VAR23 + + CONST105 * VAR08 * VAR21 + ) + # not the prettiest way to concatenate, but better than + # messing with the linter + tensors = [ + Y00, + Y01, + Y02, + Y03, + Y04, + Y05, + Y06, + Y07, + Y08, + Y09, + Y10, + Y11, + Y12, + Y13, + Y14, + Y15, + Y16, + Y17, + Y18, + ] + return torch.cat(tensors, dim=-1) + + +@triton.jit +def ninth_order_fwd( + coord_ptr: tl.tensor, + output_ptr: tl.tensor, + block_size: tl.constexpr, + coord_numel: tl.constexpr, + output_numel: tl.constexpr, + col_offset: tl.constexpr, + output_stride: tl.constexpr, +): + # these are hardcoded because they are predetermined; + coord_stride = 3 + # work out the row offsets + block_id = tl.program_id(0) + coord_striding = tl.arange(0, block_size) * coord_stride + # as the name suggests, this is effectively every node/atom + coord_row_offset = coord_striding + (block_size * coord_stride * block_id) + x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel) + y = tl.load( + coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + z = tl.load( + coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + # -------------------- variable and constant definitions + CONST000 = 1.93163963757558 + CONST001 = 2.65478475211798 + CONST002 = 1.72771101506082 + CONST004 = 1.59908344719522 + CONST005 = 6.39633378878088 + CONST006 = 6.39633378878088 + CONST007 = 8.63855507530412 + CONST008 = 9.59450068317133 + CONST009 = 4.35889894354067 + CONST010 = 10.7269778688696 + CONST011 = 10.7269778688696 + CONST012 = 6.39633378878088 + CONST013 = 15.0007324039945 + CONST014 = 13.0937127087774 + CONST016 = 14.4550674370400 + CONST017 = 14.4550674370400 + CONST018 = 13.3827919767794 + CONST019 = 13.5214774630291 + CONST020 = 23.8930627690618 + CONST021 = 27.0429549260581 + CONST022 = 29.2403830344269 + CONST023 = 29.2403830344269 + CONST024 = 30.0014648079890 + CONST025 = -480.023436927823 + CONST026 = -480.023436927823 + CONST029 = 42.9079114754785 + CONST030 = -462.562157985281 + CONST032 = -967.518168434061 + CONST034 = 57.8202697481601 + CONST035 = 58.9217071894985 + CONST036 = 58.9217071894985 + CONST037 = 62.4530292249704 + CONST038 = 1081.71819704233 + CONST039 = 64.3618672132178 + CONST040 = 578.202697481601 + CONST044 = 600.029296159779 + CONST045 = -936.795438374555 + CONST047 = 96.7518168434061 + CONST049 = 115.640539496320 + CONST051 = -392.811381263323 + CONST053 = 137.149553407950 + CONST055 = 150.007324039945 + CONST056 = -343.263291803828 + CONST058 = 11.2632978048796 + CONST061 = -315.372338536630 + CONST062 = -314.249105010659 + CONST063 = 205.957975082297 + CONST065 = -294.608535947493 + CONST066 = 240.011718463912 + CONST068 = 241.879542108515 + CONST069 = 255.853351551235 + CONST070 = 255.853351551235 + CONST071 = -241.879542108515 + CONST072 = -240.011718463912 + CONST073 = -241.879542108515 + CONST074 = 788.430846341574 + CONST075 = 1.72771101506082 + CONST076 = -1.93163963757558 + CONST077 = -1249.06058449941 + CONST078 = -223.001919177910 + CONST080 = -216.343639408465 + CONST081 = 300.014648079890 + CONST082 = -204.682681240988 + CONST083 = -204.682681240988 + CONST084 = -204.682681240988 + CONST086 = -196.405690631662 + CONST087 = -191.890013663426 + CONST088 = -191.890013663427 + CONST089 = -187.359087674911 + CONST090 = -693.843236977922 + CONST091 = 334.502878766866 + CONST092 = -176.765121568496 + CONST093 = -150.007324039945 + CONST094 = -144.550674370400 + CONST095 = 374.718175349822 + CONST096 = 374.718175349822 + CONST097 = -649.030918225395 + CONST099 = -630.744677073259 + CONST100 = -115.640539496320 + CONST101 = -114.421097267943 + CONST102 = -115.640539496320 + CONST103 = -104.749701670220 + CONST104 = 411.915950164594 + CONST105 = -95.5722510762473 + CONST106 = -90.1063824390370 + CONST107 = -90.0043944239669 + CONST109 = -80.2967518606762 + CONST110 = -78.4601809837321 + CONST111 = 435.383175795327 + CONST112 = -589.217071894985 + CONST113 = -78.4601809837321 + CONST114 = 435.383175795328 + CONST115 = -68.5747767039748 + CONST116 = -63.9633378878088 + CONST117 = -63.9633378878088 + CONST118 = -62.4530292249704 + CONST119 = -58.9217071894985 + CONST120 = -1081.71819704233 + CONST121 = -57.8202697481601 + CONST122 = -57.8202697481601 + CONST123 = -58.9217071894985 + CONST124 = -54.0859098521163 + CONST125 = 462.562157985281 + CONST127 = -48.3759084217031 + CONST128 = -48.3759084217030 + CONST129 = -38.6327927515116 + CONST130 = -30.9062342012093 + CONST131 = 483.759084217031 + CONST132 = -30.0014648079890 + CONST133 = -30.0014648079890 + CONST134 = -27.0429549260581 + CONST135 = -24.1879542108515 + CONST136 = -24.1879542108515 + CONST137 = -1.63671408859718 + CONST138 = -15.0007324039945 + CONST139 = -13.5214774630291 + CONST140 = -13.8216881204866 + CONST141 = -13.0937127087774 + CONST142 = -13.3827919767794 + CONST143 = -9.82028453158308 + CONST144 = -4.91014226579154 + CONST145 = 511.706703102471 + VAR06 = x * x * x * x + VAR07 = x * x * x + VAR08 = x * x + VAR01 = VAR07 * VAR07 * VAR07 + VAR02 = VAR06 * VAR06 + VAR03 = VAR06 * VAR07 + VAR04 = VAR07 * VAR07 + VAR05 = VAR07 * VAR08 + VAR15 = y * y * y * y + VAR16 = y * y * y + VAR17 = y * y + VAR10 = VAR16 * VAR16 * VAR16 + VAR11 = VAR15 * VAR15 + VAR12 = VAR15 * VAR16 + VAR13 = VAR16 * VAR16 + VAR14 = VAR16 * VAR17 + VAR24 = z * z * z * z + VAR25 = z * z * z + VAR26 = z * z + VAR19 = VAR25 * VAR25 * VAR25 + VAR20 = VAR24 * VAR24 + VAR21 = VAR24 * VAR25 + VAR22 = VAR25 * VAR25 + VAR23 = VAR25 * VAR26 + # -------------------- kernel implementations + Y00 = ( + CONST001 * VAR01 + + CONST020 * VAR20 * x + + CONST078 * VAR07 * VAR22 + + CONST091 * VAR05 * VAR24 + + CONST105 * VAR03 * VAR26 + ) + Y01 = y * ( + -CONST099 * VAR05 * VAR25 + + CONST099 * VAR07 * VAR23 + + CONST106 * VAR03 * z + - CONST106 * VAR21 * x + ) + Y02 = ( + CONST000 * VAR01 + + VAR03 * (CONST129 * VAR26 + CONST130 * VAR17) + + VAR05 * (CONST021 * VAR24 - CONST097 * VAR17 * VAR26) + + VAR07 * (CONST120 * VAR17 * VAR24 - CONST124 * VAR22) + + x * (-CONST080 * VAR17 * VAR22 + CONST139 * VAR20) + ) + Y03 = VAR16 * ( + CONST077 * VAR07 * VAR25 + CONST095 * VAR05 * z + CONST096 * VAR23 * x + ) + y * ( + -CONST089 * VAR05 * VAR25 + - CONST089 * VAR07 * VAR23 + + CONST109 * VAR03 * z + + CONST109 * VAR21 * x + ) + Y04 = ( + CONST002 * VAR01 + + CONST007 * VAR20 * x + + CONST135 * VAR05 * VAR24 + + CONST140 * VAR03 * VAR26 + + VAR15 * (CONST032 * VAR07 * VAR26 + CONST047 * VAR05 + CONST131 * VAR24 * x) + + VAR17 + * ( + -CONST071 * VAR07 * VAR24 + + CONST071 * VAR22 * x + + CONST111 * VAR05 * VAR26 + + CONST127 * VAR03 + ) + ) + Y05 = ( + VAR14 * (CONST030 * VAR07 * z - CONST030 * VAR25 * x) + + VAR16 * (CONST030 * VAR23 * x + CONST125 * VAR05 * z) + + y + * ( + CONST034 * VAR07 * VAR23 + + CONST121 * VAR05 * VAR25 + - CONST121 * VAR21 * x + + CONST122 * VAR03 * z + ) + ) + Y06 = ( + CONST119 * VAR03 * VAR17 + - CONST137 * VAR01 + + VAR05 * (CONST035 * VAR17 * VAR26 - CONST086 * VAR15 + CONST143 * VAR24) + + VAR07 + * ( + CONST051 * VAR15 * VAR26 + - CONST065 * VAR17 * VAR24 + + CONST103 * VAR13 + + CONST141 * VAR22 + ) + + x + * ( + -CONST062 * VAR13 * VAR26 + - CONST092 * VAR17 * VAR22 + + CONST112 * VAR15 * VAR24 + + CONST144 * VAR20 + ) + ) + Y07 = ( + CONST132 * VAR03 * y * z + + VAR05 * (CONST081 * VAR16 * z + CONST107 * VAR25 * y) + + VAR07 + * (CONST026 * VAR14 * z + CONST044 * VAR16 * VAR25 + CONST107 * VAR23 * y) + + x + * ( + CONST025 * VAR14 * VAR25 + + CONST053 * VAR12 * z + + CONST081 * VAR16 * VAR23 + + CONST132 * VAR21 * y + ) + ) + Y08 = ( + CONST004 * VAR01 + + VAR03 * (CONST006 * VAR26 + CONST116 * VAR17) + + VAR05 * (CONST008 * VAR24 + CONST069 * VAR15 + CONST087 * VAR17 * VAR26) + + VAR07 + * ( + CONST005 * VAR22 + + CONST083 * VAR13 + + CONST087 * VAR17 * VAR24 + + CONST145 * VAR15 * VAR26 + ) + + x + * ( + CONST004 * VAR20 + + CONST022 * VAR11 + + CONST069 * VAR15 * VAR24 + + CONST082 * VAR13 * VAR26 + + CONST116 * VAR17 * VAR22 + ) + ) + Y09 = ( + CONST009 * VAR10 + + VAR12 * (CONST110 * VAR26 + CONST113 * VAR08) + + VAR14 * (CONST063 * VAR06 + CONST063 * VAR24 + CONST104 * VAR08 * VAR26) + + VAR16 + * ( + CONST056 * VAR06 * VAR26 + + CONST056 * VAR08 * VAR24 + + CONST101 * VAR04 + + CONST101 * VAR22 + ) + + y + * ( + CONST010 * VAR20 + + CONST011 * VAR02 + + CONST029 * VAR04 * VAR26 + + CONST029 * VAR08 * VAR22 + + CONST039 * VAR06 * VAR24 + ) + ) + Y10 = ( + CONST004 * VAR19 + + VAR21 * (CONST005 * VAR08 + CONST117 * VAR17) + + VAR23 * (CONST008 * VAR06 + CONST070 * VAR15 + CONST088 * VAR08 * VAR17) + + VAR25 + * ( + CONST012 * VAR04 + + CONST082 * VAR13 + + CONST087 * VAR06 * VAR17 + + CONST145 * VAR08 * VAR15 + ) + + z + * ( + CONST004 * VAR02 + + CONST023 * VAR11 + + CONST070 * VAR06 * VAR15 + + CONST084 * VAR08 * VAR13 + + CONST117 * VAR04 * VAR17 + ) + ) + Y11 = ( + VAR12 * (CONST115 * VAR08 - CONST115 * VAR26) + + VAR14 * (CONST066 * VAR06 + CONST072 * VAR24) + + VAR16 + * ( + CONST055 * VAR08 * VAR24 + + CONST093 * VAR04 + + CONST093 * VAR06 * VAR26 + - CONST093 * VAR22 + ) + + y + * ( + CONST013 * VAR02 + + CONST024 * VAR04 * VAR26 + + CONST133 * VAR08 * VAR22 + + CONST138 * VAR20 + ) + ) + Y12 = ( + CONST036 * VAR17 * VAR21 + + CONST137 * VAR19 + + VAR23 * (CONST086 * VAR15 + CONST123 * VAR08 * VAR17 - CONST143 * VAR06) + + VAR25 + * ( + CONST014 * VAR04 + - CONST051 * VAR08 * VAR15 + + CONST065 * VAR06 * VAR17 + - CONST103 * VAR13 + ) + + z + * ( + CONST062 * VAR08 * VAR13 + + CONST092 * VAR04 * VAR17 + - CONST112 * VAR06 * VAR15 + - CONST144 * VAR02 + ) + ) + Y13 = ( + VAR14 * (CONST049 * VAR06 + CONST049 * VAR24 + CONST090 * VAR08 * VAR26) + + VAR16 + * ( + CONST040 * VAR06 * VAR26 + + CONST040 * VAR08 * VAR24 + + CONST100 * VAR22 + + CONST102 * VAR04 + ) + + y + * ( + CONST016 * VAR20 + + CONST017 * VAR02 + + CONST094 * VAR06 * VAR24 + + CONST121 * VAR04 * VAR26 + + CONST122 * VAR08 * VAR22 + ) + ) + Y14 = ( + CONST007 * VAR02 * z + + CONST075 * VAR19 + + CONST136 * VAR06 * VAR23 + + CONST140 * VAR08 * VAR21 + + VAR15 * (CONST032 * VAR08 * VAR25 + CONST047 * VAR23 + CONST131 * VAR06 * z) + + VAR17 + * ( + CONST068 * VAR06 * VAR25 + + CONST073 * VAR04 * z + + CONST114 * VAR08 * VAR23 + + CONST128 * VAR21 + ) + ) + Y15 = VAR16 * ( + CONST037 * VAR22 + - CONST045 * VAR06 * VAR26 + + CONST045 * VAR08 * VAR24 + + CONST118 * VAR04 + ) + y * ( + CONST018 * VAR02 + + CONST089 * VAR04 * VAR26 + - CONST089 * VAR08 * VAR22 + + CONST142 * VAR20 + ) + Y16 = ( + CONST019 * VAR02 * z + + CONST076 * VAR19 + + CONST124 * VAR04 * VAR25 + - CONST129 * VAR08 * VAR21 + + CONST134 * VAR06 * VAR23 + + VAR17 + * ( + CONST038 * VAR06 * VAR25 + + CONST080 * VAR04 * z + + CONST097 * VAR08 * VAR23 + - CONST130 * VAR21 + ) + ) + Y17 = y * ( + CONST058 * VAR02 + + CONST058 * VAR20 + + CONST061 * VAR04 * VAR26 + + CONST061 * VAR08 * VAR22 + + CONST074 * VAR06 * VAR24 + ) + Y18 = ( + CONST001 * VAR19 + + CONST020 * VAR02 * z + + CONST078 * VAR04 * VAR25 + + CONST091 * VAR06 * VAR23 + + CONST105 * VAR08 * VAR21 + ) + output_striding = tl.arange(0, block_size) * output_stride + output_row_offset = ( + output_striding + (block_size * output_stride * block_id) + col_offset + ) + tl.store(output_ptr + output_row_offset, Y00, mask=output_row_offset < output_numel) + tl.store( + output_ptr + output_row_offset + 1, + Y01, + mask=output_row_offset + 1 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 2, + Y02, + mask=output_row_offset + 2 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 3, + Y03, + mask=output_row_offset + 3 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 4, + Y04, + mask=output_row_offset + 4 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 5, + Y05, + mask=output_row_offset + 5 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 6, + Y06, + mask=output_row_offset + 6 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 7, + Y07, + mask=output_row_offset + 7 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 8, + Y08, + mask=output_row_offset + 8 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 9, + Y09, + mask=output_row_offset + 9 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 10, + Y10, + mask=output_row_offset + 10 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 11, + Y11, + mask=output_row_offset + 11 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 12, + Y12, + mask=output_row_offset + 12 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 13, + Y13, + mask=output_row_offset + 13 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 14, + Y14, + mask=output_row_offset + 14 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 15, + Y15, + mask=output_row_offset + 15 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 16, + Y16, + mask=output_row_offset + 16 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 17, + Y17, + mask=output_row_offset + 17 < output_numel, + ) + tl.store( + output_ptr + output_row_offset + 18, + Y18, + mask=output_row_offset + 18 < output_numel, + ) + + +@triton.jit +def ninth_order_bwd( + coord_ptr: tl.tensor, + coord_grad_ptr: tl.tensor, + sph_grad_ptr: tl.tensor, + block_size: tl.constexpr, + coord_numel: tl.constexpr, + output_numel: tl.constexpr, + col_offset: tl.constexpr, + output_stride: tl.constexpr, +): + # work out the row offsets + block_id = tl.program_id(0) + # these are hardcoded because they are predetermined; + coord_stride = 3 + coord_striding = tl.arange(0, block_size) * coord_stride + # as the name suggests, this is effectively every node/atom + coord_row_offset = coord_striding + (block_size * coord_stride * block_id) + x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel) + y = tl.load( + coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + z = tl.load( + coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + output_striding = tl.arange(0, block_size) * output_stride + output_row_offset = ( + output_striding + (block_size * output_stride * block_id) + col_offset + ) + # load in gradients w.r.t. spherical harmonic projections + g_0 = tl.load( + sph_grad_ptr + output_row_offset, mask=output_row_offset < output_numel + ) + g_1 = tl.load( + sph_grad_ptr + output_row_offset + 1, mask=output_row_offset + 1 < output_numel + ) + g_2 = tl.load( + sph_grad_ptr + output_row_offset + 2, mask=output_row_offset + 2 < output_numel + ) + g_3 = tl.load( + sph_grad_ptr + output_row_offset + 3, mask=output_row_offset + 3 < output_numel + ) + g_4 = tl.load( + sph_grad_ptr + output_row_offset + 4, mask=output_row_offset + 4 < output_numel + ) + g_5 = tl.load( + sph_grad_ptr + output_row_offset + 5, mask=output_row_offset + 5 < output_numel + ) + g_6 = tl.load( + sph_grad_ptr + output_row_offset + 6, mask=output_row_offset + 6 < output_numel + ) + g_7 = tl.load( + sph_grad_ptr + output_row_offset + 7, mask=output_row_offset + 7 < output_numel + ) + g_8 = tl.load( + sph_grad_ptr + output_row_offset + 8, mask=output_row_offset + 8 < output_numel + ) + g_9 = tl.load( + sph_grad_ptr + output_row_offset + 9, mask=output_row_offset + 9 < output_numel + ) + g_10 = tl.load( + sph_grad_ptr + output_row_offset + 10, + mask=output_row_offset + 10 < output_numel, + ) + g_11 = tl.load( + sph_grad_ptr + output_row_offset + 11, + mask=output_row_offset + 11 < output_numel, + ) + g_12 = tl.load( + sph_grad_ptr + output_row_offset + 12, + mask=output_row_offset + 12 < output_numel, + ) + g_13 = tl.load( + sph_grad_ptr + output_row_offset + 13, + mask=output_row_offset + 13 < output_numel, + ) + g_14 = tl.load( + sph_grad_ptr + output_row_offset + 14, + mask=output_row_offset + 14 < output_numel, + ) + g_15 = tl.load( + sph_grad_ptr + output_row_offset + 15, + mask=output_row_offset + 15 < output_numel, + ) + g_16 = tl.load( + sph_grad_ptr + output_row_offset + 16, + mask=output_row_offset + 16 < output_numel, + ) + g_17 = tl.load( + sph_grad_ptr + output_row_offset + 17, + mask=output_row_offset + 17 < output_numel, + ) + g_18 = tl.load( + sph_grad_ptr + output_row_offset + 18, + mask=output_row_offset + 18 < output_numel, + ) + # -------------------- variable and constant definitions + CONST000 = 1.59908344719522 + CONST001 = 2.00000000000000 + CONST002 = 3.00000000000000 + CONST003 = 4.00000000000000 + CONST004 = 5.00000000000000 + CONST005 = 6.39633378878088 + CONST006 = 7.00000000000000 + CONST007 = 8.63855507530412 + CONST008 = 9.59450068317133 + CONST009 = 6.39633378878088 + CONST011 = 12.7926675775618 + CONST012 = 12.7926675775618 + CONST014 = 15.5493991355474 + CONST015 = 14.3917510247570 + CONST017 = 15.0007324039945 + CONST018 = 14.4550674370400 + CONST019 = 14.4550674370400 + CONST020 = 13.3827919767794 + CONST021 = 23.8930627690618 + CONST022 = 23.8930627690618 + CONST023 = 27.0429549260581 + CONST024 = 29.2403830344269 + CONST025 = 30.0014648079890 + CONST027 = 29.2403830344269 + CONST028 = 38.3780027326853 + CONST031 = 39.2300904918661 + CONST032 = 42.9079114754785 + CONST033 = 10.7269778688696 + CONST034 = 54.0859098521163 + CONST036 = 58.9217071894985 + CONST037 = 57.8202697481601 + CONST038 = 60.0029296159779 + CONST039 = 62.4530292249704 + CONST040 = 64.3618672132178 + CONST042 = 69.1084406024329 + CONST044 = 78.5622762526647 + CONST045 = 85.8158229509570 + CONST046 = 85.8158229509570 + CONST050 = 107.062335814235 + CONST052 = 108.171819704233 + CONST053 = -1935.03633686812 + CONST055 = 115.640539496320 + CONST056 = 117.843414378997 + CONST057 = 117.843414378997 + CONST059 = 120.005859231956 + CONST060 = 2176.91587897664 + CONST061 = 2176.91587897664 + CONST064 = 150.007324039945 + CONST065 = -1892.23403121978 + CONST066 = -1885.49463006395 + CONST067 = 173.460809244480 + CONST068 = -1873.59087674911 + CONST070 = 10.7269778688696 + CONST071 = 180.008788847934 + CONST074 = 13.5214774630291 + CONST076 = 205.957975082297 + CONST078 = 216.343639408465 + CONST079 = 4326.87278816930 + CONST080 = 233.923064275415 + CONST081 = 233.923064275415 + CONST082 = 240.011718463912 + CONST083 = 241.879542108515 + CONST085 = 255.853351551235 + CONST086 = 255.853351551235 + CONST087 = 257.447468852871 + CONST088 = 257.447468852871 + CONST090 = 270.429549260581 + CONST091 = 289.101348740801 + CONST093 = 300.014648079890 + CONST097 = 13.0937127087774 + CONST099 = -3747.18175349822 + CONST100 = 6.39633378878088 + CONST103 = 374.718175349822 + CONST105 = 404.741888237121 + CONST106 = 411.915950164594 + CONST107 = 412.451950326490 + CONST108 = 432.687278816930 + CONST109 = 435.383175795328 + CONST110 = 435.383175795327 + CONST112 = 462.562157985281 + CONST113 = -1571.24552505329 + CONST114 = 483.759084217031 + CONST115 = 511.706703102471 + CONST116 = 562.077263024733 + CONST117 = 578.202697481601 + CONST119 = -1451.27725265109 + CONST121 = -1451.27725265109 + CONST123 = 600.029296159779 + CONST124 = -1440.07031078347 + CONST129 = -1387.68647395584 + CONST130 = -1387.68647395584 + CONST131 = -1373.05316721531 + CONST132 = -1338.01151506746 + CONST133 = 725.638626325546 + CONST134 = -1298.06183645079 + CONST137 = 788.430846341574 + CONST138 = -1249.06058449941 + CONST139 = -1228.09608744593 + CONST140 = -1228.09608744593 + CONST141 = 823.831900329187 + CONST142 = -3245.15459112698 + CONST143 = -1178.43414378997 + CONST144 = 870.766351590655 + CONST145 = 870.766351590655 + CONST147 = -1124.15452604947 + CONST149 = -3153.72338536630 + CONST150 = 960.046873855647 + CONST151 = 960.046873855647 + CONST152 = 967.518168434061 + CONST153 = -1081.71819704233 + CONST154 = 967.518168434061 + CONST155 = -1060.59072941097 + CONST156 = 1023.41340620494 + CONST157 = 1023.41340620494 + CONST159 = -967.518168434061 + CONST160 = 1081.71819704233 + CONST161 = -960.046873855647 + CONST163 = -936.795438374555 + CONST165 = -900.043944239669 + CONST166 = 1156.40539496320 + CONST168 = -2902.55450530218 + CONST170 = 11.2632978048796 + CONST171 = -785.622762526647 + CONST172 = -785.622762526647 + CONST173 = -767.560054653706 + CONST175 = 1338.01151506746 + CONST176 = -693.843236977922 + CONST177 = -693.843236977921 + CONST178 = -686.526583607656 + CONST179 = -669.005757533731 + CONST180 = -669.005757533731 + CONST182 = -649.030918225395 + CONST183 = -630.744677073259 + CONST184 = -628.498210021318 + CONST185 = -628.498210021317 + CONST186 = -600.029296159779 + CONST187 = -589.217071894985 + CONST188 = -578.202697481601 + CONST189 = 15.5493991355474 + CONST190 = -562.077263024733 + CONST191 = 1500.07324039945 + CONST192 = -480.023436927823 + CONST193 = -480.023436927823 + CONST195 = -462.562157985281 + CONST196 = -450.021972119834 + CONST197 = -412.451950326490 + CONST198 = -409.365362481977 + CONST199 = -409.365362481976 + CONST200 = -404.741888237121 + CONST201 = -392.811381263323 + CONST202 = -383.780027326853 + CONST203 = -383.780027326853 + CONST204 = 1672.51439383433 + CONST205 = -374.718175349822 + CONST206 = -353.530243136991 + CONST207 = -2400.11718463912 + CONST209 = -346.921618488961 + CONST210 = -346.921618488961 + CONST211 = -343.263291803828 + CONST212 = -338.631358951921 + CONST213 = -338.631358951921 + CONST214 = -324.515459112698 + CONST215 = -315.372338536630 + CONST216 = -314.249105010659 + CONST217 = -2356.86828757994 + CONST218 = -300.014648079890 + CONST219 = -294.608535947493 + CONST220 = -289.101348740801 + CONST221 = -270.013183271901 + CONST222 = -2312.81078992641 + CONST223 = 1800.08788847934 + CONST224 = -241.879542108515 + CONST225 = -240.011718463912 + CONST226 = -241.879542108515 + CONST227 = -4326.87278816930 + CONST228 = -216.343639408465 + CONST229 = -210.010253655923 + CONST230 = -204.682681240988 + CONST231 = -204.682681240988 + CONST232 = -204.682681240988 + CONST233 = -196.405690631662 + CONST234 = -191.144502152495 + CONST235 = -191.890013663426 + CONST236 = -191.890013663427 + CONST237 = -187.359087674911 + CONST238 = -180.008788847934 + CONST239 = -176.765121568496 + CONST241 = 1873.59087674911 + CONST242 = -173.460809244480 + CONST244 = -162.257729556349 + CONST245 = -156.920361967464 + CONST246 = -156.920361967464 + CONST248 = -150.007324039945 + CONST249 = -144.550674370400 + CONST250 = -137.149553407950 + CONST251 = -135.214774630291 + CONST252 = -127.926675775618 + CONST253 = -127.926675775618 + CONST254 = -120.939771054258 + CONST255 = -120.005859231956 + CONST256 = -120.939771054258 + CONST257 = -117.843414378997 + CONST258 = -117.843414378997 + CONST259 = -115.640539496320 + CONST260 = -115.640539496320 + CONST261 = 1935.03633686812 + CONST262 = -2163.43639408465 + CONST263 = -114.421097267943 + CONST264 = -108.171819704233 + CONST265 = -107.062335814235 + CONST266 = -108.171819704233 + CONST267 = -104.749701670220 + CONST268 = -96.7518168434061 + CONST269 = -96.7518168434061 + CONST270 = -90.0043944239669 + CONST271 = -90.1063824390370 + CONST272 = -80.2967518606762 + CONST273 = -78.4601809837321 + CONST274 = -78.4601809837321 + CONST275 = -77.2655855030233 + CONST276 = -78.5622762526647 + CONST277 = -68.5747767039748 + CONST278 = -63.9633378878088 + CONST279 = -62.4530292249704 + CONST280 = -61.8124684024186 + CONST281 = -60.0029296159779 + CONST282 = -63.9633378878088 + CONST283 = -58.9217071894985 + CONST284 = -57.8202697481601 + CONST285 = -57.8202697481601 + CONST286 = -48.3759084217030 + CONST287 = -48.3759084217031 + CONST288 = -39.2811381263323 + CONST289 = -38.6327927515116 + CONST290 = -39.2811381263323 + CONST291 = -30.9062342012093 + CONST292 = -30.0014648079890 + CONST293 = -30.0014648079890 + CONST294 = -27.6433762409732 + CONST295 = -17.3847567381802 + CONST296 = -15.0007324039945 + CONST297 = -14.7304267973746 + CONST298 = -13.5214774630291 + CONST299 = -13.0937127087774 + CONST300 = -13.3827919767794 + CONST301 = -9.82028453158308 + CONST302 = -4.91014226579154 + CONST303 = 2046.82681240988 + VAR06 = x * x * x * x + VAR07 = x * x * x + VAR08 = x * x + VAR02 = VAR06 * VAR06 + VAR03 = VAR06 * VAR07 + VAR04 = VAR07 * VAR07 + VAR05 = VAR07 * VAR08 + VAR15 = y * y * y * y + VAR16 = y * y * y + VAR17 = y * y + VAR11 = VAR15 * VAR15 + VAR12 = VAR15 * VAR16 + VAR13 = VAR16 * VAR16 + VAR14 = VAR16 * VAR17 + VAR24 = z * z * z * z + VAR25 = z * z * z + VAR26 = z * z + VAR20 = VAR24 * VAR24 + VAR21 = VAR24 * VAR25 + VAR22 = VAR25 * VAR25 + VAR23 = VAR25 * VAR26 + # -------------------- kernel implementations + g_x = tl.load( + coord_grad_ptr + coord_row_offset, mask=coord_row_offset < coord_numel + ) + g_y = tl.load( + coord_grad_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel + ) + g_z = tl.load( + coord_grad_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel + ) + g_x += ( + g_0 + * ( + CONST021 * VAR20 + + CONST022 * VAR02 + + CONST179 * VAR04 * VAR26 + + CONST180 * VAR08 * VAR22 + + CONST204 * VAR06 * VAR24 + ) + + g_1 + * y + * ( + CONST065 * VAR08 * VAR23 + - CONST149 * VAR06 * VAR25 + + CONST183 * VAR04 * z + - CONST271 * VAR21 + ) + + g_10 + * ( + CONST012 * VAR21 * x + + VAR23 * (CONST028 * VAR07 + CONST203 * VAR17 * x) + + VAR25 + * (CONST028 * VAR05 + CONST157 * VAR15 * x + CONST173 * VAR07 * VAR17) + + z + * ( + CONST011 * VAR03 + + CONST157 * VAR07 * VAR15 + + CONST198 * VAR13 * x + + CONST202 * VAR05 * VAR17 + ) + ) + + g_11 + * ( + CONST150 * VAR07 * VAR14 + + CONST250 * VAR12 * x + + VAR16 + * (CONST093 * VAR24 * x + CONST165 * VAR05 + CONST186 * VAR07 * VAR26) + + y * (CONST059 * VAR03 + CONST071 * VAR05 * VAR26 + CONST281 * VAR22 * x) + ) + + g_12 + * ( + VAR23 * (CONST257 * VAR17 * x - CONST290 * VAR07) + + VAR25 + * (CONST044 * VAR05 + CONST143 * VAR07 * VAR17 - CONST172 * VAR15 * x) + + z + * ( + CONST155 * VAR05 * VAR17 + + CONST184 * VAR13 * x + - CONST217 * VAR07 * VAR15 + - CONST288 * VAR03 + ) + ) + + g_13 + * ( + VAR14 * (CONST129 * VAR26 * x - CONST195 * VAR07) + + VAR16 + * (CONST166 * VAR24 * x + CONST176 * VAR05 - CONST222 * VAR07 * VAR26) + + y + * ( + CONST188 * VAR07 * VAR24 + + CONST209 * VAR05 * VAR26 + - CONST259 * VAR03 + + CONST259 * VAR22 * x + ) + ) + + g_14 + * ( + CONST042 * VAR03 * z + + CONST268 * VAR07 * VAR23 + + CONST294 * VAR21 * x + + VAR15 * (CONST053 * VAR25 * x + CONST261 * VAR07 * z) + + VAR17 + * (CONST119 * VAR05 * z + CONST144 * VAR23 * x + CONST152 * VAR07 * VAR25) + ) + + g_15 + * ( + VAR16 * (CONST068 * VAR24 * x - CONST099 * VAR07 * VAR26 + CONST205 * VAR05) + + y * (CONST050 * VAR03 + CONST147 * VAR05 * VAR26 - CONST205 * VAR22 * x) + ) + + g_16 + * ( + CONST214 * VAR05 * VAR25 + - CONST264 * VAR03 * z + + CONST264 * VAR07 * VAR23 + - CONST275 * VAR21 * x + + VAR17 + * (CONST079 * VAR07 * VAR25 + CONST134 * VAR05 * z + CONST134 * VAR23 * x) + ) + + g_17 + * y + * ( + CONST065 * VAR05 * VAR26 + - CONST149 * VAR07 * VAR24 + + CONST183 * VAR22 * x + - CONST271 * VAR03 + ) + + g_18 + * ( + CONST132 * VAR05 * VAR25 + + CONST175 * VAR07 * VAR23 + - CONST234 * VAR03 * z + + CONST234 * VAR21 * x + ) + + g_2 + * ( + CONST002 * VAR08 * (CONST034 * VAR22 + CONST153 * VAR17 * VAR24) + + CONST004 * VAR06 * (CONST023 * VAR24 - CONST182 * VAR17 * VAR26) + + CONST006 * VAR04 * (CONST289 * VAR26 + CONST291 * VAR17) + - CONST228 * VAR17 * VAR22 + - CONST295 * VAR02 + + CONST298 * VAR20 + ) + + g_3 + * ( + VAR16 + * (-CONST068 * VAR06 * z + CONST099 * VAR08 * VAR25 + CONST103 * VAR23) + + y + * ( + CONST116 * VAR08 * VAR23 + - CONST163 * VAR06 * VAR25 + + CONST190 * VAR04 * z + + CONST272 * VAR21 + ) + ) + + g_4 + * ( + CONST007 * VAR20 + + CONST014 * VAR02 + + CONST254 * VAR06 * VAR24 + + CONST269 * VAR04 * VAR26 + + VAR15 * (CONST114 * VAR06 + CONST114 * VAR24 + CONST168 * VAR08 * VAR26) + + VAR17 + * ( + CONST060 * VAR06 * VAR26 + + CONST133 * VAR08 * VAR24 + + CONST212 * VAR04 + + CONST224 * VAR22 + ) + ) + + g_5 + * ( + VAR14 * (CONST130 * VAR08 * z - CONST195 * VAR25) + + VAR16 * (CONST195 * VAR23 - CONST222 * VAR06 * z) + + y + * ( + CONST067 * VAR08 * VAR23 + + CONST200 * VAR04 * z + + CONST220 * VAR06 * VAR25 + - CONST284 * VAR21 + ) + ) + + g_6 + * ( + CONST002 + * VAR08 + * ( + CONST201 * VAR15 * VAR26 + - CONST219 * VAR17 * VAR24 + + CONST267 * VAR13 + + CONST299 * VAR22 + ) + + CONST004 + * VAR06 + * (CONST036 * VAR17 * VAR26 - CONST233 * VAR15 + CONST301 * VAR24) + + CONST187 * VAR15 * VAR24 + + CONST197 * VAR04 * VAR17 + - CONST216 * VAR13 * VAR26 + - CONST239 * VAR17 * VAR22 + - CONST297 * VAR02 + + CONST302 * VAR20 + ) + + g_7 + * ( + CONST002 + * VAR08 + * (-CONST186 * VAR16 * VAR25 + CONST192 * VAR14 * z + CONST270 * VAR23 * y) + + CONST004 * VAR06 * (-CONST218 * VAR16 * z + CONST270 * VAR25 * y) + + CONST193 * VAR14 * VAR25 + - CONST218 * VAR16 * VAR23 + + CONST229 * VAR04 * y * z + - CONST250 * VAR12 * z + + CONST292 * VAR21 * y + ) + + g_8 + * ( + CONST000 * VAR20 + + CONST002 + * VAR08 + * ( + CONST005 * VAR22 + + CONST115 * VAR15 * VAR26 + + CONST230 * VAR13 + + CONST235 * VAR17 * VAR24 + ) + + CONST004 + * VAR06 + * (CONST008 * VAR24 + CONST085 * VAR15 + CONST235 * VAR17 * VAR26) + + CONST006 * VAR04 * (CONST009 * VAR26 + CONST278 * VAR17) + + CONST015 * VAR02 + + CONST024 * VAR11 + + CONST085 * VAR15 * VAR24 + + CONST231 * VAR13 * VAR26 + + CONST278 * VAR17 * VAR22 + ) + + g_9 + * ( + CONST245 * VAR12 * x + + VAR14 * (CONST141 * VAR07 + CONST141 * VAR26 * x) + + VAR16 + * (CONST131 * VAR07 * VAR26 + CONST178 * VAR05 + CONST178 * VAR24 * x) + + y + * ( + CONST045 * VAR03 + + CONST046 * VAR22 * x + + CONST087 * VAR05 * VAR26 + + CONST088 * VAR07 * VAR24 + ) + ) + ) + g_y += ( + CONST001 + * g_16 + * y + * ( + CONST160 * VAR06 * VAR25 + + CONST182 * VAR08 * VAR23 + + CONST228 * VAR04 * z + - CONST291 * VAR21 + ) + + g_1 + * ( + -CONST183 * VAR05 * VAR25 + + CONST183 * VAR07 * VAR23 + + CONST271 * VAR03 * z + - CONST271 * VAR21 * x + ) + + g_10 + * ( + CONST252 * VAR21 * y + + VAR23 * (CONST157 * VAR16 + CONST203 * VAR08 * y) + + VAR25 + * (CONST140 * VAR14 + CONST202 * VAR06 * y + CONST303 * VAR08 * VAR16) + + z + * ( + CONST080 * VAR12 + + CONST139 * VAR08 * VAR14 + + CONST157 * VAR06 * VAR16 + + CONST252 * VAR04 * y + ) + ) + + g_11 + * ( + CONST002 + * VAR17 + * ( + CONST064 * VAR08 * VAR24 + + CONST248 * VAR04 + + CONST248 * VAR06 * VAR26 + - CONST248 * VAR22 + ) + + CONST004 * VAR15 * (CONST082 * VAR06 + CONST225 * VAR24) + + CONST006 * VAR13 * (CONST277 * VAR08 - CONST277 * VAR26) + + CONST017 * VAR02 + + CONST025 * VAR04 * VAR26 + + CONST293 * VAR08 * VAR22 + + CONST296 * VAR20 + ) + + g_12 + * ( + CONST056 * VAR21 * y + + VAR23 * (CONST171 * VAR16 + CONST257 * VAR08 * y) + + VAR25 + * (-CONST113 * VAR08 * VAR16 - CONST185 * VAR14 + CONST187 * VAR06 * y) + + z + * ( + CONST066 * VAR08 * VAR14 + + CONST206 * VAR04 * y + - CONST217 * VAR06 * VAR16 + ) + ) + + g_13 + * ( + CONST002 + * VAR17 + * ( + CONST117 * VAR06 * VAR26 + + CONST117 * VAR08 * VAR24 + + CONST259 * VAR04 + + CONST260 * VAR22 + ) + + CONST004 + * VAR15 + * (CONST055 * VAR06 + CONST055 * VAR24 + CONST176 * VAR08 * VAR26) + + CONST018 * VAR20 + + CONST019 * VAR02 + + CONST249 * VAR06 * VAR24 + + CONST284 * VAR04 * VAR26 + + CONST285 * VAR08 * VAR22 + ) + + g_14 + * ( + CONST001 + * y + * ( + CONST083 * VAR06 * VAR25 + + CONST109 * VAR08 * VAR23 + + CONST226 * VAR04 * z + + CONST286 * VAR21 + ) + + CONST003 + * VAR16 + * (CONST114 * VAR06 * z + CONST159 * VAR08 * VAR25 - CONST269 * VAR23) + ) + + g_15 + * ( + CONST002 + * VAR17 + * ( + CONST039 * VAR22 + - CONST163 * VAR06 * VAR26 + + CONST163 * VAR08 * VAR24 + + CONST279 * VAR04 + ) + + CONST020 * VAR02 + + CONST237 * VAR04 * VAR26 + - CONST237 * VAR08 * VAR22 + + CONST300 * VAR20 + ) + + g_17 + * ( + CONST137 * VAR06 * VAR24 + + CONST170 * VAR02 + + CONST170 * VAR20 + + CONST215 * VAR04 * VAR26 + + CONST215 * VAR08 * VAR22 + ) + + g_2 + * ( + CONST108 * VAR22 * x * y + - CONST134 * VAR05 * VAR26 * y + + CONST262 * VAR07 * VAR24 * y + + CONST280 * VAR03 * y + ) + + g_3 + * ( + CONST002 + * VAR17 + * (CONST103 * VAR23 * x + CONST138 * VAR07 * VAR25 - CONST205 * VAR05 * z) + - CONST237 * VAR05 * VAR25 + - CONST237 * VAR07 * VAR23 + + CONST272 * VAR03 * z + + CONST272 * VAR21 * x + ) + + g_4 + * ( + CONST001 + * y + * ( + CONST110 * VAR05 * VAR26 + - CONST224 * VAR07 * VAR24 + + CONST224 * VAR22 * x + + CONST287 * VAR03 + ) + + CONST003 + * VAR16 + * (CONST114 * VAR24 * x + CONST159 * VAR07 * VAR26 - CONST269 * VAR05) + ) + + g_5 + * ( + CONST002 * VAR17 * (CONST112 * VAR05 * z + CONST195 * VAR23 * x) + + CONST004 * VAR15 * (CONST195 * VAR07 * z - CONST195 * VAR25 * x) + + CONST037 * VAR07 * VAR23 + + CONST284 * VAR05 * VAR25 + - CONST284 * VAR21 * x + + CONST285 * VAR03 * z + ) + + g_6 + * ( + CONST258 * VAR03 * y + + VAR05 * (CONST057 * VAR26 * y - CONST171 * VAR16) + + VAR07 + * (CONST113 * VAR16 * VAR26 + CONST185 * VAR14 - CONST187 * VAR24 * y) + + x + * ( + -CONST066 * VAR14 * VAR26 + - CONST206 * VAR22 * y + + CONST217 * VAR16 * VAR24 + ) + ) + + g_7 + * ( + CONST292 * VAR03 * z + + VAR05 * (-CONST165 * VAR17 * z + CONST270 * VAR25) + + VAR07 + * (CONST207 * VAR15 * z + CONST223 * VAR17 * VAR25 + CONST270 * VAR23) + + x + * ( + CONST151 * VAR13 * z + - CONST165 * VAR17 * VAR23 + + CONST207 * VAR15 * VAR25 + + CONST292 * VAR21 + ) + ) + + g_8 + * ( + CONST253 * VAR03 * y + + VAR05 * (CONST156 * VAR16 + CONST202 * VAR26 * y) + + VAR07 + * (CONST139 * VAR14 + CONST202 * VAR24 * y + CONST303 * VAR16 * VAR26) + + x + * ( + CONST081 * VAR12 + + CONST140 * VAR14 * VAR26 + + CONST156 * VAR16 * VAR24 + + CONST253 * VAR22 * y + ) + ) + + g_9 + * ( + CONST002 + * VAR17 + * ( + CONST211 * VAR06 * VAR26 + + CONST211 * VAR08 * VAR24 + + CONST263 * VAR04 + + CONST263 * VAR22 + ) + + CONST004 + * VAR15 + * (CONST076 * VAR06 + CONST076 * VAR24 + CONST106 * VAR08 * VAR26) + + CONST006 * VAR13 * (CONST273 * VAR26 + CONST274 * VAR08) + + CONST031 * VAR11 + + CONST032 * VAR04 * VAR26 + + CONST032 * VAR08 * VAR22 + + CONST033 * VAR20 + + CONST040 * VAR06 * VAR24 + + CONST070 * VAR02 + ) + ) + g_z += ( + g_0 + * ( + CONST132 * VAR07 * VAR23 + + CONST175 * VAR05 * VAR25 + + CONST234 * VAR03 * z + - CONST234 * VAR21 * x + ) + + g_1 + * y + * ( + -CONST065 * VAR05 * VAR26 + + CONST149 * VAR07 * VAR24 + - CONST183 * VAR22 * x + + CONST271 * VAR03 + ) + + g_10 + * ( + CONST000 * VAR02 + + CONST002 + * VAR26 + * ( + CONST100 * VAR04 + + CONST115 * VAR08 * VAR15 + + CONST231 * VAR13 + + CONST235 * VAR06 * VAR17 + ) + + CONST004 + * VAR24 + * (CONST008 * VAR06 + CONST086 * VAR15 + CONST236 * VAR08 * VAR17) + + CONST006 * VAR22 * (CONST005 * VAR08 + CONST282 * VAR17) + + CONST015 * VAR20 + + CONST027 * VAR11 + + CONST086 * VAR06 * VAR15 + + CONST232 * VAR08 * VAR13 + + CONST282 * VAR04 * VAR17 + ) + + g_11 + * ( + CONST161 * VAR14 * VAR25 + - CONST250 * VAR12 * z + + VAR16 + * (CONST123 * VAR08 * VAR25 - CONST165 * VAR23 + CONST218 * VAR06 * z) + + y * (CONST038 * VAR04 * z + CONST238 * VAR08 * VAR23 + CONST255 * VAR21) + ) + + g_12 + * ( + CONST002 + * VAR26 + * ( + CONST097 * VAR04 + - CONST201 * VAR08 * VAR15 + + CONST219 * VAR06 * VAR17 + - CONST267 * VAR13 + ) + + CONST004 + * VAR24 + * (CONST233 * VAR15 + CONST283 * VAR08 * VAR17 - CONST301 * VAR06) + + CONST107 * VAR17 * VAR22 + - CONST187 * VAR06 * VAR15 + + CONST216 * VAR08 * VAR13 + + CONST239 * VAR04 * VAR17 + + CONST297 * VAR20 + - CONST302 * VAR02 + ) + + g_13 + * ( + VAR14 * (CONST129 * VAR08 * z - CONST195 * VAR25) + + VAR16 + * (CONST166 * VAR06 * z + CONST177 * VAR23 - CONST222 * VAR08 * VAR25) + + y + * ( + CONST188 * VAR06 * VAR25 + + CONST210 * VAR08 * VAR23 + + CONST260 * VAR04 * z + - CONST260 * VAR21 + ) + ) + + g_14 + * ( + CONST007 * VAR02 + + CONST189 * VAR20 + + CONST256 * VAR06 * VAR24 + + CONST269 * VAR08 * VAR22 + + VAR15 * (CONST114 * VAR06 + CONST114 * VAR24 + CONST168 * VAR08 * VAR26) + + VAR17 + * ( + CONST061 * VAR08 * VAR24 + + CONST133 * VAR06 * VAR26 + + CONST213 * VAR22 + + CONST226 * VAR04 + ) + ) + + g_15 + * ( + VAR16 + * (-CONST068 * VAR06 * z + CONST099 * VAR08 * VAR25 + CONST103 * VAR23) + + y * (-CONST147 * VAR08 * VAR23 + CONST205 * VAR04 * z + CONST265 * VAR21) + ) + + g_16 + * ( + CONST074 * VAR02 + + CONST090 * VAR08 * VAR22 + + CONST244 * VAR04 * VAR26 + + CONST251 * VAR06 * VAR24 + + CONST295 * VAR20 + + VAR17 + * ( + CONST078 * VAR22 + - CONST142 * VAR06 * VAR26 + + CONST142 * VAR08 * VAR24 + + CONST228 * VAR04 + ) + ) + + g_17 + * y + * ( + CONST065 * VAR08 * VAR23 + - CONST149 * VAR06 * VAR25 + + CONST183 * VAR04 * z + - CONST271 * VAR21 + ) + + g_18 + * ( + CONST021 * VAR02 + + CONST022 * VAR20 + + CONST179 * VAR08 * VAR22 + + CONST180 * VAR04 * VAR26 + + CONST204 * VAR06 * VAR24 + ) + + g_2 + * ( + CONST275 * VAR03 * z + + VAR05 * (CONST052 * VAR25 - CONST134 * VAR17 * z) + + VAR07 * (-CONST214 * VAR23 + CONST227 * VAR17 * VAR25) + + x * (-CONST134 * VAR17 * VAR23 + CONST266 * VAR21) + ) + + g_3 + * ( + VAR16 * (CONST099 * VAR07 * VAR26 - CONST205 * VAR05 + CONST241 * VAR24 * x) + + y + * ( + CONST116 * VAR05 * VAR26 + - CONST163 * VAR07 * VAR24 + + CONST190 * VAR22 * x + + CONST272 * VAR03 + ) + ) + + g_4 + * ( + CONST042 * VAR21 * x + + CONST269 * VAR05 * VAR25 + + CONST294 * VAR03 * z + + VAR15 * (CONST053 * VAR07 * z + CONST261 * VAR25 * x) + + VAR17 + * (CONST121 * VAR23 * x + CONST145 * VAR05 * z + CONST154 * VAR07 * VAR25) + ) + + g_5 + * ( + VAR14 * (-CONST130 * VAR26 * x + CONST195 * VAR07) + + VAR16 * (CONST112 * VAR05 + CONST222 * VAR24 * x) + + y + * ( + CONST091 * VAR07 * VAR24 + + CONST105 * VAR22 * x + + CONST242 * VAR05 * VAR26 + + CONST285 * VAR03 + ) + ) + + g_6 + * ( + VAR05 * (CONST057 * VAR17 * z + CONST290 * VAR25) + + VAR07 + * (-CONST143 * VAR17 * VAR25 + CONST172 * VAR15 * z + CONST276 * VAR23) + + x + * ( + -CONST155 * VAR17 * VAR23 + - CONST184 * VAR13 * z + + CONST217 * VAR15 * VAR25 + + CONST288 * VAR21 + ) + ) + + g_7 + * ( + CONST292 * VAR03 * y + + VAR05 * (-CONST218 * VAR16 + CONST221 * VAR26 * y) + + VAR07 + * (CONST192 * VAR14 + CONST196 * VAR24 * y + CONST223 * VAR16 * VAR26) + + x + * ( + CONST124 * VAR14 * VAR26 + + CONST191 * VAR16 * VAR24 + + CONST229 * VAR22 * y + - CONST250 * VAR12 + ) + ) + + g_8 + * ( + CONST011 * VAR03 * z + + VAR05 * (CONST028 * VAR25 + CONST202 * VAR17 * z) + + VAR07 + * (CONST028 * VAR23 + CONST157 * VAR15 * z + CONST173 * VAR17 * VAR25) + + x + * ( + CONST011 * VAR21 + + CONST156 * VAR15 * VAR25 + + CONST199 * VAR13 * z + + CONST202 * VAR17 * VAR23 + ) + ) + + g_9 + * ( + CONST246 * VAR12 * z + + VAR14 * (CONST141 * VAR08 * z + CONST141 * VAR25) + + VAR16 + * (CONST131 * VAR08 * VAR25 + CONST178 * VAR06 * z + CONST178 * VAR23) + + y + * ( + CONST046 * VAR04 * z + + CONST046 * VAR21 + + CONST087 * VAR08 * VAR23 + + CONST088 * VAR06 * VAR25 + ) + ) + ) + # write out gradients + tl.store( + coord_grad_ptr + coord_row_offset, g_x, mask=coord_row_offset < coord_numel + ) + tl.store( + coord_grad_ptr + coord_row_offset + 1, + g_y, + mask=coord_row_offset + 1 < coord_numel, + ) + tl.store( + coord_grad_ptr + coord_row_offset + 2, + g_z, + mask=coord_row_offset + 2 < coord_numel, + ) diff --git a/src/equitriton/utils.py b/src/equitriton/utils.py index 7760350..9eed927 100644 --- a/src/equitriton/utils.py +++ b/src/equitriton/utils.py @@ -1,9 +1,38 @@ from __future__ import annotations +from collections import Counter +import math + import torch import triton +import numpy as np +from e3nn import o3 + +__all__ = [ + "pad_tensor_to_power", + "calculate_lastdim_num_blocks", + "spherical_harmonics_irreps", + "num_irreps_projections", + "separate_embedding_irreps", +] -__all__ = ["pad_tensor_to_power"] + +def num_irreps_projections(l: int) -> int: + """ + Calculate the number of projections for a given order + of spherical harmonic. + + Parameters + ---------- + l : int + Order of spherical harmonic. + + Returns + ------- + int + Number of projections, i.e. 2l + 1 + """ + return 2 * l + 1 def pad_tensor_to_power( @@ -43,3 +72,141 @@ def pad_tensor_to_power( mask = torch.ones(pad_size, device=joint_tensor.device, dtype=torch.bool) mask[num_nodes:] = False return (joint_tensor, mask) + + +def calculate_lastdim_num_blocks(input_tensor: torch.Tensor, block_size: int) -> int: + """ + Calculate the number of blocks for a tensor, assuming we + stride along the last dimension, and a given block size. + + The corresponding pointer arithmetic looks like this: + + ```python + block_id = tl.program_id(0) + striding = tl.arange(0, block_size) * stride + offset = (striding + (block_size * stride * block_id)) + ``` + + This function is used to work out the amount of parallel + work that needs to be done, given as the total number of + elements divided by the last dimension stride, and a specified + block size that will then divvy up the work. + + Parameters + ---------- + input_tensor : torch.Tensor + Torch N-d tensor to operate over. + + Returns + ------- + int + Number of blocks of work, given a block size. + """ + # get the stride of the last dimension + stride = input_tensor.stride(-2) + numel = input_tensor.numel() + total_blocks = math.ceil(numel / stride) + return total_blocks + + +def unravel_index(tensor: torch.Tensor, index: int) -> tuple[int, ...]: + """ + For a given N-d tensor and a 1D index, work out the corresponding + index tuple for the N-d tensor. + + This is equivalent to the `torch.unravel_index` function, but + makes it a bit more friendlier in terms of Python types. + + Parameters + ---------- + tensor : torch.Tensor + Torch N-D tensor to index. + index : int + 1D index value to map onto an N-tuple, where N + is the dimensionality of the tensor. Must be + greater or equal to zero, and smaller than the + total number of elements. + + Returns + ------- + tuple[int, ...] + An N-tuple of integers corresponding to the + N-d index of the provided index. + """ + # make sure that the index is within bounds + assert 0 <= index < tensor.numel() + indices = [] + for size in reversed(tensor.shape): + indices.append(index % size) + index //= size + return tuple(reversed(indices)) + + +def spherical_harmonics_irreps(l_values: list[int], num_feat: int = 1) -> o3.Irreps: + """ + Generate the set of irreducible representations given a list of + arbitrary l values; i.e. they need not be contiguous. + + While ``l_values`` does not need to be contiguous, this function + will sort in ascending order of ``l``, such that the returned + representations are in order. This makes it a lot more straightforward + for building off of. + + Parameters + ---------- + l_values : list[int] + List of l values to generate representations for. + num_feat : int + Number of features for the associated representations. + Defaults to 1, which can be used for specifying a spherical + harmonic basis, but values greater than one can be used to + specify weights. + + Returns + ------- + o3.Irreps + Irreducible representations for the set of spherical harmonics. + """ + assert num_feat > 0, "Number of features must be positive!" + joint = [] + for l in sorted(l_values): + parity = "e" if (-1) ** l > 0 else "o" + joint.append(f"{num_feat}x{l}{parity}") + return o3.Irreps("+".join(joint)) + + +def separate_embedding_irreps( + embeddings: torch.Tensor | np.ndarray, irreps: o3.Irreps, return_numpy: bool = True +) -> dict[int, torch.Tensor]: + """ + Utility function that will split a joint embedding tensor + into embeddings for individual orders. + + Parameters + ---------- + embeddings : torch.Tensor + PyTorch N-d tensor containing embeddings for all irreps. + irreps : o3.Irreps + Object containing information on which orders of + representations, and how many. + + Returns + ------- + dict[int, torch.Tensor] + Dictionary mapping a tensor chunk with its corresponding order. + """ + # just for safety, clone the tensor for chunking + if isinstance(embeddings, torch.Tensor): + embeddings = embeddings.detach().cpu() + if isinstance(embeddings, np.ndarray): + embeddings = torch.from_numpy(embeddings) + irrep_dims = dict(Counter(irreps.ls)) + splits = np.cumsum(list(irrep_dims.values())).tolist() + return_dict = {} + chunks = torch.tensor_split(embeddings, splits, dim=-1) + # should be an extra empty chunk but zip should skip it + for key, chunk in zip(irrep_dims.keys(), chunks): + if return_numpy: + chunk = chunk.numpy() + return_dict[key] = chunk + return return_dict