Skip to content

Commit

Permalink
Test if a PyTorch module receives corrects arguments (#50)
Browse files Browse the repository at this point in the history
* Test if a PyTorch module receives corrects arguments

* Skip if CUDA is not available
  • Loading branch information
Raimondas Galvelis authored Jan 8, 2022
1 parent 77d6ac3 commit 4614813
Showing 1 changed file with 56 additions and 0 deletions.
56 changes: 56 additions & 0 deletions python/tests/TestTorchForce.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import openmmtorch as ot
import numpy as np
import unittest
import pytest
import torch as pt
from tempfile import NamedTemporaryFile

class TestTorchForce(unittest.TestCase):

Expand Down Expand Up @@ -31,6 +34,59 @@ def testForce(self):
assert np.allclose(-2*positions, state.getForces(asNumpy=True))


@pytest.mark.parametrize('deviceString', ['cpu', 'cuda:0', 'cuda:1'])
@pytest.mark.parametrize('precision', ['single', 'mixed', 'double'])
def testModuleArguments(deviceString, precision):

if pt.cuda.device_count() < 1 and deviceString == 'cuda:0':
pytest.skip('A CUDA device is not available')
if pt.cuda.device_count() < 2 and deviceString == 'cuda:1':
pytest.skip('Two CUDA devices are not available')

class TestModule(pt.nn.Module):

def __init__(self, device, dtype, positions):
super().__init__()
self.device = device
self.dtype = dtype
self.positions = pt.tensor(positions).to(self.device).to(self.dtype)

def forward(self, positions):
assert positions.device == self.device
assert positions.dtype == self.dtype
assert pt.all(positions == self.positions)
return pt.sum(positions)

with NamedTemporaryFile() as fd:

numParticles = 10
system = mm.System()
positions = np.random.rand(numParticles, 3)
for _ in range(numParticles):
system.addParticle(1.0)

device = pt.device(deviceString)
if device.type == 'cpu' or precision == 'double':
dtype = pt.float64
else:
dtype = pt.float32
module = TestModule(device, dtype, positions)
pt.jit.script(module).save(fd.name)
force = ot.TorchForce(fd.name)
system.addForce(force)

integrator = mm.VerletIntegrator(1.0)
platform = mm.Platform.getPlatformByName(device.type.upper())
properties = {}
if device.type == 'cuda':
properties['DeviceIndex'] = str(device.index)
properties['Precision'] = precision
context = mm.Context(system, integrator, platform, properties)

context.setPositions(positions)
context.getState(getEnergy=True, getForces=True)


if __name__ == '__main__':
unittest.main()

0 comments on commit 4614813

Please sign in to comment.