Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TorchForce-equipped System object errs upon serialization/deserialization #37

Closed
dominicrufa opened this issue May 27, 2021 · 11 comments · Fixed by #38
Closed

TorchForce-equipped System object errs upon serialization/deserialization #37

dominicrufa opened this issue May 27, 2021 · 11 comments · Fixed by #38

Comments

@dominicrufa
Copy link

it appears that a TorchForce-equipped system loses its TorchForce-specific GlobalParameter after the system is serialized/deserialized. It also seems that when i equip the deserialized system to a context and try to retrieve the energy, i get a NULL error

gist

zipped yaml droplet of conda env:
tre.zip

@peastman
Copy link
Member

This has to do with the details of how SWIG generated wrappers work. Deserialization is implemented in C++. SWIG has to generate a Python wrapper object for each force in the system, but it only knows how to do that for force classes that are built into OpenMM, not ones that are defined by separate plugins. It just returns a generic Force object rather than a TorchForce. Pass that object to TorchForce.cast() to get a TorchForce.

@dominicrufa
Copy link
Author

alright. slight confusion: once i deserialize a TorchForce-equipped system, i am basically replacing the Force object in the system with TorchForce.cast(Force)? If i do this, i get
Exception: the System object does not own its corresponding OpenMM object

@peastman
Copy link
Member

i am basically replacing the Force object in the system with TorchForce.cast(Force)?

I'm not quite sure what you mean by "replacing" it. If you want to call a TorchForce method on it, then instead of writing (for example)

force = system.getForce(0)

you would instead write

force = TorchForce.cast(system.getForce(0))

If i do this, i get Exception: the System object does not own its corresponding OpenMM object

That sounds like something is messed up internally. Can you provide the code that leads to that exception?

@dominicrufa
Copy link
Author

so in this gist, between which cells should I be calling Torchforce.cast? also, should i cast before i serialize the system or after i deserialize (like in In [19-20])? sorry for the confusion.

@peastman
Copy link
Member

I don't see anywhere you would need to cast it. Once you add the TorchForce to the System, you never reference it again.

The Python Force object is just a lightweight wrapper around a C++ object. Initially you create a TorchForce Python object, which wraps a TorchForce C++ object. That's stored in the System. But when you call getForce(), it doesn't know what kind of Python object to create, so it gives you a generic Force object. That's still a wrapper around the same C++ TorchForce object.

@dominicrufa
Copy link
Author

I don't see anywhere you would need to cast it. Once you add the TorchForce to the System, you never reference it again.

right, i didn't think i would have to cast it. so it isnt obvious why i'm getting a null error in cell 25?

@peastman
Copy link
Member

You're running into (this error)[https://github.com/openmm/openmm/issues/3098]. The exception just has the wrong message. Most likely it's a NaN.

@dominicrufa
Copy link
Author

duly noted. but from the gist, I am creating a context, setting postions, and getting a non-nan energy before serializing/deserializing the system. there is only an error after I am creating a context with the deserialized system (and setting the same, previous positions)

@dominicrufa
Copy link
Author

also, I am adding a global parameter to the torchforce, which exists in the context.getParameters before i serialize the system, but not after i create a context with the deserialized system. (in cell 17 but not in cell 26)

@peastman
Copy link
Member

Here's the actual error that's being thrown:

OpenMMException: forward() is missing value for argument 'scale'. Declaration: forward(__torch__.ForceModule self, Tensor positions, Tensor scale) -> (Tensor)

It looks like the serialization proxy isn't saving and restoring the global parameters correctly. Let me fix that!

@peastman
Copy link
Member

Can you verify this fixes the problem for you? Then I can make a new release.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants