Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EquiformerV2Backbone has surprisingly large equivariance error #1068

Closed
humeniuka opened this issue Mar 19, 2025 · 2 comments
Closed

EquiformerV2Backbone has surprisingly large equivariance error #1068

humeniuka opened this issue Mar 19, 2025 · 2 comments
Assignees

Comments

@humeniuka
Copy link

What would you like to report?

Hello,

It is my understanding that an EquiformerV2 network should be equivariant, that is to say if one rotates the atomic positions of the input (x), the node features (y) should transform according to their irreps. If the rotation operator is R, and the equivariant network M, the following should hold:
M(R(x)) = R(M(x))

I tested this for the water molecule and a minimal EquiformerV2Backbone model:

#!/usr/bin/env python
import torch
import ase.build
from e3nn import o3
import numpy

from fairchem.core.datasets import data_list_collater
from fairchem.core.models.equiformer_v2.equiformer_v2 import EquiformerV2Backbone
from fairchem.core.preprocessing import AtomsToGraphs

def test_equivariance():
    lmax = 2
    # Define model
    torch.manual_seed(100)  # fix network initialization
    backbone = EquiformerV2Backbone(
        use_pbc=False,
        num_layers=2,
        sphere_channels=2,
        attn_hidden_channels=2,
        num_sphere_samples=64,
        edge_channels=2,
        lmax_list = [lmax]
    )
    # Disable drop out in evaluation mode
    backbone.eval()
    assert not backbone.training

    # convert ASE atoms object to graph
    a2g = AtomsToGraphs(r_pbc=False)

    def model(atoms):
        inp = data_list_collater([a2g.convert(atoms)])        
        out = backbone(inp)
        node_features = out["node_embedding"].embedding
        node_features_aggr = torch.sum(node_features, dim=[0,2])
        return node_features_aggr
        
    # angles of rotation in 3D
    alpha, beta, gamma = torch.tensor(0.1), torch.tensor(0.5), torch.tensor(0.8)
    # rotation matrix
    rotation = o3.angles_to_matrix(alpha, beta, gamma)
    # Wigner D matrices for rotating node features
    irreps_strings = []
    for l in range(0, lmax+1):
        irreps_strings.append(f"1x{l}e")
    irreps = o3.Irreps("+".join(irreps_strings))
    wigner_D = irreps.D_from_angles(alpha, beta, gamma)

    # Water molecule in very large unit cell
    atoms = ase.build.molecule('H2O')
    atoms.translate(numpy.array([10.0, 10.0, 10.0]))
    atoms.set_cell(3*[1000.0])
    atoms.pbc = True
    #print("Atoms before rotation")
    #print(atoms)
    #print(atoms.get_positions())

    # Evaluate model on atom in the original orientation, y = M(x)
    prediction_Mx = model(atoms)
    
    # Rotate atom coordinates x' = R.x and lattice vectors of unit cell.
    atoms.positions = numpy.einsum('ij,aj->ai', rotation, atoms.positions)
    atoms.cell = numpy.einsum('ij,ja->ia', rotation, atoms.cell)
    #print("Atoms after rotation")
    #print(atoms)
    #print(atoms.get_positions())

    # Evaluate model on rotated atoms, y' = M(x')
    prediction_MRx = model(atoms)

    # Rotate the output tensor.
    # y'' = R(y) = R(M(x)) = M(R(x)) = M(x') = y'
    prediction_RMx = torch.einsum('...ij,...j->...i', wigner_D, prediction_Mx)

    print("M(x)")
    print(prediction_Mx)
    print("M(R(x))")
    print(prediction_MRx)
    print("R(M(x))")
    print(prediction_RMx)

    error = torch.linalg.norm(prediction_MRx - prediction_RMx)
    print(f"equivariance error |M(R(x)) - R(M(x))|= {error}")
    
if __name__ == "__main__":
    test_equivariance()

Equivariance is only approximately fulfilled,

M(R(x)) = tensor(
[0.0,  0.1534, -0.06606,  0.1087, -0.1761, -0.02575,  3.6217,  5.0306, 1.6298], ...)
R(M(x)) = tensor(
[ 0.0,  0.1526, -0.0658,  0.1057, -0.1751, -0.0223,  3.6105,  5.0067, 1.6283], ...)

and the error is surprisingly large

equivariance error |M(R(x)) - R(M(x))|= 0.0269

If I increase the number of layers, channels, lmax etc., the error gets much larger. Is there something wrong with my settings? How can I ensure that the network is equivariant to numerical precision?

Thank you for your time and help.
Best regards,
Alexander

@humeniuka
Copy link
Author

humeniuka commented Mar 20, 2025

Increasing the number of grid points for SO3_Grid (e.g. grid_resolution = 1024) reduces the error,

equivariance error |M(R(x)) - R(M(x))|= 0.000225

And increasing the floating point precision with torch.set_default_dtype(torch.double) reduces the error further,

equivariance error |M(R(x)) - R(M(x))|= 2.6e-05

but it seems that the computational cost of these settings will be very large.

@kyonofx
Copy link
Collaborator

kyonofx commented Mar 20, 2025

Hi @humeniuka, EquiformerV2 breaks exact equivariance due to (1) max-neighbor limit (2) grid discretization. For (1) you should be able to further reduce the error by setting enforce_max_neighbors_strictly=False. more details on this in #823.

Our recent preprint discuss design aspects to achieve exact equivariance and energy conservation: https://arxiv.org/abs/2502.12147

@kyonofx kyonofx closed this as completed Mar 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants