diff --git a/CMakeLists.txt b/CMakeLists.txt index b27faf78..8573ec30 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,7 +13,6 @@ LINK_DIRECTORIES("${OPENMM_DIR}/lib" "${OPENMM_DIR}/lib/plugins") SET(PYTORCH_DIR "" CACHE PATH "Where the PyTorch C++ API is installed") SET(CMAKE_PREFIX_PATH "${PYTORCH_DIR}") FIND_PACKAGE(Torch REQUIRED) -#LINK_DIRECTORIES("${TENSORFLOW_DIR}/lib") # Specify the C++ version we are building for. SET (CMAKE_CXX_STANDARD 14) diff --git a/openmmapi/include/TorchForce.h b/openmmapi/include/TorchForce.h index 8cba9482..46e28193 100644 --- a/openmmapi/include/TorchForce.h +++ b/openmmapi/include/TorchForce.h @@ -35,13 +35,18 @@ #include "openmm/Context.h" #include "openmm/Force.h" #include +#include #include "internal/windowsExportTorch.h" namespace TorchPlugin { /** * This class implements forces that are defined by user-supplied neural networks. - * It uses the PyTorch library to perform the computations. */ + * It uses the PyTorch library to perform the computations. + * The PyTorch module can either be passed directly as an argument to + * the constructor or loaded from a file. In either case, the + * constructor makes a copy of the module in memory. Later changes to + * the original module or to the file do not affect it.*/ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force { public: @@ -52,10 +57,24 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force { * @param file the path to the file containing the network */ TorchForce(const std::string& file); + /** + * Create a TorchForce. The network is defined by a PyTorch ScriptModule + * Note that this constructor makes a copy of the provided module. + * Any changes to the module after calling this constructor will be ignored by TorchForce. + * + * @param module an instance of the torch module + */ + TorchForce(const torch::jit::Module &module); /** * Get the path to the file containing the network. + * If the TorchForce instance was constructed with a module, instead of a filename, + * this function returns an empty string. */ const std::string& getFile() const; + /** + * Get the torch module currently in use. + */ + const torch::jit::Module & getModule() const; /** * Set whether this force makes use of periodic boundary conditions. If this is set * to true, the network must take a 3x3 tensor as its second input, which @@ -128,6 +147,7 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force { std::string file; bool usePeriodic, outputsForces; std::vector globalParameters; + torch::jit::Module module; }; /** diff --git a/openmmapi/include/internal/TorchForceImpl.h b/openmmapi/include/internal/TorchForceImpl.h index f3489c2e..0b09182f 100644 --- a/openmmapi/include/internal/TorchForceImpl.h +++ b/openmmapi/include/internal/TorchForceImpl.h @@ -65,7 +65,6 @@ class OPENMM_EXPORT_NN TorchForceImpl : public OpenMM::ForceImpl { private: const TorchForce& owner; OpenMM::Kernel kernel; - torch::jit::script::Module module; }; } // namespace TorchPlugin diff --git a/openmmapi/src/TorchForce.cpp b/openmmapi/src/TorchForce.cpp index f9df35a2..b9aa64d6 100644 --- a/openmmapi/src/TorchForce.cpp +++ b/openmmapi/src/TorchForce.cpp @@ -34,18 +34,28 @@ #include "openmm/OpenMMException.h" #include "openmm/internal/AssertionUtilities.h" #include +#include +#include using namespace TorchPlugin; using namespace OpenMM; using namespace std; -TorchForce::TorchForce(const std::string& file) : file(file), usePeriodic(false), outputsForces(false) { +TorchForce::TorchForce(const torch::jit::Module& module) : file(), usePeriodic(false), outputsForces(false), module(module) { +} + +TorchForce::TorchForce(const std::string& file) : TorchForce(torch::jit::load(file)) { + this->file = file; } const string& TorchForce::getFile() const { return file; } +const torch::jit::Module& TorchForce::getModule() const { + return this->module; +} + ForceImpl* TorchForce::createImpl() const { return new TorchForceImpl(*this); } diff --git a/openmmapi/src/TorchForceImpl.cpp b/openmmapi/src/TorchForceImpl.cpp index 3a67782c..054bc846 100644 --- a/openmmapi/src/TorchForceImpl.cpp +++ b/openmmapi/src/TorchForceImpl.cpp @@ -46,12 +46,8 @@ TorchForceImpl::~TorchForceImpl() { } void TorchForceImpl::initialize(ContextImpl& context) { - // Load the module from the file. - - module = torch::jit::load(owner.getFile()); - + auto module = owner.getModule(); // Create the kernel. - kernel = context.getPlatform().createKernel(CalcTorchForceKernel::Name(), context); kernel.getAs().initialize(context.getSystem(), owner, module); } diff --git a/platforms/cuda/tests/TestCudaTorchForce.cpp b/platforms/cuda/tests/TestCudaTorchForce.cpp index 3a96eec5..8e8c56b9 100644 --- a/platforms/cuda/tests/TestCudaTorchForce.cpp +++ b/platforms/cuda/tests/TestCudaTorchForce.cpp @@ -42,6 +42,8 @@ #include "sfmt/SFMT.h" #include #include +#include +#include #include using namespace TorchPlugin; @@ -52,7 +54,7 @@ extern "C" OPENMM_EXPORT void registerTorchCudaKernelFactories(); void testForce(bool outputsForces) { // Create a random cloud of particles. - + const int numParticles = 10; System system; vector positions(numParticles); @@ -62,10 +64,11 @@ void testForce(bool outputsForces) { system.addParticle(1.0); positions[i] = Vec3(genrand_real2(sfmt), genrand_real2(sfmt), genrand_real2(sfmt))*10; } - TorchForce* force = new TorchForce(outputsForces ? "tests/forces.pt" : "tests/central.pt"); + auto model = torch::jit::load(outputsForces ? "tests/forces.pt" : "tests/central.pt"); + TorchForce* force = new TorchForce(model); force->setOutputsForces(outputsForces); system.addForce(force); - + // Compute the forces and energy. VerletIntegrator integ(1.0); @@ -73,9 +76,9 @@ void testForce(bool outputsForces) { Context context(system, integ, platform); context.setPositions(positions); State state = context.getState(State::Energy | State::Forces); - + // See if the energy is correct. The network defines a potential of the form E(r) = |r|^2 - + double expectedEnergy = 0; for (int i = 0; i < numParticles; i++) { Vec3 pos = positions[i]; diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index b9dfd5d7..ee0c19dd 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -2,6 +2,9 @@ set(WRAP_FILE TorchPluginWrapper.cpp) set(MODULE_NAME openmmtorch) # Execute SWIG to generate source code for the Python module. +foreach(dir ${TORCH_INCLUDE_DIRS}) + set(torchincs "${torchincs}" "-I${dir}") +endforeach() add_custom_command( OUTPUT "${WRAP_FILE}" @@ -9,6 +12,7 @@ add_custom_command( -python -c++ -o "${WRAP_FILE}" "-I${OPENMM_DIR}/include" + ${torchincs} "${CMAKE_CURRENT_SOURCE_DIR}/openmmtorch.i" DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/openmmtorch.i" WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}" @@ -16,7 +20,7 @@ add_custom_command( # Compile the Python module. -add_custom_target(PythonInstall DEPENDS "${WRAP_FILE}") +add_custom_target(PythonInstall DEPENDS "${WRAP_FILE}" "${CMAKE_CURRENT_SOURCE_DIR}/setup.py") set(NN_PLUGIN_HEADER_DIR "${CMAKE_SOURCE_DIR}/openmmapi/include") set(NN_PLUGIN_LIBRARY_DIR "${CMAKE_BINARY_DIR}") configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py ${CMAKE_CURRENT_BINARY_DIR}/setup.py) diff --git a/python/openmmtorch.i b/python/openmmtorch.i index 2bbe11b4..bfabe86e 100644 --- a/python/openmmtorch.i +++ b/python/openmmtorch.i @@ -12,6 +12,7 @@ #include "OpenMMDrude.h" #include "openmm/RPMDIntegrator.h" #include "openmm/RPMDMonteCarloBarostat.h" +#include %} /* @@ -26,12 +27,33 @@ } } +%typemap(in) const torch::jit::Module&(torch::jit::Module module) { + py::object o = py::reinterpret_borrow($input); + module = torch::jit::as_module(o).value(); + $1 = &module; +} + +%typemap(out) const torch::jit::Module& { + auto fileName = std::tmpnam(nullptr); + $1->save(fileName); + $result = py::module::import("torch.jit").attr("load")(fileName).release().ptr(); + //This typemap assumes that torch does not require the file to exist after construction + std::remove(fileName); +} + +%typecheck(SWIG_TYPECHECK_POINTER) const torch::jit::Module& { + py::object o = py::reinterpret_borrow($input); + $1 = torch::jit::as_module(o).has_value() ? 1 : 0; +} + namespace TorchPlugin { class TorchForce : public OpenMM::Force { public: TorchForce(const std::string& file); + TorchForce(const torch::jit::Module& module); const std::string& getFile() const; + const torch::jit::Module& getModule() const; void setUsesPeriodicBoundaryConditions(bool periodic); bool usesPeriodicBoundaryConditions() const; void setOutputsForces(bool); diff --git a/python/setup.py b/python/setup.py index b3d09d57..6b57525e 100644 --- a/python/setup.py +++ b/python/setup.py @@ -5,12 +5,13 @@ import platform openmm_dir = '@OPENMM_DIR@' +torch_include_dirs = '@TORCH_INCLUDE_DIRS@'.split(';') nn_plugin_header_dir = '@NN_PLUGIN_HEADER_DIR@' nn_plugin_library_dir = '@NN_PLUGIN_LIBRARY_DIR@' torch_dir, _ = os.path.split('@TORCH_LIBRARY@') # setup extra compile and link arguments on Mac -extra_compile_args = ['-std=c++11'] +extra_compile_args = ['-std=c++14'] extra_link_args = [] if platform.system() == 'Darwin': @@ -20,9 +21,9 @@ extension = Extension(name='_openmmtorch', sources=['TorchPluginWrapper.cpp'], libraries=['OpenMM', 'OpenMMTorch'], - include_dirs=[os.path.join(openmm_dir, 'include'), nn_plugin_header_dir], + include_dirs=[os.path.join(openmm_dir, 'include'), nn_plugin_header_dir] + torch_include_dirs, library_dirs=[os.path.join(openmm_dir, 'lib'), nn_plugin_library_dir], - runtime_library_dirs=[os.path.join(openmm_dir, 'lib'), torch_dir], + runtime_library_dirs=[os.path.join(openmm_dir, 'lib')], extra_compile_args=extra_compile_args, extra_link_args=extra_link_args ) @@ -32,4 +33,3 @@ py_modules=['openmmtorch'], ext_modules=[extension], ) - diff --git a/python/tests/TestSerializeTorchForce.py b/python/tests/TestSerializeTorchForce.py new file mode 100644 index 00000000..c7282724 --- /dev/null +++ b/python/tests/TestSerializeTorchForce.py @@ -0,0 +1,48 @@ +import torch +import shutil +import pytest +from openmm import XmlSerializer, OpenMMException +from openmmtorch import TorchForce +import os +import tempfile + +class ForceModule(torch.nn.Module): + """A simple module that can be serialized""" + def forward(self, positions): + return torch.sum(positions**2) + + +class ForceModule2(torch.nn.Module): + """A dummy module distict from ForceModule""" + def forward(self, positions): + return torch.sum(positions**3) + + +def createAndSerialize(model_filename, serialized_filename): + module = torch.jit.script(ForceModule()) + module.save(model_filename) + torch_force = TorchForce(model_filename) + stored = XmlSerializer.serialize(torch_force) + with open(serialized_filename, 'w') as f: + f.write(stored) + +def readXML(filename): + with open(filename, 'r') as f: + fileContents = f.read() + return fileContents + +def deserialize(filename): + other_force = XmlSerializer.deserialize(readXML(filename)) + +def test_serialize(): + with tempfile.TemporaryDirectory() as tempdir: + model_filename = os.path.join(tempdir, 'model.pt') + serialized_filename = os.path.join(tempdir, 'stored.xml') + createAndSerialize(model_filename, serialized_filename) + +def test_deserialize(): + with tempfile.TemporaryDirectory() as tempdir: + model_filename = os.path.join(tempdir, 'model.pt') + serialized_filename = os.path.join(tempdir, 'stored.xml') + createAndSerialize(model_filename, serialized_filename) + deserialize(serialized_filename) diff --git a/python/tests/TestTorchForce.py b/python/tests/TestTorchForce.py index c7e7708e..332a6cf4 100644 --- a/python/tests/TestTorchForce.py +++ b/python/tests/TestTorchForce.py @@ -6,12 +6,23 @@ import torch as pt from tempfile import NamedTemporaryFile -@pytest.mark.parametrize('model_file, output_forces,', - [('../../tests/central.pt', False), - ('../../tests/forces.pt', True)]) +@pytest.mark.parametrize('model_file,', + ['../../tests/central.pt', + '../../tests/forces.pt']) +def testConstructors(model_file): + force = ot.TorchForce(model_file) + model = pt.jit.load(model_file) + force = ot.TorchForce(pt.jit.load(model_file)) + model = force.getModule() + force = ot.TorchForce(model) + +@pytest.mark.parametrize('model_file, output_forces, use_module_constructor', + [('../../tests/central.pt', False, False,), + ('../../tests/forces.pt', True, False), + ('../../tests/forces.pt', True, True)]) @pytest.mark.parametrize('use_cv_force', [True, False]) @pytest.mark.parametrize('platform', ['Reference', 'CPU', 'CUDA', 'OpenCL']) -def testForce(model_file, output_forces, use_cv_force, platform): +def testForce(model_file, output_forces, use_module_constructor, use_cv_force, platform): if pt.cuda.device_count() < 1 and platform == 'CUDA': pytest.skip('A CUDA device is not available') @@ -24,7 +35,11 @@ def testForce(model_file, output_forces, use_cv_force, platform): system.addParticle(1.0) # Create a force - force = ot.TorchForce(model_file) + if use_module_constructor: + model = pt.jit.load(model_file) + force = ot.TorchForce(model) + else: + force = ot.TorchForce(model_file) assert not force.getOutputsForces() # Check the default force.setOutputsForces(output_forces) assert force.getOutputsForces() == output_forces diff --git a/serialization/src/TorchForceProxy.cpp b/serialization/src/TorchForceProxy.cpp index 4481b7ca..59ee977f 100644 --- a/serialization/src/TorchForceProxy.cpp +++ b/serialization/src/TorchForceProxy.cpp @@ -32,19 +32,60 @@ #include "TorchForceProxy.h" #include "TorchForce.h" #include "openmm/serialization/SerializationNode.h" +#include +#include +#include #include +#include using namespace TorchPlugin; using namespace OpenMM; using namespace std; +static string hexEncode(const string& input) { + stringstream ss; + ss << hex << setfill('0'); + for (const unsigned char& i : input) { + ss << setw(2) << static_cast(i); + } + return ss.str(); +} + +static string hexDecode(const string& input) { + string res; + res.reserve(input.size() / 2); + for (size_t i = 0; i < input.length(); i += 2) { + istringstream iss(input.substr(i, 2)); + uint64_t temp; + iss >> hex >> temp; + res += static_cast(temp); + } + return res; +} + +static string hexEncodeFromFileName(const string& filename) { + ifstream inputFile(filename, ios::binary); + stringstream inputStream; + inputStream << inputFile.rdbuf(); + return hexEncode(inputStream.str()); +} + TorchForceProxy::TorchForceProxy() : SerializationProxy("TorchForce") { } void TorchForceProxy::serialize(const void* object, SerializationNode& node) const { - node.setIntProperty("version", 1); + node.setIntProperty("version", 2); const TorchForce& force = *reinterpret_cast(object); node.setStringProperty("file", force.getFile()); + try { + auto tempFileName = std::tmpnam(nullptr); + force.getModule().save(tempFileName); + node.setStringProperty("encodedFileContents", hexEncodeFromFileName(tempFileName)); + std::remove(tempFileName); + } + catch (exception& ex) { + throw OpenMMException("TorchForceProxy: Could not serialize model. Failed with error: " + string(ex.what())); + } node.setIntProperty("forceGroup", force.getForceGroup()); node.setBoolProperty("usesPeriodic", force.usesPeriodicBoundaryConditions()); node.setBoolProperty("outputsForces", force.getOutputsForces()); @@ -55,9 +96,21 @@ void TorchForceProxy::serialize(const void* object, SerializationNode& node) con } void* TorchForceProxy::deserialize(const SerializationNode& node) const { - if (node.getIntProperty("version") != 1) + int storedVersion = node.getIntProperty("version"); + if (storedVersion > 2) throw OpenMMException("Unsupported version number"); - TorchForce* force = new TorchForce(node.getStringProperty("file")); + TorchForce* force; + if (storedVersion == 1) { + string fileName = node.getStringProperty("file"); + force = new TorchForce(fileName); + } else { + const string storedEncodedFile = node.getStringProperty("encodedFileContents"); + string fileName = tmpnam(nullptr); // A unique filename + ofstream(fileName) << hexDecode(storedEncodedFile); + auto model = torch::jit::load(fileName); + std::remove(fileName.c_str()); + force = new TorchForce(model); + } if (node.hasProperty("forceGroup")) force->setForceGroup(node.getIntProperty("forceGroup", 0)); if (node.hasProperty("usesPeriodic")) diff --git a/serialization/tests/TestSerializeTorchForce.cpp b/serialization/tests/TestSerializeTorchForce.cpp index bbde9ec9..eb5f65cd 100644 --- a/serialization/tests/TestSerializeTorchForce.cpp +++ b/serialization/tests/TestSerializeTorchForce.cpp @@ -34,18 +34,17 @@ #include "openmm/internal/AssertionUtilities.h" #include "openmm/serialization/XmlSerializer.h" #include +#include #include - +#include +#include using namespace TorchPlugin; using namespace OpenMM; using namespace std; extern "C" void registerTorchSerializationProxies(); -void testSerialization() { - // Create a Force. - - TorchForce force("module.pt"); +void serializeAndDeserialize(TorchForce force) { force.setForceGroup(3); force.addGlobalParameter("x", 1.3); force.addGlobalParameter("y", 2.221); @@ -61,7 +60,11 @@ void testSerialization() { // Compare the two forces to see if they are identical. TorchForce& force2 = *copy; - ASSERT_EQUAL(force.getFile(), force2.getFile()); + ostringstream bufferModule; + force.getModule().save(bufferModule); + ostringstream bufferModule2; + force2.getModule().save(bufferModule2); + ASSERT_EQUAL(bufferModule.str(), bufferModule2.str()); ASSERT_EQUAL(force.getForceGroup(), force2.getForceGroup()); ASSERT_EQUAL(force.getNumGlobalParameters(), force2.getNumGlobalParameters()); for (int i = 0; i < force.getNumGlobalParameters(); i++) { @@ -72,12 +75,26 @@ void testSerialization() { ASSERT_EQUAL(force.getOutputsForces(), force2.getOutputsForces()); } +void testSerializationFromModule() { + string fileName = "../../tests/forces.pt"; + torch::jit::Module module = torch::jit::load(fileName); + TorchForce force(module); + serializeAndDeserialize(force); +} + +void testSerializationFromFile() { + string fileName = "../../tests/forces.pt"; + TorchForce force(fileName); + serializeAndDeserialize(force); +} + int main() { try { registerTorchSerializationProxies(); - testSerialization(); + testSerializationFromFile(); + testSerializationFromModule(); } - catch(const exception& e) { + catch (const exception& e) { cout << "exception: " << e.what() << endl; return 1; }