Skip to content

Commit

Permalink
Add a constructor to TorchForce that takes a torch::jit::Module (#97)
Browse files Browse the repository at this point in the history
* Add version number as a member to TorchForceProxy

* Encode the model file contents when serializing TorchForce

* Add tests for new TorchForce serialization

* Fix test not finding Python executable

* Format include directives correctly

* Hardcode TorchForceProxy version number

* Fix formatting issues

* Move Python serialization test to the correct place

* Make function encodeFromFileName static

* Update serialization python test to correctly remove temporary files after executing

* Use the base64 encoding capabilities of openssl to serialize model file

* Update TorchForce serializer

* Add a constructor to TorchForce that takes a torch::jit::Module.
 TorchForce(string fileName) is implemented by delegating to the new
 constructor.
 Update serialization test accordingly to compare the module file name
 and the module itself.

* Remove unnecessary include

* Change i_file to file in TorchForce constructor

* Add swig typemaps to new TorchForce constructor

* Add setup.py as a dependency for the PythonInstall CMake rule

* Fix swig out typemap for torch::jit::Module
 Now it is possible to call getModule() on a TorchForce object from
 Python, which will return a module of the same type as, for instance, torch.jit.load()

* Remove commented line in CMakeLists.txt

* Remove unnecessary dependency in setup.py

* Add more tests for new constructor

* Add some comments for the new constructor

* Updates to TorchForce serialization

* Use hex encoding instead of base64 for serialization.
SSL no longer a direct dependency.

* Remove unnecessary header

* Update Python serialization test

* Minor changes

* Improve temporary path handling in python serialization tests

* More informative exception when failing to serialize TorchForce

* Remove unnecessary check in TorchForce serialization

* Changes to C++ serialization tests

* Changes to C++ serialization tests
  • Loading branch information
RaulPPelaez authored Feb 10, 2023
1 parent c76684f commit 769302a
Show file tree
Hide file tree
Showing 13 changed files with 221 additions and 35 deletions.
1 change: 0 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ LINK_DIRECTORIES("${OPENMM_DIR}/lib" "${OPENMM_DIR}/lib/plugins")
SET(PYTORCH_DIR "" CACHE PATH "Where the PyTorch C++ API is installed")
SET(CMAKE_PREFIX_PATH "${PYTORCH_DIR}")
FIND_PACKAGE(Torch REQUIRED)
#LINK_DIRECTORIES("${TENSORFLOW_DIR}/lib")

# Specify the C++ version we are building for.
SET (CMAKE_CXX_STANDARD 14)
Expand Down
22 changes: 21 additions & 1 deletion openmmapi/include/TorchForce.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,18 @@
#include "openmm/Context.h"
#include "openmm/Force.h"
#include <string>
#include <torch/torch.h>
#include "internal/windowsExportTorch.h"

namespace TorchPlugin {

/**
* This class implements forces that are defined by user-supplied neural networks.
* It uses the PyTorch library to perform the computations. */
* It uses the PyTorch library to perform the computations.
* The PyTorch module can either be passed directly as an argument to
* the constructor or loaded from a file. In either case, the
* constructor makes a copy of the module in memory. Later changes to
* the original module or to the file do not affect it.*/

class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
public:
Expand All @@ -52,10 +57,24 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
* @param file the path to the file containing the network
*/
TorchForce(const std::string& file);
/**
* Create a TorchForce. The network is defined by a PyTorch ScriptModule
* Note that this constructor makes a copy of the provided module.
* Any changes to the module after calling this constructor will be ignored by TorchForce.
*
* @param module an instance of the torch module
*/
TorchForce(const torch::jit::Module &module);
/**
* Get the path to the file containing the network.
* If the TorchForce instance was constructed with a module, instead of a filename,
* this function returns an empty string.
*/
const std::string& getFile() const;
/**
* Get the torch module currently in use.
*/
const torch::jit::Module & getModule() const;
/**
* Set whether this force makes use of periodic boundary conditions. If this is set
* to true, the network must take a 3x3 tensor as its second input, which
Expand Down Expand Up @@ -128,6 +147,7 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
std::string file;
bool usePeriodic, outputsForces;
std::vector<GlobalParameterInfo> globalParameters;
torch::jit::Module module;
};

/**
Expand Down
1 change: 0 additions & 1 deletion openmmapi/include/internal/TorchForceImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ class OPENMM_EXPORT_NN TorchForceImpl : public OpenMM::ForceImpl {
private:
const TorchForce& owner;
OpenMM::Kernel kernel;
torch::jit::script::Module module;
};

} // namespace TorchPlugin
Expand Down
12 changes: 11 additions & 1 deletion openmmapi/src/TorchForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,28 @@
#include "openmm/OpenMMException.h"
#include "openmm/internal/AssertionUtilities.h"
#include <fstream>
#include <torch/torch.h>
#include <torch/csrc/jit/serialization/import.h>

using namespace TorchPlugin;
using namespace OpenMM;
using namespace std;

TorchForce::TorchForce(const std::string& file) : file(file), usePeriodic(false), outputsForces(false) {
TorchForce::TorchForce(const torch::jit::Module& module) : file(), usePeriodic(false), outputsForces(false), module(module) {
}

TorchForce::TorchForce(const std::string& file) : TorchForce(torch::jit::load(file)) {
this->file = file;
}

const string& TorchForce::getFile() const {
return file;
}

const torch::jit::Module& TorchForce::getModule() const {
return this->module;
}

ForceImpl* TorchForce::createImpl() const {
return new TorchForceImpl(*this);
}
Expand Down
6 changes: 1 addition & 5 deletions openmmapi/src/TorchForceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,8 @@ TorchForceImpl::~TorchForceImpl() {
}

void TorchForceImpl::initialize(ContextImpl& context) {
// Load the module from the file.

module = torch::jit::load(owner.getFile());

auto module = owner.getModule();
// Create the kernel.

kernel = context.getPlatform().createKernel(CalcTorchForceKernel::Name(), context);
kernel.getAs<CalcTorchForceKernel>().initialize(context.getSystem(), owner, module);
}
Expand Down
13 changes: 8 additions & 5 deletions platforms/cuda/tests/TestCudaTorchForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
#include "sfmt/SFMT.h"
#include <cmath>
#include <iostream>
#include <torch/torch.h>
#include <torch/csrc/jit/serialization/import.h>
#include <vector>

using namespace TorchPlugin;
Expand All @@ -52,7 +54,7 @@ extern "C" OPENMM_EXPORT void registerTorchCudaKernelFactories();

void testForce(bool outputsForces) {
// Create a random cloud of particles.

const int numParticles = 10;
System system;
vector<Vec3> positions(numParticles);
Expand All @@ -62,20 +64,21 @@ void testForce(bool outputsForces) {
system.addParticle(1.0);
positions[i] = Vec3(genrand_real2(sfmt), genrand_real2(sfmt), genrand_real2(sfmt))*10;
}
TorchForce* force = new TorchForce(outputsForces ? "tests/forces.pt" : "tests/central.pt");
auto model = torch::jit::load(outputsForces ? "tests/forces.pt" : "tests/central.pt");
TorchForce* force = new TorchForce(model);
force->setOutputsForces(outputsForces);
system.addForce(force);

// Compute the forces and energy.

VerletIntegrator integ(1.0);
Platform& platform = Platform::getPlatformByName("CUDA");
Context context(system, integ, platform);
context.setPositions(positions);
State state = context.getState(State::Energy | State::Forces);

// See if the energy is correct. The network defines a potential of the form E(r) = |r|^2

double expectedEnergy = 0;
for (int i = 0; i < numParticles; i++) {
Vec3 pos = positions[i];
Expand Down
6 changes: 5 additions & 1 deletion python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,25 @@ set(WRAP_FILE TorchPluginWrapper.cpp)
set(MODULE_NAME openmmtorch)

# Execute SWIG to generate source code for the Python module.
foreach(dir ${TORCH_INCLUDE_DIRS})
set(torchincs "${torchincs}" "-I${dir}")
endforeach()

add_custom_command(
OUTPUT "${WRAP_FILE}"
COMMAND "${SWIG_EXECUTABLE}"
-python -c++
-o "${WRAP_FILE}"
"-I${OPENMM_DIR}/include"
${torchincs}
"${CMAKE_CURRENT_SOURCE_DIR}/openmmtorch.i"
DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/openmmtorch.i"
WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}"
)

# Compile the Python module.

add_custom_target(PythonInstall DEPENDS "${WRAP_FILE}")
add_custom_target(PythonInstall DEPENDS "${WRAP_FILE}" "${CMAKE_CURRENT_SOURCE_DIR}/setup.py")
set(NN_PLUGIN_HEADER_DIR "${CMAKE_SOURCE_DIR}/openmmapi/include")
set(NN_PLUGIN_LIBRARY_DIR "${CMAKE_BINARY_DIR}")
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py ${CMAKE_CURRENT_BINARY_DIR}/setup.py)
Expand Down
22 changes: 22 additions & 0 deletions python/openmmtorch.i
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "OpenMMDrude.h"
#include "openmm/RPMDIntegrator.h"
#include "openmm/RPMDMonteCarloBarostat.h"
#include <torch/csrc/jit/python/module_python.h>
%}

/*
Expand All @@ -26,12 +27,33 @@
}
}

%typemap(in) const torch::jit::Module&(torch::jit::Module module) {
py::object o = py::reinterpret_borrow<py::object>($input);
module = torch::jit::as_module(o).value();
$1 = &module;
}

%typemap(out) const torch::jit::Module& {
auto fileName = std::tmpnam(nullptr);
$1->save(fileName);
$result = py::module::import("torch.jit").attr("load")(fileName).release().ptr();
//This typemap assumes that torch does not require the file to exist after construction
std::remove(fileName);
}

%typecheck(SWIG_TYPECHECK_POINTER) const torch::jit::Module& {
py::object o = py::reinterpret_borrow<py::object>($input);
$1 = torch::jit::as_module(o).has_value() ? 1 : 0;
}

namespace TorchPlugin {

class TorchForce : public OpenMM::Force {
public:
TorchForce(const std::string& file);
TorchForce(const torch::jit::Module& module);
const std::string& getFile() const;
const torch::jit::Module& getModule() const;
void setUsesPeriodicBoundaryConditions(bool periodic);
bool usesPeriodicBoundaryConditions() const;
void setOutputsForces(bool);
Expand Down
8 changes: 4 additions & 4 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import platform

openmm_dir = '@OPENMM_DIR@'
torch_include_dirs = '@TORCH_INCLUDE_DIRS@'.split(';')
nn_plugin_header_dir = '@NN_PLUGIN_HEADER_DIR@'
nn_plugin_library_dir = '@NN_PLUGIN_LIBRARY_DIR@'
torch_dir, _ = os.path.split('@TORCH_LIBRARY@')

# setup extra compile and link arguments on Mac
extra_compile_args = ['-std=c++11']
extra_compile_args = ['-std=c++14']
extra_link_args = []

if platform.system() == 'Darwin':
Expand All @@ -20,9 +21,9 @@
extension = Extension(name='_openmmtorch',
sources=['TorchPluginWrapper.cpp'],
libraries=['OpenMM', 'OpenMMTorch'],
include_dirs=[os.path.join(openmm_dir, 'include'), nn_plugin_header_dir],
include_dirs=[os.path.join(openmm_dir, 'include'), nn_plugin_header_dir] + torch_include_dirs,
library_dirs=[os.path.join(openmm_dir, 'lib'), nn_plugin_library_dir],
runtime_library_dirs=[os.path.join(openmm_dir, 'lib'), torch_dir],
runtime_library_dirs=[os.path.join(openmm_dir, 'lib')],
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args
)
Expand All @@ -32,4 +33,3 @@
py_modules=['openmmtorch'],
ext_modules=[extension],
)

48 changes: 48 additions & 0 deletions python/tests/TestSerializeTorchForce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
import shutil
import pytest
from openmm import XmlSerializer, OpenMMException
from openmmtorch import TorchForce
import os
import tempfile

class ForceModule(torch.nn.Module):
"""A simple module that can be serialized"""
def forward(self, positions):
return torch.sum(positions**2)


class ForceModule2(torch.nn.Module):
"""A dummy module distict from ForceModule"""
def forward(self, positions):
return torch.sum(positions**3)


def createAndSerialize(model_filename, serialized_filename):
module = torch.jit.script(ForceModule())
module.save(model_filename)
torch_force = TorchForce(model_filename)
stored = XmlSerializer.serialize(torch_force)
with open(serialized_filename, 'w') as f:
f.write(stored)

def readXML(filename):
with open(filename, 'r') as f:
fileContents = f.read()
return fileContents

def deserialize(filename):
other_force = XmlSerializer.deserialize(readXML(filename))

def test_serialize():
with tempfile.TemporaryDirectory() as tempdir:
model_filename = os.path.join(tempdir, 'model.pt')
serialized_filename = os.path.join(tempdir, 'stored.xml')
createAndSerialize(model_filename, serialized_filename)

def test_deserialize():
with tempfile.TemporaryDirectory() as tempdir:
model_filename = os.path.join(tempdir, 'model.pt')
serialized_filename = os.path.join(tempdir, 'stored.xml')
createAndSerialize(model_filename, serialized_filename)
deserialize(serialized_filename)
25 changes: 20 additions & 5 deletions python/tests/TestTorchForce.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,23 @@
import torch as pt
from tempfile import NamedTemporaryFile

@pytest.mark.parametrize('model_file, output_forces,',
[('../../tests/central.pt', False),
('../../tests/forces.pt', True)])
@pytest.mark.parametrize('model_file,',
['../../tests/central.pt',
'../../tests/forces.pt'])
def testConstructors(model_file):
force = ot.TorchForce(model_file)
model = pt.jit.load(model_file)
force = ot.TorchForce(pt.jit.load(model_file))
model = force.getModule()
force = ot.TorchForce(model)

@pytest.mark.parametrize('model_file, output_forces, use_module_constructor',
[('../../tests/central.pt', False, False,),
('../../tests/forces.pt', True, False),
('../../tests/forces.pt', True, True)])
@pytest.mark.parametrize('use_cv_force', [True, False])
@pytest.mark.parametrize('platform', ['Reference', 'CPU', 'CUDA', 'OpenCL'])
def testForce(model_file, output_forces, use_cv_force, platform):
def testForce(model_file, output_forces, use_module_constructor, use_cv_force, platform):

if pt.cuda.device_count() < 1 and platform == 'CUDA':
pytest.skip('A CUDA device is not available')
Expand All @@ -24,7 +35,11 @@ def testForce(model_file, output_forces, use_cv_force, platform):
system.addParticle(1.0)

# Create a force
force = ot.TorchForce(model_file)
if use_module_constructor:
model = pt.jit.load(model_file)
force = ot.TorchForce(model)
else:
force = ot.TorchForce(model_file)
assert not force.getOutputsForces() # Check the default
force.setOutputsForces(output_forces)
assert force.getOutputsForces() == output_forces
Expand Down
Loading

0 comments on commit 769302a

Please sign in to comment.