Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making TorchForce CUDA-graph aware #103

Merged
merged 47 commits into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
732e36d
Add CUDA graph draft
RaulPPelaez Mar 28, 2023
b1106e8
Initialize energy and force tensors in the GPU.
RaulPPelaez Mar 29, 2023
b972b5b
Add comment on graph capture
RaulPPelaez Mar 29, 2023
2f9238e
Catch torch exception if the model fails to capture.
RaulPPelaez Mar 29, 2023
6ff7f1a
Replay graph just after construction
RaulPPelaez Mar 29, 2023
84a5460
Add python-side test script for CUDA graphs
RaulPPelaez Mar 29, 2023
6e8c873
Implement properties
Feb 25, 2022
a82e77d
Update the Python bindings
Feb 25, 2022
91cf545
Unify the API for properties
Apr 12, 2022
e81dad6
Pass the propery map to the constructor
Apr 12, 2022
efc2589
Skip graph tests if no GPU is present
RaulPPelaez Mar 29, 2023
389862d
Guard CUDA graph behavior with the CUDA_GRAPH_ENABLE macro
RaulPPelaez Mar 29, 2023
f200e43
Check validity of the useCUDAGraphs property
RaulPPelaez Mar 29, 2023
cd6abc6
Add missing bracket to openmmtorch.i
RaulPPelaez Mar 29, 2023
af6b7b8
Fix bug in useCUDAgraph selection
RaulPPelaez Mar 29, 2023
7f65b04
Update tests
RaulPPelaez Mar 29, 2023
3c31110
Add test for get/setProperty
RaulPPelaez Mar 29, 2023
2c7309c
Update documentation with new functionality
RaulPPelaez Mar 29, 2023
87640b1
Add a CUDA graph test for a model that returns only energy
RaulPPelaez Mar 29, 2023
c4746dc
Add contributors
RaulPPelaez Mar 30, 2023
f80b797
Reset pos grads after graph capture. Make energy and force tensors pe…
RaulPPelaez Mar 30, 2023
ea6bc4a
Add tests that execute the model many times to catch bugs related with
RaulPPelaez Mar 30, 2023
331ba31
Run formatter
RaulPPelaez Mar 30, 2023
1378d1a
Warmup model for several steps
RaulPPelaez Mar 30, 2023
eabc60e
Include gradient reset into the graph
RaulPPelaez Mar 30, 2023
afef58d
Do not reset energy and force tensors before graph capture
RaulPPelaez Mar 30, 2023
c73ee4d
Remove unnecessary line
RaulPPelaez Mar 30, 2023
eab94a8
Add tests for larger number of particles
RaulPPelaez Mar 30, 2023
c94711f
Remove unnecessary compilation guard now that Pytorch 1.10 is not sup…
RaulPPelaez Apr 11, 2023
edd37a9
Simplify getTensorPointer now that Pytorch 1.7 is not supported
RaulPPelaez Apr 11, 2023
6de2c5b
Change addForcesToOpenMM to addForces
RaulPPelaez Apr 11, 2023
19594e3
Change execute_graph to executeGraph
RaulPPelaez Apr 11, 2023
2f6e88a
Wrap graph warming up in a try/catch block
RaulPPelaez Apr 11, 2023
d20f4bf
Add correctness test for modules that only provide energy
RaulPPelaez Apr 11, 2023
7ccd5a3
Revert "Add correctness test for modules that only provide energy"
RaulPPelaez Apr 11, 2023
71af9be
Merge remote-tracking branch 'origin/master' into cuda_graphs_raul
RaulPPelaez Apr 14, 2023
3d88c26
Explicit conversion to correct type in getTensorPointer
RaulPPelaez Apr 14, 2023
e2bb3a0
Added a new property for TorchForce, CUDAGraphWarmupSteps.
RaulPPelaez Apr 14, 2023
0e04b63
Clarify docs
RaulPPelaez Apr 17, 2023
54b20fa
Document properties
RaulPPelaez Apr 17, 2023
f800784
Throw if requested property does not exist
RaulPPelaez Apr 17, 2023
0cc936a
Change getProperty(string) to getProperties()
RaulPPelaez Apr 17, 2023
40c8fef
Add getProperties to python wrappers
RaulPPelaez Apr 17, 2023
d82eb53
Fix formatting
RaulPPelaez Apr 17, 2023
34a555a
Set default properties
RaulPPelaez Apr 17, 2023
8e6949c
Update tests
RaulPPelaez Apr 17, 2023
2be7fd9
Update some comments
RaulPPelaez Apr 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,20 @@ to return forces.
torch_force.setOutputsForces(True)
```

Recording the model into a CUDA graph
-------------------------------------

You can ask `TorchForce` to run the model using [CUDA graphs](https://pytorch.org/docs/stable/notes/cuda.html#cuda-graphs). Not every model will be compatible with this feature, but it can be a significant performance boost for some models. To enable it the CUDA platform must be used and an special property must be provided to `TorchForce`:

```python
torch_force.setProperty("useCUDAGraphs", "true")
# The property can also be queried at construction
RaulPPelaez marked this conversation as resolved.
Show resolved Hide resolved
torch_force = TorchForce('model.pt', {'useCUDAGraphs': 'true'})
```

The first time the model is run, it will be compiled into a CUDA graph. Subsequent runs will use the compiled graph, which can be significantly faster. It is possible that compilation fails, in which case an `OpenMMException` will be raised. If that happens, you can disable CUDA graphs and try again.


License
=======

Expand Down
28 changes: 24 additions & 4 deletions openmmapi/include/TorchForce.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
* *
* Portions copyright (c) 2018-2022 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* Contributors: Raimondas Galvelis, Raul P. Pelaez *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
Expand All @@ -34,6 +34,7 @@

#include "openmm/Context.h"
#include "openmm/Force.h"
#include <map>
#include <string>
#include <torch/torch.h>
#include "internal/windowsExportTorch.h"
Expand All @@ -54,17 +55,20 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
* Create a TorchForce. The network is defined by a PyTorch ScriptModule saved
* to a file.
*
* @param file the path to the file containing the network
* @param file the path to the file containing the network
* @param properties the property map
RaulPPelaez marked this conversation as resolved.
Show resolved Hide resolved
*/
TorchForce(const std::string& file);
TorchForce(const std::string& file,
const std::map<std::string, std::string>& properties = {});
/**
* Create a TorchForce. The network is defined by a PyTorch ScriptModule
* Note that this constructor makes a copy of the provided module.
* Any changes to the module after calling this constructor will be ignored by TorchForce.
*
* @param module an instance of the torch module
* @param properties the property map
*/
TorchForce(const torch::jit::Module &module);
TorchForce(const torch::jit::Module &module, const std::map<std::string, std::string>& properties = {});
/**
* Get the path to the file containing the network.
* If the TorchForce instance was constructed with a module, instead of a filename,
Expand Down Expand Up @@ -140,6 +144,20 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
* @param defaultValue the default value of the parameter
*/
void setGlobalParameterDefaultValue(int index, double defaultValue);
/**
* Set a value of a property.
*
* @param name the name of the property
* @param value the value of the property
*/
void setProperty(const std::string& name, const std::string& value);
/**
* Get a value of a property.
*
* @param name the name of the property
* @return the value of the property
*/
const std::string& getProperty(const std::string& name) const;
RaulPPelaez marked this conversation as resolved.
Show resolved Hide resolved
protected:
OpenMM::ForceImpl* createImpl() const;
private:
Expand All @@ -148,6 +166,8 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
bool usePeriodic, outputsForces;
std::vector<GlobalParameterInfo> globalParameters;
torch::jit::Module module;
std::map<std::string, std::string> properties;
std::string emptyProperty;
};

/**
Expand Down
16 changes: 13 additions & 3 deletions openmmapi/src/TorchForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
* *
* Portions copyright (c) 2018-2022 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* Contributors: Raimondas Galvelis, Raul P. Pelaez *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
Expand Down Expand Up @@ -41,10 +41,10 @@ using namespace TorchPlugin;
using namespace OpenMM;
using namespace std;

TorchForce::TorchForce(const torch::jit::Module& module) : file(), usePeriodic(false), outputsForces(false), module(module) {
TorchForce::TorchForce(const torch::jit::Module& module, const map<string, string>& properties) : file(), usePeriodic(false), outputsForces(false), module(module), properties(properties) {
}

TorchForce::TorchForce(const std::string& file) : TorchForce(torch::jit::load(file)) {
TorchForce::TorchForce(const std::string& file, const map<string, string>& properties) : TorchForce(torch::jit::load(file), properties) {
this->file = file;
}

Expand Down Expand Up @@ -104,3 +104,13 @@ void TorchForce::setGlobalParameterDefaultValue(int index, double defaultValue)
ASSERT_VALID_INDEX(index, globalParameters);
globalParameters[index].defaultValue = defaultValue;
}

void TorchForce::setProperty(const std::string& name, const std::string& value) {
properties[name] = value;
RaulPPelaez marked this conversation as resolved.
Show resolved Hide resolved
}

const std::string& TorchForce::getProperty(const std::string& name) const {
if (properties.find(name) != properties.end())
return properties.at(name);
return emptyProperty;
}
Loading