-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a constructor to TorchForce that takes a torch::jit::Module #97
Conversation
TorchForce(string fileName) is implemented by delegating to the new constructor. Update serialization test accordingly to compare the module file name and the module itself.
It is not possible yet to use this constructor from the Python side. I do not know how to expose this constructor to Python, which requires telling swig how to transform the object returned by |
I posted on the PyTorch discussion forum asking for help about how to create the wrapper. |
Thanks Peter, I actually posted a pybind11 version of your question a few hours before you to the pytorch forums. I figured it would help me understand how to take it to swig later on, since I am familiar with pybind. |
I found some functions to go from a PyObject* to a Module, but I am clueless SWIG-side. %typemap(in) torch::jit::Module {
py::object o = py::reinterpret_borrow<py::object>($input);
torch::jit::Module module = torch::jit::as_module(o).value();
$1 = module;
} This compiles, but my typemap is clearly not being used, as I get a warning:
Then trying to run from Python: >>> import openmmtorch as ot; import torch as pt; ot.TorchForce(pt.jit.load("forces.pt"))
Warning: importing 'simtk.openmm' is deprecated. Import 'openmm' instead.
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/shared/raul/mambaforge/envs/openmm/lib/python3.9/site-packages/openmmtorch.py", line 70, in __init__
_openmmtorch.TorchForce_swiginit(self, _openmmtorch.new_TorchForce(*args))
TypeError: Wrong number or type of arguments for overloaded function 'new_TorchForce'.
Possible C/C++ prototypes are:
TorchPlugin::TorchForce::TorchForce(std::string const &)
TorchPlugin::TorchForce::TorchForce(torch::jit::Module) (I change the constructor to take an instance instead of a const ref to play around with this) |
Adding a typecheck like this does the trick:
|
I think you need to write it for ScriptModule, not Module? |
See is the definition of the as_module function in torch. |
The swig typecheck(in) I added allow to use the import openmmtorch as ot
import torch as pt
m=pt.jit.load("forces.pt")
f = ot.TorchForce(m) # This is ok
m2 = f.getModule()
f2 = ot.TorchForce(m2) # This fails The actual types in python are printed as: >>> print(m)
RecursiveScriptModule(original_name=Forces)
>>> print(m2)
ScriptObject |
In the typemap %typemap(out) const torch::jit::Module&{
py::object o = py::cast(const_cast<torch::jit::Module*>($1));
$result = o.release().ptr();
} is there some way of telling pybind11 what Python class the result should be represented as? |
I have been researching and my conclusion is that there is simply no way to go from C++ torch::jit::Module to the corresponding Python type. Although undocumented, the opposite indirection is possible. %typemap(out) const torch::jit::Module&{
auto fileName = std::tmpnam(nullptr);
$1->save(fileName);
$result = py::module::import("torch.jit").attr("load")(fileName).release().ptr();
std::remove(fileName);
} With this it is now possible to do the following import openmmtorch as ot
import torch
m=torch.jit.load("forces.pt")
f=ot.TorchForce(m)
m2= f.getModule()
f2 = ot.TorchForce(m2)
print(m)
print(m2)
print(m==m2)
print(m.save_to_buffer() == m2.save_to_buffer()) Which outputs: RecursiveScriptModule(original_name=Forces)
RecursiveScriptModule(original_name=Forces)
False
True I believe this is a sensible approach. There is only one downside I can think of, which is the fact that you cannot modify the module being used by TorchForce after its creation, i.e |
Now it is possible to call getModule() on a TorchForce object from Python, which will return a module of the same type as, for instance, torch.jit.load()
I merged my commits from #95. Now all serialization tests are passing and the issue from my previous message is fixed. I think the only dangling issue now is dealing with the base64 encoding. In my current commits the functions |
What about using hex encoding instead of base64? It is more lengthy, but not that it matters. Here is a possible implementation: using namespace std;
string hexEncode(const string &input) {
stringstream ss;
ss << hex << setfill('0');
for (const unsigned char &i : input) {
ss << setw(2) << static_cast<uint64_t>(i);
}
return ss.str();
}
string hexDecode(const string &input) {
string res;
res.reserve(input.size() / 2);
for (size_t i = 0; i < input.length(); i += 2) {
istringstream iss(input.substr(i, 2));
uint64_t temp;
iss >> hex >> temp;
res += static_cast<unsigned char>(temp);
}
return res;
}
string hexEncodeFromFileName(const string &filename) {
ifstream input_file(filename, ios::binary);
stringstream input_stream;
input_stream << input_file.rdbuf();
return hexEncode(input_stream.str());
}
int main() {
string encoded = hexEncodeFromFileName("input.txt");
auto decoded = hexDecode(encoded);
ofstream("output.txt", ios::binary) << decoded;
ifstream input_file("input.txt", ios::binary);
stringstream input_stream;
input_stream << input_file.rdbuf();
assert(input_stream.str() == decoded);
return 0;
} Its just a few lines of standard C++ library calls. Seems to work with the example model files in the repo. |
SSL no longer a direct dependency.
Hex seems like a fine solution. Models are unlikely to be so large that the extra space is an issue. |
Then I would say all functionality has been included and tested, this is ready to merge. Please review. |
This is looking really good. I made a few suggestions, but they're all very minor ones. |
Thanks Peter, I agree with your suggestions and addressed them in the latest commits. |
I think this is ready to merge. Thanks for the really nice new feature! |
With the changes in this pull request
TorchForce(const string& fileName)
is implemented by delegating to a new constructorTorchForce(const torch::jit::Module&)
.On construction
TorchForce
now callstorch::jit::load
which introduced, C++ side, the restriction that when using the string constructor the file has to exists.Previously
TorchForce
could be created by providing a non-existent module file name. This was being abused in the C++ serialization test, which I had to modify accordingly.I also added a comparison check for the module itself (in addition to the file name) to this test.
When the instance is constructed with a module directly,
TorchForce::getFile()
will return an empty string.Regardless of the chosen constructor,
TorchForce
holds both a file name (even if empty) and a model.Related to #66 .