|
14 | 14 | #include "openmm/RPMDIntegrator.h"
|
15 | 15 | #include "openmm/RPMDMonteCarloBarostat.h"
|
16 | 16 | #include <torch/csrc/jit/python/module_python.h>
|
| 17 | +#include <torch/csrc/jit/serialization/import.h> |
17 | 18 | %}
|
18 | 19 |
|
19 | 20 | /*
|
|
28 | 29 | }
|
29 | 30 | }
|
30 | 31 |
|
31 |
| -%typemap(in) const torch::jit::Module&(torch::jit::Module module) { |
| 32 | +%typemap(in) const torch::jit::Module&(torch::jit::Module mod) { |
32 | 33 | py::object o = py::reinterpret_borrow<py::object>($input);
|
33 |
| - module = torch::jit::as_module(o).value(); |
34 |
| - $1 = &module; |
| 34 | + py::object pybuffer = py::module::import("io").attr("BytesIO")(); |
| 35 | + py::module::import("torch.jit").attr("save")(o, pybuffer); |
| 36 | + std::string s = py::cast<std::string>(pybuffer.attr("getvalue")()); |
| 37 | + std::stringstream buffer(s); |
| 38 | + mod = torch::jit::load(buffer); |
| 39 | + $1 = &mod; |
35 | 40 | }
|
36 | 41 |
|
37 | 42 | %typemap(out) const torch::jit::Module& {
|
38 |
| - auto fileName = std::tmpnam(nullptr); |
39 |
| - $1->save(fileName); |
40 |
| - $result = py::module::import("torch.jit").attr("load")(fileName).release().ptr(); |
41 |
| - //This typemap assumes that torch does not require the file to exist after construction |
42 |
| - std::remove(fileName); |
| 43 | + std::stringstream buffer; |
| 44 | + $1->save(buffer); |
| 45 | + auto pybuffer = py::module::import("io").attr("BytesIO")(py::bytes(buffer.str())); |
| 46 | + $result = py::module::import("torch.jit").attr("load")(pybuffer).release().ptr(); |
43 | 47 | }
|
44 | 48 |
|
45 | 49 | %typecheck(SWIG_TYPECHECK_POINTER) const torch::jit::Module& {
|
46 | 50 | py::object o = py::reinterpret_borrow<py::object>($input);
|
47 |
| - $1 = torch::jit::as_module(o).has_value() ? 1 : 0; |
| 51 | + py::handle ScriptModule = py::module::import("torch.jit").attr("ScriptModule"); |
| 52 | + $1 = py::isinstance(o, ScriptModule); |
48 | 53 | }
|
49 | 54 |
|
50 | 55 | namespace std {
|
|
0 commit comments