Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simple TorchScript test fails #219

Closed
RaulPPelaez opened this issue Sep 21, 2023 · 5 comments
Closed

Simple TorchScript test fails #219

RaulPPelaez opened this issue Sep 21, 2023 · 5 comments

Comments

@RaulPPelaez
Copy link
Collaborator

RaulPPelaez commented Sep 21, 2023

This test will run exactly 4 iterations and then print an error:

def test_really_simple():
    n_atoms=10
    zs = torch.tensor([1, 6, 7, 8, 9], dtype=torch.long)
    z = zs[torch.randint(0, len(zs), (n_atoms,))]
    pos = torch.randn(len(z), 3)
    batch = torch.zeros(len(z), dtype=torch.long)
    batch[len(batch) // 2 :] = 1
    args = {"model": "tensornet",
            "embedding_dimension": 128,
            "num_layers": 2,
            "num_rbf": 32,
            "rbf_type": "expnorm",
            "trainable_rbf": False,
            "activation": "silu",
            "cutoff_lower": 0.0,
            "cutoff_upper": 5.0,
            "max_z": 100,
            "max_num_neighbors": 128,
            "equivariance_invariance_group": "O(3)",
            "prior_model": None,
            "atom_filter": -1,
            "derivative": True,
            "output_model": "Scalar",
            "reduce_op": "sum",
            "precision": 32 }
    model = create_model(args).to(device="cuda")
    z = z.to("cuda")
    pos = pos.to("cuda").requires_grad_(True)
    batch = batch.to("cuda")
    model = torch.jit.script(model).to(device="cuda")
    for i in range(0, 10):
        print("Running iteration {}".format(i))
        y, neg_dy = model(z, pos, batch)

The error:

self = RecursiveScriptModule(
  original_name=TorchMD_Net
  (representation_model): RecursiveScriptModule(
    original_name=...      (1): RecursiveScriptModule(original_name=SiLU)
      (2): RecursiveScriptModule(original_name=Linear)
    )
  )
)
args = (tensor([7, 6, 7, 7, 8, 7, 8, 7, 8, 8], device='cuda:0'), tensor([[-0.6791,  0.2550, -0.8304],
        [-0.4754,  0.60...107, -1.6600,  0.7560]], device='cuda:0', requires_grad=True), tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], device='cuda:0')), kwargs = {}, forward_call = <torch.ScriptMethod object at 0x7f9d363cd4e0>

    def _call_impl(self, *args, **kwargs):
        forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
        # If we don't have any hooks, we want to skip the rest of the logic in
        # this function, and just call forward.
        if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
                or _global_backward_pre_hooks or _global_backward_hooks
                or _global_forward_hooks or _global_forward_pre_hooks):
>           return forward_call(*args, **kwargs)
E           RuntimeError: The following operation failed in the TorchScript interpreter.
E           Traceback of TorchScript (most recent call last):
E             File "/shared/raul/torchmd-net/torchmdnet/models/model.py", line 292, in forward
E                       print("Shape of pos")
E                       print(pos.shape)
E                       dy = grad(
E                            ~~~~ <--- HERE
E                           [y],
E                           [pos],
E           RuntimeError: The following operation failed in the TorchScript interpreter.
E           Traceback of TorchScript (most recent call last):
E           RuntimeError: The following operation failed in the TorchScript interpreter.
E           Traceback of TorchScript (most recent call last):
E           RuntimeError: The size of tensor a (128) must match the size of tensor b (3) at non-singleton dimension 3

../../mambaforge/envs/openmmtorch-test/lib/python3.10/site-packages/torch/nn/modules/module.py:1501: RuntimeError
  • If one does not jit.script the model then no error is produced.
  • The number of atoms does not seem to affect this and I only see it with tensornet.
  • Tried to use the old Distance module to check for problems related to C++ extensions, but the error is the same.
  • Even for something like 10 atoms, each iteration takes increasingly more time, in the order of 5 seconds.
  • 128 is the embedding dimension, changing it to 32 does indeed change the error to 32.
@RaulPPelaez
Copy link
Collaborator Author

Saving and loading the model each iteration solves the issue, and it is in fact faster than not doing it:

    for i in range(0, 10):
        print("Running iteration {}".format(i))
        import io
        buffer = io.BytesIO()
        torch.jit.save(model, buffer)
        buffer.seek(0)
        model = torch.jit.load(buffer).to(device="cuda")
        y, neg_dy = model(z, pos, batch)

@RaulPPelaez
Copy link
Collaborator Author

RaulPPelaez commented Sep 21, 2023

This is the offending line. Making tensor_m = tensor or using 0 layers makes the problem go away.

tensor_m = scatter(msg, edge_index[0], dim=0, dim_size=natoms)

Replacing it by this does NOT solve the issue:

tensor_m = torch.zeros_like(tensor).scatter_add(0, (edge_index[0])[:,None, None, None].expand_as(msg), msg)

But my suspicion is scatter_add is the same thing being used by scatter anyway...
EDIT: Not saving the graph when calling grad also makes the error go away:

        if self.derivative:
            grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(y)]
            dy = grad(
                [y],
                [pos],
                grad_outputs=grad_outputs,
                create_graph=False,
                retain_graph=False,
            )[0]

@RaulPPelaez
Copy link
Collaborator Author

5 is a magic number we have encountered before when dealing with CUDA torch models. See for instance openmm/openmm-torch#122

I just made the connection with this issue now while reading about torch.jit fusion operation mechanism with NVFuser:
https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/codegen/cuda/README.md#general-ideas-of-debug-no-fusion

What I believe is going on here is that there is some kind of collision between jit.script, the nvfuser backend and autograd.

@RaulPPelaez
Copy link
Collaborator Author

This issue goes away with pytorch>2.0.0, so I am going to assume it is a bug there. Will leave this open until there is a conda-forge package for pytorch 2.1.

@RaulPPelaez
Copy link
Collaborator Author

Closing this as there is a package for 2.1 already.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant