Skip to content

TensorNET with CUDA Graphs and LangevinMiddleIntegrator fails #220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
raimis opened this issue Sep 21, 2023 · 9 comments
Closed

TensorNET with CUDA Graphs and LangevinMiddleIntegrator fails #220

raimis opened this issue Sep 21, 2023 · 9 comments

Comments

@raimis
Copy link
Collaborator

raimis commented Sep 21, 2023

TensorNET with CUDA Graphs and LangevinMiddleIntegrator fails, while it works LangevinIntegrator.

Versions:

import numpy as np
from openmm import Context, LangevinIntegrator, LangevinMiddleIntegrator, Platform, System
from openmmtorch import TorchForce
import torch as pt
from torchmdnet.models.model import create_model

args = {"model": "tensornet",
        "embedding_dimension": 128,
        "num_layers": 2,
        "num_rbf": 32,
        "rbf_type": "expnorm",
        "trainable_rbf": False,
        "activation": "silu",
        "cutoff_lower": 0.0,
        "cutoff_upper": 5.0,
        "max_z": 100,
        "max_num_neighbors": 128,
        "equivariance_invariance_group": "O(3)",
        "prior_model": None,
        "atom_filter": -1,
        "derivative": False,
        "output_model": "Scalar",
        "reduce_op": "sum",
        "precision": 32 }
model = create_model(args)

n_atoms = 5
positions = np.random.randn(n_atoms, 3)
temperature = 300

class NNPForce(pt.nn.Module):

    def __init__(self, model, n_atoms):
        super().__init__()
        self.model = model
        self.register_buffer("zs", pt.ones(n_atoms, dtype=pt.long))

    def forward(self, positions):
        return self.model(self.zs, positions)[0]

for useCUDAgraphs in [False, True]:
    for integratorType in [LangevinIntegrator, LangevinMiddleIntegrator]:
        for set_velocities in [False, True]:

            system = System()
            for _ in range(n_atoms):
                system.addParticle(1)

            force = NNPForce(model, n_atoms)
            force = TorchForce(pt.jit.script(force))
            force.setProperty("useCUDAGraphs", "true" if useCUDAgraphs else "false")
            system.addForce(force)

            platform = Platform.getPlatformByName("CUDA")

            integrator = integratorType(temperature, 1, 0.001)

            context = Context(system, integrator, platform)
            context.setPositions(positions)
            if set_velocities:
                context.setVelocitiesToTemperature(temperature)

            energy = context.getState(getEnergy=True).getPotentialEnergy()
            print(useCUDAgraphs, integratorType.__name__, set_velocities, energy)
False LangevinIntegrator False 1.1890453100204468 kJ/mol
False LangevinIntegrator True 1.189045786857605 kJ/mol
False LangevinMiddleIntegrator False 1.1890454292297363 kJ/mol
False LangevinMiddleIntegrator True 1.1890456676483154 kJ/mol
[W ___torch_mangle_13.py:45] Warning: CUDA graph capture will lock the batch to the current number of samples (1). Changing this will result in a crash (function )
True LangevinIntegrator False 1.189045786857605 kJ/mol
[W ___torch_mangle_13.py:45] Warning: CUDA graph capture will lock the batch to the current number of samples (1). Changing this will result in a crash (function )
True LangevinIntegrator True 1.1890455484390259 kJ/mol
[W ___torch_mangle_13.py:45] Warning: CUDA graph capture will lock the batch to the current number of samples (1). Changing this will result in a crash (function )
True LangevinMiddleIntegrator False 0.0 kJ/mol
[W ___torch_mangle_13.py:45] Warning: CUDA graph capture will lock the batch to the current number of samples (1). Changing this will result in a crash (function )
[W manager.cpp:335] Warning: FALLBACK path has been taken inside: runCudaFusionGroup. This is an indication that codegen Failed for some reason.
To debug try disable codegen fallback path via setting the env variable `export PYTORCH_NVFUSER_DISABLE=fallback`
 (function runCudaFusionGroup)

