Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a constructor to TorchForce that takes a torch::jit::Module #97

Merged
merged 34 commits into from
Feb 10, 2023
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
f344f82
Add version number as a member to TorchForceProxy
RaulPPelaez Jan 20, 2023
73b5222
Encode the model file contents when serializing TorchForce
RaulPPelaez Jan 20, 2023
66b2530
Add tests for new TorchForce serialization
RaulPPelaez Jan 20, 2023
74bf087
Fix test not finding Python executable
RaulPPelaez Jan 23, 2023
85a6acd
Format include directives correctly
RaulPPelaez Jan 25, 2023
68cc189
Hardcode TorchForceProxy version number
RaulPPelaez Jan 25, 2023
49d5d6e
Fix formatting issues
RaulPPelaez Jan 25, 2023
20f4a7e
Move Python serialization test to the correct place
RaulPPelaez Jan 25, 2023
58dbbba
Make function encodeFromFileName static
RaulPPelaez Jan 25, 2023
612d383
Update serialization python test to correctly remove temporary files …
RaulPPelaez Jan 26, 2023
8fe7906
Use the base64 encoding capabilities of openssl to serialize model file
RaulPPelaez Jan 26, 2023
f996da9
Update TorchForce serializer
RaulPPelaez Jan 26, 2023
05d7764
Add a constructor to TorchForce that takes a torch::jit::Module.
RaulPPelaez Jan 30, 2023
f4e2b76
Remove unnecessary include
RaulPPelaez Jan 30, 2023
60a1ebe
Change i_file to file in TorchForce constructor
RaulPPelaez Jan 31, 2023
7c7b068
Add swig typemaps to new TorchForce constructor
RaulPPelaez Feb 1, 2023
8fae5eb
Add setup.py as a dependency for the PythonInstall CMake rule
RaulPPelaez Feb 3, 2023
ab0ef40
Fix swig out typemap for torch::jit::Module
RaulPPelaez Feb 3, 2023
a23672a
Remove commented line in CMakeLists.txt
RaulPPelaez Feb 6, 2023
1e4cae6
Remove unnecessary dependency in setup.py
RaulPPelaez Feb 6, 2023
103be5c
Add more tests for new constructor
RaulPPelaez Feb 6, 2023
068779b
Add some comments for the new constructor
RaulPPelaez Feb 6, 2023
5d54cbb
Merge branch 'serialization' into module_constructor
RaulPPelaez Feb 6, 2023
0ed5cee
Updates to TorchForce serialization
RaulPPelaez Feb 6, 2023
fecb6d0
Use hex encoding instead of base64 for serialization.
RaulPPelaez Feb 6, 2023
57191be
Remove unnecessary header
RaulPPelaez Feb 6, 2023
25bdfac
Update Python serialization test
RaulPPelaez Feb 6, 2023
8597c34
Merge remote-tracking branch 'origin/master' into module_constructor
RaulPPelaez Feb 7, 2023
ba554e8
Minor changes
RaulPPelaez Feb 7, 2023
51baa27
Improve temporary path handling in python serialization tests
RaulPPelaez Feb 8, 2023
abf43ff
More informative exception when failing to serialize TorchForce
RaulPPelaez Feb 8, 2023
6210b2b
Remove unnecessary check in TorchForce serialization
RaulPPelaez Feb 8, 2023
7c4cf66
Changes to C++ serialization tests
RaulPPelaez Feb 8, 2023
19749f0
Changes to C++ serialization tests
RaulPPelaez Feb 8, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ LINK_DIRECTORIES("${OPENMM_DIR}/lib" "${OPENMM_DIR}/lib/plugins")
SET(PYTORCH_DIR "" CACHE PATH "Where the PyTorch C++ API is installed")
SET(CMAKE_PREFIX_PATH "${PYTORCH_DIR}")
FIND_PACKAGE(Torch REQUIRED)
#LINK_DIRECTORIES("${TENSORFLOW_DIR}/lib")

# Specify the C++ version we are building for.
SET (CMAKE_CXX_STANDARD 14)
Expand Down
22 changes: 21 additions & 1 deletion openmmapi/include/TorchForce.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,18 @@
#include "openmm/Context.h"
#include "openmm/Force.h"
#include <string>
#include <torch/torch.h>
#include "internal/windowsExportTorch.h"

