-
Notifications
You must be signed in to change notification settings - Fork 78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simple TorchScript test fails #219
Comments
Saving and loading the model each iteration solves the issue, and it is in fact faster than not doing it: for i in range(0, 10):
print("Running iteration {}".format(i))
import io
buffer = io.BytesIO()
torch.jit.save(model, buffer)
buffer.seek(0)
model = torch.jit.load(buffer).to(device="cuda")
y, neg_dy = model(z, pos, batch) |
This is the offending line. Making tensor_m = tensor or using 0 layers makes the problem go away. torchmd-net/torchmdnet/models/tensornet.py Line 331 in fd83954
Replacing it by this does NOT solve the issue: tensor_m = torch.zeros_like(tensor).scatter_add(0, (edge_index[0])[:,None, None, None].expand_as(msg), msg) But my suspicion is scatter_add is the same thing being used by scatter anyway... if self.derivative:
grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(y)]
dy = grad(
[y],
[pos],
grad_outputs=grad_outputs,
create_graph=False,
retain_graph=False,
)[0] |
5 is a magic number we have encountered before when dealing with CUDA torch models. See for instance openmm/openmm-torch#122 I just made the connection with this issue now while reading about torch.jit fusion operation mechanism with NVFuser: What I believe is going on here is that there is some kind of collision between jit.script, the nvfuser backend and autograd. |
This issue goes away with pytorch>2.0.0, so I am going to assume it is a bug there. Will leave this open until there is a conda-forge package for pytorch 2.1. |
Closing this as there is a package for 2.1 already. |
This test will run exactly 4 iterations and then print an error:
The error:
The text was updated successfully, but these errors were encountered: