-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Making TorchForce CUDA-graph aware #103
Conversation
This is a MWE: import openmmtorch as ot
import torch
import openmm as mm
import numpy as np
class ForceModule(torch.nn.Module):
def forward(self, positions):
#return (torch.sum(torch.norm(positions,dim=1)), -2*positions)
return (torch.sum(positions**2), -2*positions)
module = torch.jit.script(ForceModule())
torch_force = ot.TorchForce(module)
torch_force.setOutputsForces(True)
numParticles = 10
system = mm.System()
positions = np.random.rand(numParticles, 3)
for _ in range(numParticles):
system.addParticle(1.0)
system.addForce(torch_force)
integ = mm.VerletIntegrator(1.0)
platform = mm.Platform.getPlatformByName('CUDA')
context = mm.Context(system, integ, platform)
context.setPositions(positions)
state = context.getState(getEnergy=True, getForces=True) This is the exception printed if one tries to capture this module: [W manager.cpp:329] Warning: FALLBACK path has been taken inside: runCudaFusionGroup. This is an indication that codegen Failed for some reason.
To debug try disable codegen fallback path via setting the env variable `export PYTORCH_NVFUSER_DISABLE=fallback`
(function runCudaFusionGroup)
Traceback (most recent call last):
File "/shared/raul/openmm-torch/python/tests/graph.py", line 48, in <module>
state = context.getState(getEnergy=True, getForces=True)
File "/shared/raul/mambaforge/envs/openmmtorchtest/lib/python3.9/site-packages/openmm/openmm.py", line 10009, in getState
state = _openmm.Context_getState(self, types, enforcePeriodicBox, groups_mask)
openmm.OpenMMException: TorchForce Failed to capture the model into a CUDA graph. Torch reported the following error:
The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
File "code/__torch__.py", line 8, in fallback_cuda_fuser
def forward(self: __torch__.ForceModule,
positions: Tensor) -> Tuple[Tensor, Tensor]:
_0 = (torch.sum(torch.pow(positions, 2)), torch.mul(positions, -2))
~~~~~~~~~ <--- HERE
return _0
Traceback of TorchScript, original code (most recent call last):
File "/shared/raul/openmm-torch/python/tests/graph.py", line 24, in fallback_cuda_fuser
"""
#return (torch.sum(torch.norm(positions,dim=1)), -2*positions)
return (torch.sum(positions**2), -2*positions)
~~~~~~~~~~~~ <--- HERE
RuntimeError: status != cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated INTERNAL ASSERT FAILED at "/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1664405705473/work/c10/cuda/CUDACachingAllocator.cpp":1082, please report a bug to PyTorch.
Using the commented line instead results in no error. |
Finish capturing before rethrowing if an exception occurred during capture
Same as my comments in #101 (comment), we should consider the alternatives for how to enable CUDA graphs. Could you comment on what you see as the advantages and disadvantages of the possibilities listed there, and why you chose this approach? |
Pros of the current implementation:
Cons
I believe adding a Property system will be benefitial in the future. There are surely some other functionalities we can implement for TorchForce that we would like to be turned on-off. As far as other ways to use graphs, I can think of the following:
As for the mechanism to enable the CUDA graph functionality, I can think of some alternatives to the Property system:
class ForceModule(torch.nn.Module):
def forward(self, positions, some_flag):
factor = 1
if(some_flag is not None):
flag = 10
pt.cuda.synchronize()
return (factor*torch.sum(torch.norm(positions,dim=1)), -2*positions) But if I know some_flag is None during the lifetime of TorchForce, I can enable cuda graphs on it. |
This reverts commit d20f4bf.
@peastman any more comments? |
This may be reviewed again @peastman |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking great! Just a few more very minor comments, and then it should be ready to merge.
Looks good to me! @raimis do you have any more comments before we merge it? |
This PR is a continuation of the work started by @raimis in #68 .
CUDA graphs provide a way to group several operations into a single "graph". The benefit of this being that by providing CUDA with a certain set of guarantees (mainly static shapes and memory addresses, no synchronization and no cuda* calls) it is capable of preventing some overhead (mainly related to the time spent preparing kernel launches).
CUDA graphs shine with workloads that consist of many small kernel launches put together.
In its most basic form, a CUDA graph is constructed by "capturing" a stream. In essence you do a dry run of the workload, which must happen in the same stream and CUDA records everything that happen in it into a graph. Then the graph can be replayed as many times as needed.
For TorchForce, the most evident use of this is to try and make the forward and backwards calls into a graph.
For that, this PR aims to introduce the following changes:
For this, the solution proposed by raimis in Support CUDA Graphs #68 , adding a property system to TorchForce, seems the most sensible.
Main changes in this PR go into this function:
openmm-torch/platforms/cuda/src/CudaTorchKernels.cpp
Line 97 in 769302a
which performs three main operations:
Following from #101, only step 2 is introduced into a graph. The other two operations are essentially two additional kernel launches, in principle, they could also be introduced into the graph.
Right now, there are several synchronization barriers between each step. Also, I do not know what kind of guarantees
ContextSelector
in OpenMM provides (a.i does it involve synchronization?, cudaSetDevice?), I would need guidance on this.Apart from this, I added the cherry-picked commits from #68 that implement the functionality to provide "Properties" to TorchForce. This includes modifying the constructors of TorchForce to provide an optional dictionary with properties, and the addition of a
setProperty
andgetProperty
members.Caveats
Ungraphable operations
Many apparently innocuous operations are not graphable. For instance, this model is OK:
While this one is not:
Luckily, CUDA and torch are really informative at saying which line is the offending one. Torch throws an exception that is easy to catch and handle.