Skip to content

Commit e196c8d

Browse files
Torch2 compatibility (#106)
* Update ci * port the latest CI versions to the self-hosted runner * forgot to remove bit from version I copy * see if cuda 11.7 works * Update swig typecheck to use py::isinstance instead of as_module * run CUDA tests on GPU runner * Update TorchForce swig input typemap to avoid as_module * Fix formatting * Switch to using micromamba * Make sure temporary files are deleted even if there was an error * Update swig typemaps to not use a temp file * Revert "Merge remote-tracking branch 'origin/feat/add_aws_gpu_runner' into torch2" This reverts commit 0a58de4, reversing changes made to a1c16c6. * run gpu tests on this branch * actually test with pytorch2 * Revert "actually test with pytorch2" This reverts commit 4e7a4fb. * Revert "Merge remote-tracking branch 'origin/feat/add_aws_gpu_runner' into torch2" This reverts commit 0a58de4, reversing changes made to a1c16c6. --------- Co-authored-by: Mike Henry <[email protected]>
1 parent b76deb4 commit e196c8d

File tree

2 files changed

+18
-13
lines changed

2 files changed

+18
-13
lines changed

.github/workflows/CI.yml

+4-4
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,13 @@ jobs:
3333
pytorch-version: "1.11.*"
3434

3535
# Latest supported versions
36-
- name: Linux (CUDA 11.2, Python 3.10, PyTorch 1.12)
36+
- name: Linux (CUDA 11.8, Python 3.10, PyTorch 2.0)
3737
os: ubuntu-22.04
38-
cuda-version: "11.2.2"
38+
cuda-version: "11.8.0"
3939
gcc-version: "10.3.*"
40-
nvcc-version: "11.2"
40+
nvcc-version: "11.8"
4141
python-version: "3.10"
42-
pytorch-version: "1.12.*"
42+
pytorch-version: "2.0.*"
4343

4444
- name: MacOS (Python 3.9, PyTorch 1.9)
4545
os: macos-11

python/openmmtorch.i

+14-9
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "openmm/RPMDIntegrator.h"
1515
#include "openmm/RPMDMonteCarloBarostat.h"
1616
#include <torch/csrc/jit/python/module_python.h>
17+
#include <torch/csrc/jit/serialization/import.h>
1718
%}
1819

1920
/*
@@ -28,23 +29,27 @@
2829
}
2930
}
3031

31-
%typemap(in) const torch::jit::Module&(torch::jit::Module module) {
32+
%typemap(in) const torch::jit::Module&(torch::jit::Module mod) {
3233
py::object o = py::reinterpret_borrow<py::object>($input);
33-
module = torch::jit::as_module(o).value();
34-
$1 = &module;
34+
py::object pybuffer = py::module::import("io").attr("BytesIO")();
35+
py::module::import("torch.jit").attr("save")(o, pybuffer);
36+
std::string s = py::cast<std::string>(pybuffer.attr("getvalue")());
37+
std::stringstream buffer(s);
38+
mod = torch::jit::load(buffer);
39+
$1 = &mod;
3540
}
3641

3742
%typemap(out) const torch::jit::Module& {
38-
auto fileName = std::tmpnam(nullptr);
39-
$1->save(fileName);
40-
$result = py::module::import("torch.jit").attr("load")(fileName).release().ptr();
41-
//This typemap assumes that torch does not require the file to exist after construction
42-
std::remove(fileName);
43+
std::stringstream buffer;
44+
$1->save(buffer);
45+
auto pybuffer = py::module::import("io").attr("BytesIO")(py::bytes(buffer.str()));
46+
$result = py::module::import("torch.jit").attr("load")(pybuffer).release().ptr();
4347
}
4448

4549
%typecheck(SWIG_TYPECHECK_POINTER) const torch::jit::Module& {
4650
py::object o = py::reinterpret_borrow<py::object>($input);
47-
$1 = torch::jit::as_module(o).has_value() ? 1 : 0;
51+
py::handle ScriptModule = py::module::import("torch.jit").attr("ScriptModule");
52+
$1 = py::isinstance(o, ScriptModule);
4853
}
4954

5055
namespace std {

0 commit comments

Comments
 (0)