OpenMMException: CUDA error: operation failed due to a previous error during capture
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at /home/conda/feedstock_root/build_artifacts/pytorch-recipe_1693188735407/work/c10/cuda/CUDAException.cpp:44 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string, std::allocator >) + 0x68 (0x7f2a12b62bd8 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/lib/python3.11/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string, std::allocator > const&) + 0xf3 (0x7f2a12b288c9 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/lib/python3.11/site-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x3e4 (0x7f2a12beb464 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/lib/python3.11/site-packages/torch/lib/libc10_cuda.so)
frame #3: at::cuda::CUDAGraph::capture_end() + 0xc2 (0x7f29c19735e2 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so)
frame #4:  + 0x9011 (0x7f2b2000a011 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/lib/plugins/libOpenMMTorchCUDA.so)
frame #5: OpenMM::ContextImpl::calcForcesAndEnergy(bool, bool, int) + 0xc9 (0x7f2a14cc7789 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/lib/python3.11/site-packages/openmm/../../../libOpenMM.so.8.0)
frame #6: OpenMM::Context::getState(int, bool, int) const + 0x15e (0x7f2a14cc4e5e in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/lib/python3.11/site-packages/openmm/../../../libOpenMM.so.8.0)
frame #7:  + 0x12d9d4 (0x7f2a150d99d4 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/lib/python3.11/site-packages/openmm/_openmm.cpython-311-x86_64-linux-gnu.so)
frame #8:  + 0x12de69 (0x7f2a150d9e69 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/lib/python3.11/site-packages/openmm/_openmm.cpython-311-x86_64-linux-gnu.so)
frame #9:  + 0x1ffde8 (0x563234441de8 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #10: _PyObject_MakeTpCall + 0x25b (0x56323442207b in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #11: _PyEval_EvalFrameDefault + 0x7d2 (0x56323442e992 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #12:  + 0x2a4d36 (0x5632344e6d36 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #13: PyEval_EvalCode + 0x9f (0x5632344e63ef in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #14:  + 0x2bbbae (0x5632344fdbae in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #15: _PyEval_EvalFrameDefault + 0x3d32 (0x563234431ef2 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #16:  + 0x2b8a79 (0x5632344faa79 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #17: _PyEval_EvalFrameDefault + 0x33c0 (0x563234431580 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #18:  + 0x2b8a79 (0x5632344faa79 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #19: _PyEval_EvalFrameDefault + 0x33c0 (0x563234431580 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #20:  + 0x2b8a79 (0x5632344faa79 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #21:  + 0x2b9647 (0x5632344fb647 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #22:  + 0x217839 (0x563234459839 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #23: PyObject_Vectorcall + 0x2c (0x56323443b5bc in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #24: _PyEval_EvalFrameDefault + 0x7d2 (0x56323442e992 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #25:  + 0x22fc74 (0x563234471c74 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #26:  + 0x22f673 (0x563234471673 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #27: PyObject_Call + 0xa1 (0x56323445ad71 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #28: _PyEval_EvalFrameDefault + 0x4327 (0x5632344324e7 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #29:  + 0x2b8a79 (0x5632344faa79 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #30: _PyEval_EvalFrameDefault + 0x33c0 (0x563234431580 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #31:  + 0x2b8a79 (0x5632344faa79 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #32: _PyEval_EvalFrameDefault + 0x33c0 (0x563234431580 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #33:  + 0x2b8a79 (0x5632344faa79 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #34: _PyEval_EvalFrameDefault + 0x33c0 (0x563234431580 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #35:  + 0x2b8a79 (0x5632344faa79 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #36: _PyEval_EvalFrameDefault + 0x33c0 (0x563234431580 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #37:  + 0x2b8a79 (0x5632344faa79 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #38:  + 0x78a2 (0x7f2b46eaa8a2 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/lib/python3.11/lib-dynload/_asyncio.cpython-311-x86_64-linux-gnu.so)
frame #39:  + 0x1fe7ea (0x5632344407ea in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #40:  + 0x199b2f (0x5632343dbb2f in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #41:  + 0x19baa4 (0x5632343ddaa4 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #42:  + 0x1f96cf (0x56323443b6cf in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #43: _PyEval_EvalFrameDefault + 0x8fa1 (0x563234437161 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #44:  + 0x2a4d36 (0x5632344e6d36 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #45: PyEval_EvalCode + 0x9f (0x5632344e63ef in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #46:  + 0x2bbbae (0x5632344fdbae in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #47:  + 0x1f96cf (0x56323443b6cf in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #48: PyObject_Vectorcall + 0x2c (0x56323443b5bc in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #49: _PyEval_EvalFrameDefault + 0x7d2 (0x56323442e992 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #50: _PyFunction_Vectorcall + 0x181 (0x563234451121 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #51:  + 0x2ce0c8 (0x5632345100c8 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #52: Py_RunMain + 0x139 (0x56323450fa39 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #53: Py_BytesMain + 0x37 (0x5632344d4f97 in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
frame #54:  + 0x29d90 (0x7f2b48c29d90 in /lib/x86_64-linux-gnu/libc.so.6)
frame #55: __libc_start_main + 0x80 (0x7f2b48c29e40 in /lib/x86_64-linux-gnu/libc.so.6)
frame #56:  + 0x292e3d (0x5632344d4e3d in /scratch/users/raimis/opt/mambaforge/envs/tmp-atom-2/bin/python)
@raimis
Copy link
Collaborator Author

raimis commented Sep 21, 2023

Most likely, it is some memory corruption in TorchForce, LangevinMiddleIntegrator, or somewhere else.

@raimis
Copy link
Collaborator Author

raimis commented Sep 21, 2023

@RaulPPelaez @peastman any ideas?

@peastman
Copy link
Collaborator

I agree, it sounds like memory corruption. The choice of integrator shouldn't matter at all. Probably it's just perturbing the memory layout or something similar, which changes whether the corruption leads to a crash.

Did you try the suggestions in the error message?

For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

You could also try the CUDA compute sanitizer tool to check for errors.

@RaulPPelaez
Copy link
Collaborator

Can you try another integrator? I have a feeling its completely unrelated to LangevinMiddleIntegrator

@raimis
Copy link
Collaborator Author

raimis commented Sep 22, 2023

Did you try the suggestions in the error message?

CUDA_LAUNCH_BLOCKING=1 doesn't change anything. I cannot recompile PyTorch with TORCH_USE_CUDA_DSA

You could also try the CUDA compute sanitizer tool to check for errors.

I will

Can you try another integrator? I have a feeling its completely unrelated to LangevinMiddleIntegrator

As shown, LangevinIntegrator works as expected

@RaulPPelaez
Copy link
Collaborator

I tried changing the order in your loop, but it seems to consistently fail only with LangevinMiddleIntegrator.
The error itself suggests and the warning suggests that different code is being captured in the last case. I do not see how that can happen in this case...
The previous test returns 0, maybe some CUDA error is set then and the next one is just picking it up.

@raimis
Copy link
Collaborator Author

raimis commented Sep 22, 2023

It happens even you if run it alone.

@RaulPPelaez
Copy link
Collaborator

This is fixed in my machine by adding a syncrhonization after graph replay in OpenMM-Torch here:
https://github.com/openmm/openmm-torch/blob/2270256fed21c5ed045cc21b560931b171f5c863/platforms/cuda/src/CudaTorchKernels.cpp#L239-L241

I am guessing the particular integrator tries to run something using forces/energies in another stream too soon, which is causing some race condition.
Let me see how I can reduce the impact of this sync and I will PR.

@raimis
Copy link
Collaborator Author

raimis commented Sep 27, 2023

@RaulPPelaez thank you for your effort! I confirm the issue is fixed. The fix will be released with OpenMM-Torch 1.4.

@raimis raimis closed this as completed Sep 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants