Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making TorchForce CUDA-graph aware #103

Merged
merged 47 commits into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
732e36d
Add CUDA graph draft
RaulPPelaez Mar 28, 2023
b1106e8
Initialize energy and force tensors in the GPU.
RaulPPelaez Mar 29, 2023
b972b5b
Add comment on graph capture
RaulPPelaez Mar 29, 2023
2f9238e
Catch torch exception if the model fails to capture.
RaulPPelaez Mar 29, 2023
6ff7f1a
Replay graph just after construction
RaulPPelaez Mar 29, 2023
84a5460
Add python-side test script for CUDA graphs
RaulPPelaez Mar 29, 2023
6e8c873
Implement properties
Feb 25, 2022
a82e77d
Update the Python bindings
Feb 25, 2022
91cf545
Unify the API for properties
Apr 12, 2022
e81dad6
Pass the propery map to the constructor
Apr 12, 2022
efc2589
Skip graph tests if no GPU is present
RaulPPelaez Mar 29, 2023
389862d
Guard CUDA graph behavior with the CUDA_GRAPH_ENABLE macro
RaulPPelaez Mar 29, 2023
f200e43
Check validity of the useCUDAGraphs property
RaulPPelaez Mar 29, 2023
cd6abc6
Add missing bracket to openmmtorch.i
RaulPPelaez Mar 29, 2023
af6b7b8
Fix bug in useCUDAgraph selection
RaulPPelaez Mar 29, 2023
7f65b04
Update tests
RaulPPelaez Mar 29, 2023
3c31110
Add test for get/setProperty
RaulPPelaez Mar 29, 2023
2c7309c
Update documentation with new functionality
RaulPPelaez Mar 29, 2023
87640b1
Add a CUDA graph test for a model that returns only energy
RaulPPelaez Mar 29, 2023
c4746dc
Add contributors
RaulPPelaez Mar 30, 2023
f80b797
Reset pos grads after graph capture. Make energy and force tensors pe…
RaulPPelaez Mar 30, 2023
ea6bc4a
Add tests that execute the model many times to catch bugs related with
RaulPPelaez Mar 30, 2023
331ba31
Run formatter
RaulPPelaez Mar 30, 2023
1378d1a
Warmup model for several steps
RaulPPelaez Mar 30, 2023
eabc60e
Include gradient reset into the graph
RaulPPelaez Mar 30, 2023
afef58d
Do not reset energy and force tensors before graph capture
RaulPPelaez Mar 30, 2023
c73ee4d
Remove unnecessary line
RaulPPelaez Mar 30, 2023
eab94a8
Add tests for larger number of particles
RaulPPelaez Mar 30, 2023
c94711f
Remove unnecessary compilation guard now that Pytorch 1.10 is not sup…
RaulPPelaez Apr 11, 2023
edd37a9
Simplify getTensorPointer now that Pytorch 1.7 is not supported
RaulPPelaez Apr 11, 2023
6de2c5b
Change addForcesToOpenMM to addForces
RaulPPelaez Apr 11, 2023
19594e3
Change execute_graph to executeGraph
RaulPPelaez Apr 11, 2023
2f6e88a
Wrap graph warming up in a try/catch block
RaulPPelaez Apr 11, 2023
d20f4bf
Add correctness test for modules that only provide energy
RaulPPelaez Apr 11, 2023
7ccd5a3
Revert "Add correctness test for modules that only provide energy"
RaulPPelaez Apr 11, 2023
71af9be
Merge remote-tracking branch 'origin/master' into cuda_graphs_raul
RaulPPelaez Apr 14, 2023
3d88c26
Explicit conversion to correct type in getTensorPointer
RaulPPelaez Apr 14, 2023
e2bb3a0
Added a new property for TorchForce, CUDAGraphWarmupSteps.
RaulPPelaez Apr 14, 2023
0e04b63
Clarify docs
RaulPPelaez Apr 17, 2023
54b20fa
Document properties
RaulPPelaez Apr 17, 2023
f800784
Throw if requested property does not exist
RaulPPelaez Apr 17, 2023
0cc936a
Change getProperty(string) to getProperties()
RaulPPelaez Apr 17, 2023
40c8fef
Add getProperties to python wrappers
RaulPPelaez Apr 17, 2023
d82eb53
Fix formatting
RaulPPelaez Apr 17, 2023
34a555a
Set default properties
RaulPPelaez Apr 17, 2023
8e6949c
Update tests
RaulPPelaez Apr 17, 2023
2be7fd9
Update some comments
RaulPPelaez Apr 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,20 @@ to return forces.
torch_force.setOutputsForces(True)
```

Recording the model into a CUDA graph
-------------------------------------

You can ask `TorchForce` to run the model using [CUDA graphs](https://pytorch.org/docs/stable/notes/cuda.html#cuda-graphs). Not every model will be compatible with this feature, but it can be a significant performance boost for some models. To enable it the CUDA platform must be used and an special property must be provided to `TorchForce`:

```python
torch_force.setProperty("useCUDAGraphs", "true")
# The property can also be queried at construction
RaulPPelaez marked this conversation as resolved.
Show resolved Hide resolved
torch_force = TorchForce('model.pt', {'useCUDAGraphs': 'true'})
```

The first time the model is run, it will be compiled into a CUDA graph. Subsequent runs will use the compiled graph, which can be significantly faster. It is possible that compilation fails, in which case an `OpenMMException` will be raised. If that happens, you can disable CUDA graphs and try again.


License
=======

Expand Down
28 changes: 24 additions & 4 deletions openmmapi/include/TorchForce.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
* *
* Portions copyright (c) 2018-2022 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* Contributors: Raimondas Galvelis, Raul P. Pelaez *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
Expand All @@ -34,6 +34,7 @@

#include "openmm/Context.h"
#include "openmm/Force.h"
#include <map>
#include <string>
#include <torch/torch.h>
#include "internal/windowsExportTorch.h"
Expand All @@ -54,17 +55,20 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
* Create a TorchForce. The network is defined by a PyTorch ScriptModule saved
* to a file.
*
* @param file the path to the file containing the network
* @param file the path to the file containing the network
* @param properties the property map
RaulPPelaez marked this conversation as resolved.
Show resolved Hide resolved
*/
TorchForce(const std::string& file);
TorchForce(const std::string& file,
const std::map<std::string, std::string>& properties = {});
/**
* Create a TorchForce. The network is defined by a PyTorch ScriptModule
* Note that this constructor makes a copy of the provided module.
* Any changes to the module after calling this constructor will be ignored by TorchForce.
*
* @param module an instance of the torch module
* @param properties the property map
*/
TorchForce(const torch::jit::Module &module);
TorchForce(const torch::jit::Module &module, const std::map<std::string, std::string>& properties = {});
/**
* Get the path to the file containing the network.
* If the TorchForce instance was constructed with a module, instead of a filename,
Expand Down Expand Up @@ -140,6 +144,20 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
* @param defaultValue the default value of the parameter
*/
void setGlobalParameterDefaultValue(int index, double defaultValue);
/**
* Set a value of a property.
*
* @param name the name of the property
* @param value the value of the property
*/
void setProperty(const std::string& name, const std::string& value);
/**
* Get a value of a property.
*
* @param name the name of the property
* @return the value of the property
*/
const std::string& getProperty(const std::string& name) const;
RaulPPelaez marked this conversation as resolved.
Show resolved Hide resolved
protected:
OpenMM::ForceImpl* createImpl() const;
private:
Expand All @@ -148,6 +166,8 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
bool usePeriodic, outputsForces;
std::vector<GlobalParameterInfo> globalParameters;
torch::jit::Module module;
std::map<std::string, std::string> properties;
std::string emptyProperty;
};

/**
Expand Down
16 changes: 13 additions & 3 deletions openmmapi/src/TorchForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
* *
* Portions copyright (c) 2018-2022 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* Contributors: Raimondas Galvelis, Raul P. Pelaez *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
Expand Down Expand Up @@ -41,10 +41,10 @@ using namespace TorchPlugin;
using namespace OpenMM;
using namespace std;

TorchForce::TorchForce(const torch::jit::Module& module) : file(), usePeriodic(false), outputsForces(false), module(module) {
TorchForce::TorchForce(const torch::jit::Module& module, const map<string, string>& properties) : file(), usePeriodic(false), outputsForces(false), module(module), properties(properties) {
}

TorchForce::TorchForce(const std::string& file) : TorchForce(torch::jit::load(file)) {
TorchForce::TorchForce(const std::string& file, const map<string, string>& properties) : TorchForce(torch::jit::load(file), properties) {
this->file = file;
}

Expand Down Expand Up @@ -104,3 +104,13 @@ void TorchForce::setGlobalParameterDefaultValue(int index, double defaultValue)
ASSERT_VALID_INDEX(index, globalParameters);
globalParameters[index].defaultValue = defaultValue;
}

void TorchForce::setProperty(const std::string& name, const std::string& value) {
properties[name] = value;
RaulPPelaez marked this conversation as resolved.
Show resolved Hide resolved
}

const std::string& TorchForce::getProperty(const std::string& name) const {
if (properties.find(name) != properties.end())
return properties.at(name);
return emptyProperty;
}
202 changes: 131 additions & 71 deletions platforms/cuda/src/CudaTorchKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
* *
* Portions copyright (c) 2018-2022 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* Contributors: Raimondas Galvelis, Raul P. Pelaez *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
Expand All @@ -35,22 +35,23 @@
#include "openmm/internal/ContextImpl.h"
#include <map>
#include <cuda_runtime_api.h>

#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
using namespace TorchPlugin;
using namespace OpenMM;
using namespace std;

// macro for checking the result of synchronization operation on CUDA
// copied from `openmm/platforms/cuda/src/CudaParallelKernels.cpp`
#define CHECK_RESULT(result, prefix) \
if (result != CUDA_SUCCESS) { \
std::stringstream m; \
m<<prefix<<": "<<cu.getErrorString(result)<<" ("<<result<<")"<<" at "<<__FILE__<<":"<<__LINE__; \
throw OpenMMException(m.str());\
}
#define CHECK_RESULT(result, prefix) \
if (result != CUDA_SUCCESS) { \
std::stringstream m; \
m << prefix << ": " << cu.getErrorString(result) << " (" << result << ")" \
<< " at " << __FILE__ << ":" << __LINE__; \
throw OpenMMException(m.str()); \
}
RaulPPelaez marked this conversation as resolved.
Show resolved Hide resolved

CudaCalcTorchForceKernel::CudaCalcTorchForceKernel(string name, const Platform& platform, CudaContext& cu) :
CalcTorchForceKernel(name, platform), hasInitializedKernel(false), cu(cu) {
CudaCalcTorchForceKernel::CudaCalcTorchForceKernel(string name, const Platform& platform, CudaContext& cu) : CalcTorchForceKernel(name, platform), hasInitializedKernel(false), cu(cu) {
// Explicitly activate the primary context
CHECK_RESULT(cuDevicePrimaryCtxRetain(&primaryContext, cu.getDevice()), "Failed to retain the primary context");
}
Expand All @@ -75,12 +76,11 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce
// Initialize CUDA objects for PyTorch
const torch::Device device(torch::kCUDA, cu.getDeviceIndex()); // This implicitly initialize PyTorch
module.to(device);
torch::TensorOptions options = torch::TensorOptions()
.device(device)
.dtype(cu.getUseDoublePrecision() ? torch::kFloat64 : torch::kFloat32);
torch::TensorOptions options = torch::TensorOptions().device(device).dtype(cu.getUseDoublePrecision() ? torch::kFloat64 : torch::kFloat32);
posTensor = torch::empty({numParticles, 3}, options.requires_grad(!outputsForces));
boxTensor = torch::empty({3, 3}, options);

energyTensor = torch::empty({0}, options);
forceTensor = torch::empty({0}, options);
// Pop the PyToch context
CUcontext ctx;
CHECK_RESULT(cuCtxPopCurrent(&ctx), "Failed to pop the CUDA context");
Expand All @@ -92,96 +92,156 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce
CUmodule program = cu.createModule(CudaTorchKernelSources::torchForce, defines);
copyInputsKernel = cu.getKernel(program, "copyInputs");
addForcesKernel = cu.getKernel(program, "addForces");
const std::string useCUDAGraphsString = force.getProperty("useCUDAGraphs");
if (useCUDAGraphsString == "true")
useGraphs = true;
else if (useCUDAGraphsString == "false" || useCUDAGraphsString == "")
useGraphs = false;
else
throw OpenMMException("TorchForce: invalid value of \"useCUDAGraphs\"");
this->warmupSteps = 1;
if (useGraphs) {
const std::string warmupStepsString = force.getProperty("CUDAGraphWarmupSteps");
if (!warmupStepsString.empty())
this->warmupSteps = std::stoi(warmupStepsString);
RaulPPelaez marked this conversation as resolved.
Show resolved Hide resolved
}
}

double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
int numParticles = cu.getNumAtoms();

// Push to the PyTorch context
CHECK_RESULT(cuCtxPushCurrent(primaryContext), "Failed to push the CUDA context");

// Get pointers to the atomic positions and simulation box
void* posData;
void* boxData;
/**
* Get a pointer to the data in a PyTorch tensor.
* The tensor is converted to the correct data type if necessary.
*/
static void* getTensorPointer(OpenMM::CudaContext& cu, torch::Tensor& tensor) {
void* data;
if (cu.getUseDoublePrecision()) {
posData = posTensor.data_ptr<double>();
boxData = boxTensor.data_ptr<double>();
}
else {
posData = posTensor.data_ptr<float>();
boxData = boxTensor.data_ptr<float>();
data = tensor.to(torch::kFloat64).data_ptr<double>();
} else {
data = tensor.to(torch::kFloat32).data_ptr<float>();
}
return data;
}

/**
* Prepare the inputs for the PyTorch model, copying positions from the OpenMM context.
*/
std::vector<torch::jit::IValue> CudaCalcTorchForceKernel::prepareTorchInputs(ContextImpl& context) {
int numParticles = cu.getNumAtoms();
// Get pointers to the atomic positions and simulation box
void* posData = getTensorPointer(cu, posTensor);
void* boxData = getTensorPointer(cu, boxTensor);
// Copy the atomic positions and simulation box to PyTorch tensors
{
ContextSelector selector(cu); // Switch to the OpenMM context
void* inputArgs[] = {&posData, &boxData, &cu.getPosq().getDevicePointer(), &cu.getAtomIndexArray().getDevicePointer(),
&numParticles, cu.getPeriodicBoxVecXPointer(), cu.getPeriodicBoxVecYPointer(), cu.getPeriodicBoxVecZPointer()};
void* inputArgs[] = {&posData,
&boxData,
&cu.getPosq().getDevicePointer(),
&cu.getAtomIndexArray().getDevicePointer(),
&numParticles,
cu.getPeriodicBoxVecXPointer(),
cu.getPeriodicBoxVecYPointer(),
cu.getPeriodicBoxVecZPointer()};
cu.executeKernel(copyInputsKernel, inputArgs, numParticles);
CHECK_RESULT(cuCtxSynchronize(), "Failed to synchronize the CUDA context"); // Synchronize before switching to the PyTorch context
}

// Prepare the input of the PyTorch model
vector<torch::jit::IValue> inputs = {posTensor};
if (usePeriodic)
inputs.push_back(boxTensor);
for (const string& name : globalNames)
inputs.push_back(torch::tensor(context.getParameter(name)));
return inputs;
}

// Execute the PyTorch model
torch::Tensor energyTensor, forceTensor;
/**
* Add the computed forces to the total atomic forces.
*/
void CudaCalcTorchForceKernel::addForces(torch::Tensor& forceTensor) {
int numParticles = cu.getNumAtoms();
// Get a pointer to the computed forces
void* forceData = getTensorPointer(cu, forceTensor);
CHECK_RESULT(cuCtxSynchronize(), "Failed to synchronize the CUDA context"); // Synchronize before switching to the OpenMM context
// Add the computed forces to the total atomic forces
{
ContextSelector selector(cu); // Switch to the OpenMM context
int paddedNumAtoms = cu.getPaddedNumAtoms();
int forceSign = (outputsForces ? 1 : -1);
void* forceArgs[] = {&forceData, &cu.getForce().getDevicePointer(), &cu.getAtomIndexArray().getDevicePointer(), &numParticles, &paddedNumAtoms, &forceSign};
cu.executeKernel(addForcesKernel, forceArgs, numParticles);
CHECK_RESULT(cuCtxSynchronize(), "Failed to synchronize the CUDA context"); // Synchronize before switching to the PyTorch context
}
}

/**
* This function launches the workload in a way compatible with CUDA
* graphs as far as OpenMM-Torch goes. Capturing this function when
* the model is not itself graph compatible (due to, for instance,
* implicit synchronizations) will result in a CUDA error.
*/
static void executeGraph(bool outputsForces, bool includeForces, torch::jit::script::Module& module, vector<torch::jit::IValue>& inputs, torch::Tensor& posTensor, torch::Tensor& energyTensor,
torch::Tensor& forceTensor) {
if (outputsForces) {
auto outputs = module.forward(inputs).toTuple();
energyTensor = outputs->elements()[0].toTensor();
forceTensor = outputs->elements()[1].toTensor();
}
else
} else {
energyTensor = module.forward(inputs).toTensor();

if (includeForces) {

// Compute force by backprogating the PyTorch model
if (!outputsForces) {
// Compute force by backpropagating the PyTorch model
if (includeForces) {
energyTensor.backward();
forceTensor = posTensor.grad();
}

// Get a pointer to the computed forces
void* forceData;
if (cu.getUseDoublePrecision()) {
if (!(forceTensor.dtype() == torch::kFloat64)) // TODO: simplify the logic when support for PyTorch 1.7 is dropped
forceTensor = forceTensor.to(torch::kFloat64);
forceData = forceTensor.data_ptr<double>();
}
else {
if (!(forceTensor.dtype() == torch::kFloat32)) // TODO: simplify the logic when support for PyTorch 1.7 is dropped
forceTensor = forceTensor.to(torch::kFloat32);
forceData = forceTensor.data_ptr<float>();
forceTensor = posTensor.grad().clone();
// Zero the gradient to avoid accumulating it
posTensor.grad().zero_();
}
CHECK_RESULT(cuCtxSynchronize(), "Failed to synchronize the CUDA context"); // Synchronize before switching to the OpenMM context
}
}

// Add the computed forces to the total atomic forces
{
ContextSelector selector(cu); // Switch to the OpenMM context
int paddedNumAtoms = cu.getPaddedNumAtoms();
int forceSign = (outputsForces ? 1 : -1);
void* forceArgs[] = {&forceData, &cu.getForce().getDevicePointer(), &cu.getAtomIndexArray().getDevicePointer(), &numParticles, &paddedNumAtoms, &forceSign};
cu.executeKernel(addForcesKernel, forceArgs, numParticles);
CHECK_RESULT(cuCtxSynchronize(), "Failed to synchronize the CUDA context"); // Synchronize before switching to the PyTorch context
double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
// Push to the PyTorch context
CHECK_RESULT(cuCtxPushCurrent(primaryContext), "Failed to push the CUDA context");
auto inputs = prepareTorchInputs(context);
if (!useGraphs) {
executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor);
} else {
const auto stream = c10::cuda::getStreamFromPool(false, posTensor.get_device());
const c10::cuda::CUDAStreamGuard guard(stream);
// Record graph if not already done
bool is_graph_captured = false;
if (graphs.find(includeForces) == graphs.end()) {
// Warmup the graph workload before capturing. This first
// run before capture sets up allocations so that no
// allocations are needed after. Pytorch's allocator is
// stream capture-aware and, after warmup, will provide
// record static pointers and shapes during capture.
try {
for (int i = 0; i < this->warmupSteps; i++)
executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor);
}
catch (std::exception& e) {
throw OpenMMException(string("TorchForce Failed to warmup the model before graph construction. Torch reported the following error:\n") + e.what());
}
graphs[includeForces].capture_begin();
RaulPPelaez marked this conversation as resolved.
Show resolved Hide resolved
try {
executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor);
is_graph_captured = true;
graphs[includeForces].capture_end();
}
catch (std::exception& e) {
if (!is_graph_captured) {
graphs[includeForces].capture_end();
}
throw OpenMMException(string("TorchForce Failed to capture the model into a CUDA graph. Torch reported the following error:\n") + e.what());
}
}

// Reset the forces
if (!outputsForces)
posTensor.grad().zero_();
graphs[includeForces].replay();
}
if (includeForces) {
addForces(forceTensor);
}

// Get energy
const double energy = energyTensor.item<double>(); // This implicitly synchronizes the PyTorch context

// Pop to the PyTorch context
CUcontext ctx;
CHECK_RESULT(cuCtxPopCurrent(&ctx), "Failed to pop the CUDA context");
assert(primaryContext == ctx); // Check that the correct context was popped

return energy;
}
Loading