Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model tensors not being loaded correctly #43

Closed
larsbratholm opened this issue Jul 21, 2021 · 7 comments
Closed

Model tensors not being loaded correctly #43

larsbratholm opened this issue Jul 21, 2021 · 7 comments
Labels
documentation Improvements or additions to documentation

Comments

@larsbratholm
Copy link

I am not sure if this is an issue with how PyTorch saves Python JitScript models and loads them in C++, or if it is openmm-torch specific, but I am having issues with model tensors not being loaded correctly in openmm-torch. This is all tested with openmm-torch 0.03, OpenMM 7.5.1 and pytorch 1.8 from conda-forge.

A simple example is if the torch model includes some parameters that are converted to double precision during initialization:

self.l0 = torch.nn.Linear(n_in, 1).to(torch.double).cuda()

If I initialize the model and pass a positions tensor, self.l0 is still double precision. This is still true if I compile the model with jit script, as well as if I save the compiled model and reload it in python. However if I load the same model with TorchForce and add the force to the OpenMM system, self.l0 has been converted to single precision (even though the positions tensor passed by OpenMM is double). This is also observed when parameters are assigned to different cuda devices, where they will all be assigned to cuda:0 once loaded by openmm-torch.

I have created a repository that reproduce both the issue with precision and cuda devices described above, which can be found here.

@raimis
Copy link
Contributor

raimis commented Jul 30, 2021

@larsbratholm thank you for reporting the issue and giving very detailed info how to reproduce it.

Recently, I have observed similar issues and was investigating ,but due to the holiday season probably won't be much progress until September.

@raimis
Copy link
Contributor

raimis commented Sep 30, 2021

@larsbratholm I have investigated the issues:

  • OpenMM-Torch is designed to run inference on a single GPU only, so the ParallelModel model won't work as intended and we cannot fix it easily.
class ParallelModel(torch.nn.Module):
    """
    Simple model where the energy is a sum of two linear
    transformations of the positions.
    Each linear model is on a different device.
    """

    def __init__(self, device0, device1):
        super().__init__()
        self.device0 = device0
        self.device1 = device1
        self.l0 = torch.nn.Linear(3 * 23, 1).to(self.device0)
        self.l1 = torch.nn.Linear(3 * 23, 1).to(self.device1)

    def forward(self, positions):
        flattened_float_positions = positions.flatten().to(torch.float)
        futures0 = torch.jit.fork(self.l0, flattened_float_positions.to(self.device0))
        futures1 = torch.jit.fork(self.l1, flattened_float_positions.to(self.device1))

        energy = torch.jit.wait(futures0) + torch.jit.wait(futures1).to("cuda:0")
        return energy
  • The double precision issue. To emulate what is happening in OpenMM-Torch, run_model_loaded_jit should be like this:
def run_model_loaded_jit(model, pdb):
    """
    Run the model after jitting it, saving and reloading
    """
    positions = positions_from_pdb(pdb)
    jit_model = torch.jit.script(model)
    jit_model.save("model.pt")
    del jit_model
    loaded_jit_model = torch.jit.load("model.pt").cuda()
    loaded_jit_model(positions.float())
    print("Succesfully ran loaded model with jit")

The problem is that the TorchScirpt type inference is confused, so you need to type case the position tensor explicitly:

class DoubleModel(torch.nn.Module):
    """
    Simple model where the energy is a linear
    transformations of the positions.
    The linear model uses double precision.
    """

    def __init__(self):
        super().__init__()
        self.l0 = torch.nn.Linear(3 * 23, 1).to(torch.double).cuda()

    def forward(self, positions):
        positions = positions.to(torch.double)
        energy = self.l0(positions.flatten())
        print(self.l0.weight)
        return energy

And print(self.l0.weight) shows that the double tensor isn't converted to a float one.

  • Regarding the device assignment there might be a bug:

This is also observed when parameters are assigned to different cuda devices, where they will all be assigned to cuda:0 once loaded by openmm-torch.
I'm investigating more...

@larsbratholm
Copy link
Author

Would make sense that if openmm-torch adds the .cuda() part behind the scenes to the following line, that this causes all tensors to be reassigned to cuda:0.

loaded_jit_model = torch.jit.load("model.pt").cuda()

Is it possible to put this under the control of the user instead, so that devices is not reassigned?

The suggested change to the double example does solve the issue. I am surprised that it is only an issue when the torchscript model is loaded through openmm-torch, but I presume that there must be some differences in how the jit models is loaded in pytorch and in C++.

@peastman
Copy link
Member

I presume that there must be some differences in how the jit models is loaded in pytorch and in C++.

I think it's just a difference in the dtype of the input. The dtype of positions will be float32 if you tell OpenMM to use single or mixed precision mode, float64 in double precision mode. PyTorch on CUDA is very particular about not mixing dtypes in the arguments to many ops. Running on CPU it's more flexible.

@raimis
Copy link
Contributor

raimis commented Dec 14, 2021

Would make sense that if openmm-torch adds the .cuda() part behind the scenes to the following line, that this causes all tensors to be reassigned to cuda:0.

Yes, OpenMM-Torch moves the model and all its parameters to the device, where it will be running.

Is it possible to put this under the control of the user instead, so that devices is not reassigned?

No, OpenMM-Torch doesn't support yet the multi-device execution. So all the tensors have to be on the same device.

@raimis
Copy link
Contributor

raimis commented Dec 14, 2021

I wrote some tests (#50) to check if argument to the PyTorch model are passed as intended.

We need to update the documentation too, so the users won't be confused.

@raimis raimis added the documentation Improvements or additions to documentation label Dec 14, 2021
@larsbratholm
Copy link
Author

Ok, thanks for the clarification. The behavior seems intentional now and I will close the issue as there doesn't seem to be a bug after all.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

3 participants