-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a constructor to TorchForce that takes a torch::jit::Module (#97)
* Add version number as a member to TorchForceProxy * Encode the model file contents when serializing TorchForce * Add tests for new TorchForce serialization * Fix test not finding Python executable * Format include directives correctly * Hardcode TorchForceProxy version number * Fix formatting issues * Move Python serialization test to the correct place * Make function encodeFromFileName static * Update serialization python test to correctly remove temporary files after executing * Use the base64 encoding capabilities of openssl to serialize model file * Update TorchForce serializer * Add a constructor to TorchForce that takes a torch::jit::Module. TorchForce(string fileName) is implemented by delegating to the new constructor. Update serialization test accordingly to compare the module file name and the module itself. * Remove unnecessary include * Change i_file to file in TorchForce constructor * Add swig typemaps to new TorchForce constructor * Add setup.py as a dependency for the PythonInstall CMake rule * Fix swig out typemap for torch::jit::Module Now it is possible to call getModule() on a TorchForce object from Python, which will return a module of the same type as, for instance, torch.jit.load() * Remove commented line in CMakeLists.txt * Remove unnecessary dependency in setup.py * Add more tests for new constructor * Add some comments for the new constructor * Updates to TorchForce serialization * Use hex encoding instead of base64 for serialization. SSL no longer a direct dependency. * Remove unnecessary header * Update Python serialization test * Minor changes * Improve temporary path handling in python serialization tests * More informative exception when failing to serialize TorchForce * Remove unnecessary check in TorchForce serialization * Changes to C++ serialization tests * Changes to C++ serialization tests
- Loading branch information
1 parent
c76684f
commit 769302a
Showing
13 changed files
with
221 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import torch | ||
import shutil | ||
import pytest | ||
from openmm import XmlSerializer, OpenMMException | ||
from openmmtorch import TorchForce | ||
import os | ||
import tempfile | ||
|
||
class ForceModule(torch.nn.Module): | ||
"""A simple module that can be serialized""" | ||
def forward(self, positions): | ||
return torch.sum(positions**2) | ||
|
||
|
||
class ForceModule2(torch.nn.Module): | ||
"""A dummy module distict from ForceModule""" | ||
def forward(self, positions): | ||
return torch.sum(positions**3) | ||
|
||
|
||
def createAndSerialize(model_filename, serialized_filename): | ||
module = torch.jit.script(ForceModule()) | ||
module.save(model_filename) | ||
torch_force = TorchForce(model_filename) | ||
stored = XmlSerializer.serialize(torch_force) | ||
with open(serialized_filename, 'w') as f: | ||
f.write(stored) | ||
|
||
def readXML(filename): | ||
with open(filename, 'r') as f: | ||
fileContents = f.read() | ||
return fileContents | ||
|
||
def deserialize(filename): | ||
other_force = XmlSerializer.deserialize(readXML(filename)) | ||
|
||
def test_serialize(): | ||
with tempfile.TemporaryDirectory() as tempdir: | ||
model_filename = os.path.join(tempdir, 'model.pt') | ||
serialized_filename = os.path.join(tempdir, 'stored.xml') | ||
createAndSerialize(model_filename, serialized_filename) | ||
|
||
def test_deserialize(): | ||
with tempfile.TemporaryDirectory() as tempdir: | ||
model_filename = os.path.join(tempdir, 'model.pt') | ||
serialized_filename = os.path.join(tempdir, 'stored.xml') | ||
createAndSerialize(model_filename, serialized_filename) | ||
deserialize(serialized_filename) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.