-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Model tensors not being loaded correctly #43
Comments
@larsbratholm thank you for reporting the issue and giving very detailed info how to reproduce it. Recently, I have observed similar issues and was investigating ,but due to the holiday season probably won't be much progress until September. |
@larsbratholm I have investigated the issues:
class ParallelModel(torch.nn.Module):
"""
Simple model where the energy is a sum of two linear
transformations of the positions.
Each linear model is on a different device.
"""
def __init__(self, device0, device1):
super().__init__()
self.device0 = device0
self.device1 = device1
self.l0 = torch.nn.Linear(3 * 23, 1).to(self.device0)
self.l1 = torch.nn.Linear(3 * 23, 1).to(self.device1)
def forward(self, positions):
flattened_float_positions = positions.flatten().to(torch.float)
futures0 = torch.jit.fork(self.l0, flattened_float_positions.to(self.device0))
futures1 = torch.jit.fork(self.l1, flattened_float_positions.to(self.device1))
energy = torch.jit.wait(futures0) + torch.jit.wait(futures1).to("cuda:0")
return energy
def run_model_loaded_jit(model, pdb):
"""
Run the model after jitting it, saving and reloading
"""
positions = positions_from_pdb(pdb)
jit_model = torch.jit.script(model)
jit_model.save("model.pt")
del jit_model
loaded_jit_model = torch.jit.load("model.pt").cuda()
loaded_jit_model(positions.float())
print("Succesfully ran loaded model with jit") The problem is that the TorchScirpt type inference is confused, so you need to type case the position tensor explicitly: class DoubleModel(torch.nn.Module):
"""
Simple model where the energy is a linear
transformations of the positions.
The linear model uses double precision.
"""
def __init__(self):
super().__init__()
self.l0 = torch.nn.Linear(3 * 23, 1).to(torch.double).cuda()
def forward(self, positions):
positions = positions.to(torch.double)
energy = self.l0(positions.flatten())
print(self.l0.weight)
return energy And
|
Would make sense that if openmm-torch adds the
Is it possible to put this under the control of the user instead, so that devices is not reassigned? The suggested change to the double example does solve the issue. I am surprised that it is only an issue when the torchscript model is loaded through openmm-torch, but I presume that there must be some differences in how the jit models is loaded in pytorch and in C++. |
I think it's just a difference in the dtype of the input. The dtype of |
Yes, OpenMM-Torch moves the model and all its parameters to the device, where it will be running.
No, OpenMM-Torch doesn't support yet the multi-device execution. So all the tensors have to be on the same device. |
I wrote some tests (#50) to check if argument to the PyTorch model are passed as intended. We need to update the documentation too, so the users won't be confused. |
Ok, thanks for the clarification. The behavior seems intentional now and I will close the issue as there doesn't seem to be a bug after all. |
I am not sure if this is an issue with how PyTorch saves Python JitScript models and loads them in C++, or if it is openmm-torch specific, but I am having issues with model tensors not being loaded correctly in openmm-torch. This is all tested with openmm-torch 0.03, OpenMM 7.5.1 and pytorch 1.8 from conda-forge.
A simple example is if the torch model includes some parameters that are converted to double precision during initialization:
If I initialize the model and pass a positions tensor,
self.l0
is still double precision. This is still true if I compile the model with jit script, as well as if I save the compiled model and reload it in python. However if I load the same model with TorchForce and add the force to the OpenMM system,self.l0
has been converted to single precision (even though the positions tensor passed by OpenMM is double). This is also observed when parameters are assigned to different cuda devices, where they will all be assigned to cuda:0 once loaded by openmm-torch.I have created a repository that reproduce both the issue with precision and cuda devices described above, which can be found here.
The text was updated successfully, but these errors were encountered: