Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a constructor to TorchForce that takes a torch::jit::Module #97

Merged
merged 34 commits into from
Feb 10, 2023

Conversation

RaulPPelaez
Copy link
Contributor

@RaulPPelaez RaulPPelaez commented Jan 30, 2023

With the changes in this pull request TorchForce(const string& fileName) is implemented by delegating to a new constructor TorchForce(const torch::jit::Module&).
On construction TorchForce now calls torch::jit::load which introduced, C++ side, the restriction that when using the string constructor the file has to exists.

Previously TorchForce could be created by providing a non-existent module file name. This was being abused in the C++ serialization test, which I had to modify accordingly.
I also added a comparison check for the module itself (in addition to the file name) to this test.

When the instance is constructed with a module directly, TorchForce::getFile() will return an empty string.
Regardless of the chosen constructor, TorchForce holds both a file name (even if empty) and a model.

Related to #66 .

@RaulPPelaez
Copy link
Contributor Author

RaulPPelaez commented Jan 30, 2023

It is not possible yet to use this constructor from the Python side. I do not know how to expose this constructor to Python, which requires telling swig how to transform the object returned by torch.jit.script(model) in Python to the equivalent torch::jit::Module in C++.
As far as I understand from the docs, the only way to do this is to serialize the model in Python and deserialize it in C++.
But again, I do not know how to tell swig to do that.
Any suggestions about this?

@peastman
Copy link
Member

This page gives an example of passing a PyTorch module from Python to C++.

A complication is that PyTorch uses pybind11 to interface between Python and C++, while OpenMM uses SWIG. This page has some links to examples of interfacing between pybind11 and SWIG.

@peastman
Copy link
Member

I posted on the PyTorch discussion forum asking for help about how to create the wrapper.

@RaulPPelaez
Copy link
Contributor Author

Thanks Peter, I actually posted a pybind11 version of your question a few hours before you to the pytorch forums. I figured it would help me understand how to take it to swig later on, since I am familiar with pybind.

@RaulPPelaez
Copy link
Contributor Author

I found some functions to go from a PyObject* to a Module, but I am clueless SWIG-side.
Maybe you can help. I defined the following typemap in openmmtorch.i:

%typemap(in) torch::jit::Module {
  py::object o = py::reinterpret_borrow<py::object>($input);
  torch::jit::Module module = torch::jit::as_module(o).value();
  $1 = module;
}

This compiles, but my typemap is clearly not being used, as I get a warning:

openmm-torch/python/openmmtorch.i:40: Warning 472: Overloaded method TorchPlugin::TorchForce::TorchForce(torch::jit::Module) with no explicit typecheck typemap 
for arg 0 of type 'torch::jit::Module'   

Then trying to run from Python:

>>> import openmmtorch as ot; import torch as pt; ot.TorchForce(pt.jit.load("forces.pt"))
Warning: importing 'simtk.openmm' is deprecated.  Import 'openmm' instead.
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/shared/raul/mambaforge/envs/openmm/lib/python3.9/site-packages/openmmtorch.py", line 70, in __init__
    _openmmtorch.TorchForce_swiginit(self, _openmmtorch.new_TorchForce(*args))
TypeError: Wrong number or type of arguments for overloaded function 'new_TorchForce'.
  Possible C/C++ prototypes are:
    TorchPlugin::TorchForce::TorchForce(std::string const &)
    TorchPlugin::TorchForce::TorchForce(torch::jit::Module)

(I change the constructor to take an instance instead of a const ref to play around with this)

@RaulPPelaez
Copy link
Contributor Author

Adding a typecheck like this does the trick:

%typecheck(SWIG_TYPECHECK_POINTER) torch::jit::Module {
  py::object o = py::reinterpret_borrow<py::object>($input);
  $1 = torch::jit::as_module(o).has_value()?1:0;
}

@peastman
Copy link
Member

I think you need to write it for ScriptModule, not Module?

@RaulPPelaez
Copy link
Contributor Author

See is the definition of the as_module function in torch.
I believe torch::jit::Module corresponds to ScriptModule

@RaulPPelaez
Copy link
Contributor Author

The swig typecheck(in) I added allow to use the TorchForce(const torch::jit::Module&) constructor from python.
I also added typecheck(out) to be able to call TorchForce::getModule(), but this function returns a type that is not the same as the one returned by torch.jit.load("model.pt"). So that one cannot do:

import openmmtorch as ot
import torch as pt
m=pt.jit.load("forces.pt")
f = ot.TorchForce(m) # This is ok
m2 = f.getModule()
f2 = ot.TorchForce(m2) # This fails

The actual types in python are printed as:

>>> print(m)
RecursiveScriptModule(original_name=Forces)
>>> print(m2)
ScriptObject