namespace TorchPlugin {

/**
* This class implements forces that are defined by user-supplied neural networks.
* It uses the PyTorch library to perform the computations. */
* It uses the PyTorch library to perform the computations.
* The PyTorch module can either be passed directly as an argument to
* the constructor or loaded from a file. In either case, the
* constructor makes a copy of the module in memory. Later changes to
* the original module or to the file do not affect it.*/

class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
public:
Expand All @@ -52,10 +57,24 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
* @param file the path to the file containing the network
*/
TorchForce(const std::string& file);
/**
* Create a TorchForce. The network is defined by a PyTorch ScriptModule
* Note that this constructor makes a copy of the provided module.
* Any changes to the module after calling this constructor will be ignored by TorchForce.
*
* @param module an instance of the torch module
*/
TorchForce(const torch::jit::Module &module);
/**
* Get the path to the file containing the network.
* If the TorchForce instance was constructed with a module, instead of a filename,
* this function returns an empty string.
*/
const std::string& getFile() const;
/**
* Get the torch module currently in use.
*/
const torch::jit::Module & getModule() const;
/**
* Set whether this force makes use of periodic boundary conditions. If this is set
* to true, the network must take a 3x3 tensor as its second input, which
Expand Down Expand Up @@ -128,6 +147,7 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
std::string file;
bool usePeriodic, outputsForces;
std::vector<GlobalParameterInfo> globalParameters;
torch::jit::Module module;
};

/**
Expand Down
1 change: 0 additions & 1 deletion openmmapi/include/internal/TorchForceImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ class OPENMM_EXPORT_NN TorchForceImpl : public OpenMM::ForceImpl {
private:
const TorchForce& owner;
OpenMM::Kernel kernel;
torch::jit::script::Module module;
};

} // namespace TorchPlugin
Expand Down
12 changes: 11 additions & 1 deletion openmmapi/src/TorchForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,28 @@
#include "openmm/OpenMMException.h"
#include "openmm/internal/AssertionUtilities.h"
#include <fstream>
#include <torch/torch.h>
#include <torch/csrc/jit/serialization/import.h>

using namespace TorchPlugin;
using namespace OpenMM;
using namespace std;

TorchForce::TorchForce(const std::string& file) : file(file), usePeriodic(false), outputsForces(false) {
TorchForce::TorchForce(const torch::jit::Module& module) : file(), usePeriodic(false), outputsForces(false), module(module) {
}

TorchForce::TorchForce(const std::string& file) : TorchForce(torch::jit::load(file)) {
this->file = file;
}

const string& TorchForce::getFile() const {
return file;
}

const torch::jit::Module& TorchForce::getModule() const {
return this->module;
}

ForceImpl* TorchForce::createImpl() const {
return new TorchForceImpl(*this);
}
Expand Down
6 changes: 1 addition & 5 deletions openmmapi/src/TorchForceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,8 @@ TorchForceImpl::~TorchForceImpl() {
}

void TorchForceImpl::initialize(ContextImpl& context) {
// Load the module from the file.

module = torch::jit::load(owner.getFile());

auto module = owner.getModule();
// Create the kernel.

kernel = context.getPlatform().createKernel(CalcTorchForceKernel::Name(), context);
kernel.getAs<CalcTorchForceKernel>().initialize(context.getSystem(), owner, module);
}
Expand Down
13 changes: 8 additions & 5 deletions platforms/cuda/tests/TestCudaTorchForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
#include "sfmt/SFMT.h"
#include <cmath>
#include <iostream>
#include <torch/torch.h>
#include <torch/csrc/jit/serialization/import.h>
#include <vector>

using namespace TorchPlugin;
Expand All @@ -52,7 +54,7 @@ extern "C" OPENMM_EXPORT void registerTorchCudaKernelFactories();

void testForce(bool outputsForces) {
// Create a random cloud of particles.

const int numParticles = 10;
System system;
vector<Vec3> positions(numParticles);
Expand All @@ -62,20 +64,21 @@ void testForce(bool outputsForces) {
system.addParticle(1.0);
positions[i] = Vec3(genrand_real2(sfmt), genrand_real2(sfmt), genrand_real2(sfmt))*10;
}
TorchForce* force = new TorchForce(outputsForces ? "tests/forces.pt" : "tests/central.pt");
auto model = torch::jit::load(outputsForces ? "tests/forces.pt" : "tests/central.pt");
TorchForce* force = new TorchForce(model);
force->setOutputsForces(outputsForces);
system.addForce(force);

// Compute the forces and energy.

VerletIntegrator integ(1.0);
Platform& platform = Platform::getPlatformByName("CUDA");
Context context(system, integ, platform);
context.setPositions(positions);
State state = context.getState(State::Energy | State::Forces);

// See if the energy is correct. The network defines a potential of the form E(r) = |r|^2

double expectedEnergy = 0;
for (int i = 0; i < numParticles; i++) {
Vec3 pos = positions[i];
Expand Down
6 changes: 5 additions & 1 deletion python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,25 @@ set(WRAP_FILE TorchPluginWrapper.cpp)
set(MODULE_NAME openmmtorch)

# Execute SWIG to generate source code for the Python module.
foreach(dir ${TORCH_INCLUDE_DIRS})
set(torchincs "${torchincs}" "-I${dir}")
endforeach()

add_custom_command(
OUTPUT "${WRAP_FILE}"
COMMAND "${SWIG_EXECUTABLE}"
-python -c++
-o "${WRAP_FILE}"
"-I${OPENMM_DIR}/include"
${torchincs}
"${CMAKE_CURRENT_SOURCE_DIR}/openmmtorch.i"
DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/openmmtorch.i"
WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}"
)

# Compile the Python module.

add_custom_target(PythonInstall DEPENDS "${WRAP_FILE}")
add_custom_target(PythonInstall DEPENDS "${WRAP_FILE}" "${CMAKE_CURRENT_SOURCE_DIR}/setup.py")
set(NN_PLUGIN_HEADER_DIR "${CMAKE_SOURCE_DIR}/openmmapi/include")
set(NN_PLUGIN_LIBRARY_DIR "${CMAKE_BINARY_DIR}")
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py ${CMAKE_CURRENT_BINARY_DIR}/setup.py)
Expand Down
22 changes: 22 additions & 0 deletions python/openmmtorch.i
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "OpenMMDrude.h"
#include "openmm/RPMDIntegrator.h"
#include "openmm/RPMDMonteCarloBarostat.h"
#include <torch/csrc/jit/python/module_python.h>
%}

/*
Expand All @@ -26,12 +27,33 @@
}
}

%typemap(in) const torch::jit::Module&(torch::jit::Module module) {
py::object o = py::reinterpret_borrow<py::object>($input);
module = torch::jit::as_module(o).value();
$1 = &module;
}

%typemap(out) const torch::jit::Module& {
auto fileName = std::tmpnam(nullptr);
$1->save(fileName);
$result = py::module::import("torch.jit").attr("load")(fileName).release().ptr();
//This typemap assumes that torch does not require the file to exist after construction
std::remove(fileName);
}

%typecheck(SWIG_TYPECHECK_POINTER) const torch::jit::Module& {
py::object o = py::reinterpret_borrow<py::object>($input);
$1 = torch::jit::as_module(o).has_value() ? 1 : 0;
}

namespace TorchPlugin {

class TorchForce : public OpenMM::Force {
public:
TorchForce(const std::string& file);
TorchForce(const torch::jit::Module& module);
const std::string& getFile() const;
const torch::jit::Module& getModule() const;
void setUsesPeriodicBoundaryConditions(bool periodic);
bool usesPeriodicBoundaryConditions() const;
void setOutputsForces(bool);
Expand Down
8 changes: 4 additions & 4 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import platform

openmm_dir = '@OPENMM_DIR@'
torch_include_dirs = '@TORCH_INCLUDE_DIRS@'.split(';')
nn_plugin_header_dir = '@NN_PLUGIN_HEADER_DIR@'
nn_plugin_library_dir = '@NN_PLUGIN_LIBRARY_DIR@'
torch_dir, _ = os.path.split('@TORCH_LIBRARY@')

# setup extra compile and link arguments on Mac
extra_compile_args = ['-std=c++11']
extra_compile_args = ['-std=c++14']
extra_link_args = []

if platform.system() == 'Darwin':
Expand All @@ -20,9 +21,9 @@
extension = Extension(name='_openmmtorch',
sources=['TorchPluginWrapper.cpp'],
libraries=['OpenMM', 'OpenMMTorch'],
include_dirs=[os.path.join(openmm_dir, 'include'), nn_plugin_header_dir],
include_dirs=[os.path.join(openmm_dir, 'include'), nn_plugin_header_dir] + torch_include_dirs,
library_dirs=[os.path.join(openmm_dir, 'lib'), nn_plugin_library_dir],
runtime_library_dirs=[os.path.join(openmm_dir, 'lib'), torch_dir],
runtime_library_dirs=[os.path.join(openmm_dir, 'lib')],
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args
)
Expand All @@ -32,4 +33,3 @@
py_modules=['openmmtorch'],
ext_modules=[extension],
)

50 changes: 50 additions & 0 deletions python/tests/TestSerializeTorchForce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch
import shutil
import pytest
from openmm import XmlSerializer, OpenMMException
from openmmtorch import TorchForce

@pytest.fixture
def temporal_path(tmp_path_factory):
temporal = tmp_path_factory.mktemp("data")
yield str(temporal)
shutil.rmtree(str(temporal))

class ForceModule(torch.nn.Module):
"""A simple module that can be serialized"""
def forward(self, positions):
return torch.sum(positions**2)


class ForceModule2(torch.nn.Module):
"""A dummy module distict from ForceModule"""
def forward(self, positions):
return torch.sum(positions**3)


def createAndSerialize(model_filename, serialized_filename):
module = torch.jit.script(ForceModule())
module.save(model_filename)
torch_force = TorchForce(model_filename)
stored = XmlSerializer.serialize(torch_force)
with open(serialized_filename, 'w') as f:
f.write(stored)

def readXML(filename):
with open(filename, 'r') as f:
fileContents = f.read()
return fileContents

def deserialize(filename):
other_force = XmlSerializer.deserialize(readXML(filename))

def test_serialize(temporal_path):
model_filename = temporal_path + "/model.pt"
serialized_filename = temporal_path+ "/stored.xml"
createAndSerialize(model_filename, serialized_filename)
RaulPPelaez marked this conversation as resolved.
Show resolved Hide resolved

def test_deserialize(temporal_path):
model_filename = temporal_path+ "/model.pt"
serialized_filename = temporal_path+ "/stored.xml"
createAndSerialize(model_filename, serialized_filename)
deserialize(serialized_filename)
25 changes: 20 additions & 5 deletions python/tests/TestTorchForce.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,23 @@
import torch as pt
from tempfile import NamedTemporaryFile

@pytest.mark.parametrize('model_file, output_forces,',
[('../../tests/central.pt', False),
('../../tests/forces.pt', True)])
@pytest.mark.parametrize('model_file,',
['../../tests/central.pt',
'../../tests/forces.pt'])
def testConstructors(model_file):
force = ot.TorchForce(model_file)
model = pt.jit.load(model_file)
force = ot.TorchForce(pt.jit.load(model_file))
model = force.getModule()
force = ot.TorchForce(model)

@pytest.mark.parametrize('model_file, output_forces, use_module_constructor',
[('../../tests/central.pt', False, False,),
('../../tests/forces.pt', True, False),
('../../tests/forces.pt', True, True)])
@pytest.mark.parametrize('use_cv_force', [True, False])
@pytest.mark.parametrize('platform', ['Reference', 'CPU', 'CUDA', 'OpenCL'])
def testForce(model_file, output_forces, use_cv_force, platform):
def testForce(model_file, output_forces, use_module_constructor, use_cv_force, platform):

if pt.cuda.device_count() < 1 and platform == 'CUDA':
pytest.skip('A CUDA device is not available')
Expand All @@ -24,7 +35,11 @@ def testForce(model_file, output_forces, use_cv_force, platform):
system.addParticle(1.0)

# Create a force
force = ot.TorchForce(model_file)
if use_module_constructor:
model = pt.jit.load(model_file)
force = ot.TorchForce(model)
else:
force = ot.TorchForce(model_file)
assert not force.getOutputsForces() # Check the default
force.setOutputsForces(output_forces)
assert force.getOutputsForces() == output_forces
Expand Down
Loading