@peastman
Copy link
Member

peastman commented Feb 1, 2023

ScriptObject doesn't appear anywhere in the API docs. It looks like it's torch._C.ScriptObject, which subclasses pybind11_builtins.pybind11_object. That suggests it's an internal class created by pybind11, which isn't supposed to be visible to users.

In the typemap

%typemap(out) const torch::jit::Module&{
    py::object o = py::cast(const_cast<torch::jit::Module*>($1));
    $result = o.release().ptr();
}

is there some way of telling pybind11 what Python class the result should be represented as?

@RaulPPelaez
Copy link
Contributor Author

I have been researching and my conclusion is that there is simply no way to go from C++ torch::jit::Module to the corresponding Python type. Although undocumented, the opposite indirection is possible.
The functionality is just not there (and the documentation for the C++ python bindings of torch is really scarce).
If we adhere to what is documented and part of the API there is only one way to do this, serialize the module in C++ and deserialize in Python. Pybind allows to plug this into a swig typecheck, like this:

%typemap(out) const torch::jit::Module&{
  auto fileName = std::tmpnam(nullptr);
  $1->save(fileName);
  $result = py::module::import("torch.jit").attr("load")(fileName).release().ptr();
  std::remove(fileName);
}

With this it is now possible to do the following

import openmmtorch as ot
import torch
m=torch.jit.load("forces.pt")
f=ot.TorchForce(m)
m2= f.getModule()
f2 = ot.TorchForce(m2)
print(m)
print(m2)
print(m==m2)
print(m.save_to_buffer() == m2.save_to_buffer())

Which outputs:

RecursiveScriptModule(original_name=Forces)
RecursiveScriptModule(original_name=Forces)
False
True

I believe this is a sensible approach. There is only one downside I can think of, which is the fact that you cannot modify the module being used by TorchForce after its creation, i.e getModule() gives you a clone, not a reference.
I am not sure, however, if the desired API seeks to guarantee this.
What do you think?

 Now it is possible to call getModule() on a TorchForce object from
 Python, which will return a module of the same type as, for instance, torch.jit.load()
@RaulPPelaez
Copy link
Contributor Author

I merged my commits from #95. Now all serialization tests are passing and the issue from my previous message is fixed.

I think the only dangling issue now is dealing with the base64 encoding. In my current commits the functions EVP_EncodeBlock and EVP_DecodeBlock from OpenSSL are used, but you wrote that adding OpenSSL as an explicit dependency was not desirable.
Can you think of some alternative? I do not think rolling a new encoding function from scratch is a good idea. It is a tricky algorithm already implemented by many other projects.

@RaulPPelaez
Copy link
Contributor Author

What about using hex encoding instead of base64? It is more lengthy, but not that it matters. Here is a possible implementation:

using namespace std;

string hexEncode(const string &input) {
  stringstream ss;
  ss << hex << setfill('0');
  for (const unsigned char &i : input) {
    ss << setw(2) << static_cast<uint64_t>(i);
  }
  return ss.str();
}

string hexDecode(const string &input) {
  string res;
  res.reserve(input.size() / 2);
  for (size_t i = 0; i < input.length(); i += 2) {
    istringstream iss(input.substr(i, 2));
    uint64_t temp;
    iss >> hex >> temp;
    res += static_cast<unsigned char>(temp);
  }
  return res;
}

string hexEncodeFromFileName(const string &filename) {
  ifstream input_file(filename, ios::binary);
  stringstream input_stream;
  input_stream << input_file.rdbuf();
  return hexEncode(input_stream.str());
}

int main() {
  string encoded = hexEncodeFromFileName("input.txt");
  auto decoded = hexDecode(encoded);
  ofstream("output.txt", ios::binary) << decoded;
  ifstream input_file("input.txt", ios::binary);
  stringstream input_stream;
  input_stream << input_file.rdbuf();
  assert(input_stream.str() == decoded);
  return 0;
}

Its just a few lines of standard C++ library calls. Seems to work with the example model files in the repo.

@peastman
Copy link
Member

peastman commented Feb 6, 2023

Hex seems like a fine solution. Models are unlikely to be so large that the extra space is an issue.

@RaulPPelaez
Copy link
Contributor Author

Then I would say all functionality has been included and tested, this is ready to merge. Please review.

@peastman
Copy link
Member

peastman commented Feb 7, 2023

This is looking really good. I made a few suggestions, but they're all very minor ones.

@RaulPPelaez
Copy link
Contributor Author

Thanks Peter, I agree with your suggestions and addressed them in the latest commits.

@peastman
Copy link
Member

I think this is ready to merge. Thanks for the really nice new feature!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants