Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Register the overloads added by CustomDist in worker processes #7241

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

EliasRas
Copy link

@EliasRas EliasRas commented Apr 6, 2024

Description

Currently sample_smc can fail due to a NotImplementedError if it's used with a model defined usingCustomDist. If a CustomDist is used without dist parameter, the overloads for _logprob, _logcdf and _support_point are registered only in the main process.

This PR adds an initializer which registers the overloads in the worker processes of the pool used in sample_smc.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7241.org.readthedocs.build/en/7241/

Copy link

welcome bot commented Apr 6, 2024

Thank You Banner]
💖 Thanks for opening this pull request! 💖 The PyMC community really appreciates your time and effort to contribute to the project. Please make sure you have read our Contributing Guidelines and filled in our pull request template to the best of your ability.

pymc/smc/sampling.py Outdated Show resolved Hide resolved
@ricardoV94
Copy link
Member

Hi @EliasRas, I'll need some time to review this one properly. Thanks for taking the initiative

@ricardoV94 ricardoV94 self-assigned this Apr 9, 2024
@EliasRas
Copy link
Author

Looks like I messed up by rebasing instead of merging and introduced plenty of unnecessary commits to this feature. Does it need to be fixed?

@twiecki
Copy link
Member

twiecki commented May 26, 2024

Yes, that needs to be fixed, happens to everyone. One approach is to start clean and cherry-pick your commits.

return rng.normal(loc=mu, scale=1, size=size)

with pm.Model():
mu = pm.CustomDist("mu", 0, logp=_logp, random=_random)
Copy link
Member

@ricardoV94 ricardoV94 May 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two kinds of CustomDist, if instead of random you pass dist you get a different Op type, that I guess would still fail after this PR

Edit: I see you mentioned this in your top message

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it fail with random but not dist. Testing locally it seems to fail with both for me?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just tested both the code I used in the linked issue and your example below. I still get no errors using pm.Potential or dist argument. Originally I was using pymc==5.10.0 but now I tested with pymc==5.13.0. Maybe there are differences between Windows and Linux if you're using one?

I'm not completely sure why using dist works for me but, based on some quick testing, DistributionMeta.__new__ is called when e.g. Normal is defined and the overloads for builtin distributions are registered there. I'm not well versed in multiprocessing or the way that Python does importing but my hunch is that the worker processes automatically import stuff from pymc and the overloads get registered as a side effect. For user-defined logprob etc. this is not the case since the registration isn't done during importing.

pymc/smc/sampling.py Outdated Show resolved Hide resolved
@ricardoV94
Copy link
Member

ricardoV94 commented May 27, 2024

I think it's more complicated than this. The following example has specific dispatch, but no RV that shows up in the graph:

import pymc as pm

def _logp(value, mu):
    return -((value - mu) ** 2)

def _dist(mu, size=None):
    return pm.Normal.dist(mu, 1, size=size)

with pm.Model():
    mu = pm.Normal("mu", 0)
    pm.Potential("term", pm.logp(pm.CustomDist.dist(mu, logp=_logp, dist=_dist), [1, 2]))
    pm.sample_smc(draws=6, cores=1)        

It also fails even with a single core

@EliasRas
Copy link
Author

It also fails even with a single core

22e8f0b did refactoring for sample_smc and I think that errors should now pop up even with single core since the sampling is always done in another process. Previously this was the case only when cores>1 since there were separate run_chains_parallel and run_chains_sequential.

@ricardoV94
Copy link
Member

ricardoV94 commented May 27, 2024

Somehow, in main, I am getting ConnectionResetError: [Errno 104] Connection reset by peer even for unrelated models without any sort of CustomDist

tests/smc/test_smc.py Outdated Show resolved Hide resolved
@ricardoV94
Copy link
Member

Somehow, in main, I am getting ConnectionResetError: [Errno 104] Connection reset by peer even for unrelated models without any sort of CustomDist

Okay it's something about the new progressbal and pycharm interactive python console. If I use from ipython/terminal it works. But also works in main for me?

@ricardoV94
Copy link
Member

I cannot reproduce a failure with your test locally (after avoding the pycharm issue) nor in a Colab environment: https://colab.research.google.com/drive/1I1n6c9IlmXknIfhxC5s7sAQghv0vfRSY?usp=sharing

Can you share more details about your environment/setup?

@EliasRas
Copy link
Author

Can you share more details about your environment/setup?

I added the output of conda list to "PyMC version information" section of #7224. I'm running the code using VSCode if that matters. Do you need anything else?

Basically I followed the install instructions and the pull request tutorial when installing. Might have also pip installed a couple of extra packages here and there.

@ricardoV94
Copy link
Member

I added the output of conda list to "PyMC version information" section of #7224. I'm running the code using VSCode if that matters. Do you need anything else?

We should have at least one person reproduce the problem because I cannot. It may be a VSCode environment issue. Ideally we wouldn't have to change the codebase

@EliasRas
Copy link
Author

The test does fail without the changes when I run it from miniforge prompt though.

@ricardoV94
Copy link
Member

The test does fail without the changes when I run it from miniforge prompt though.

Not sure what miniforge prompt is, can we try to reproduce here on the CI then? Push just the test without the fixes into a new PR and well run it to see if we can reproduce

@EliasRas EliasRas marked this pull request as ready for review June 3, 2024 05:01
@EliasRas
Copy link
Author

Is there anything that needs to be done here besides running the tests?

@twiecki twiecki requested a review from ricardoV94 June 11, 2024 10:59
@twiecki
Copy link
Member

twiecki commented Jun 11, 2024

Is there anything that needs to be done here besides running the tests?

Sorry for the delay, just kicked off tests.

Copy link

codecov bot commented Jun 11, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 92.19%. Comparing base (c8b22df) to head (d2669f2).
Report is 2 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7241      +/-   ##
==========================================
+ Coverage   92.18%   92.19%   +0.01%     
==========================================
  Files         103      103              
  Lines       17259    17282      +23     
==========================================
+ Hits        15910    15933      +23     
  Misses       1349     1349              
Files Coverage Δ
pymc/smc/sampling.py 99.34% <100.00%> (+0.11%) ⬆️

@lucianopaz
Copy link
Contributor

Thanks @EliasRas, I haven't been able to reproduce this yet but that's just because I'm in the middle of switching workstations and haven't gotten everything setup yet.
Your fix looks fine to me and I understand what you identified as the cause of the issue: the dispatching mechanism isn't registering the logp and other methods to the dynamically created class. I think that this highlights a caveat in pymc's and pytensor's design: spawned processes may not have all the registered dispatch signatures as the main process. I imagine that this is mostly a problem on Windows, where multiprocessing can only spawn new processes whereas linux based systems will default to forks which in principle should copy over the memory contents of the main processes. I'm not sure what will happen under MacOS because I think that they cannot use fork multiprocessing for some reason either.
With this design caveat in hand, I'm not sure if it's better to have a package level utility function that serves as a sort of book-keeper or something that can handle communicating the extra dispatch registration needed to ensure that child processes will use the correct dispatching functions. I'm curious to know what @ricardoV94 thinks about this. I don't think that this PR should have to tackle this kind of work, but I think that we can discuss if it's necessary here, and maybe later open an issue and a separate PR (also maybe in pytensor where dispatching is used for transpilation/compilation and maybe at some point for lazy gradients?).

@aseyboldt
Copy link
Member

I guess the underlying reason for the failure is that pickling of DensityDist doesn't work out of the box? Sounds like for some reason the dispatch functions don't get registered when the object is unpickled. But wouldn't it be cleaner to overwrite the pickling behavior of this class then? We could override __getstate__ and __setstate__ methods to that effect?

@lucianopaz
Copy link
Contributor

I guess the underlying reason for the failure is that pickling of DensityDist doesn't work out of the box?

I don’t think the problem is about pickling. The DensityDist end up returning an op that can be cloudpickled. If I recall correctly it can’t be pickled because the op class is created on the fly. In the process of creating the op, the dispatchers get populated with the callables that are supplied as inputs to the distribution class. As far as I understand, those functions are detached from the rv op and that’s why they never get populated on a spawned process.

@aseyboldt
Copy link
Member

I don't mean that the pickling itself throws an error (it doesn't), but that it would be the responsibility of the DensityDist object to ensure that the set-up it needs (ie registering the logp) is done when it is unpickled.

For instance the following fails with the NotImplementedError, and has nothing to do with smc, so I guess the solution shouldn't be specific to smc?

import pymc as pm
import cloudpickle
import multiprocessing


def use_logp_func(pickled_model):
    model = cloudpickle.loads(pickled_model)
    logp = model.logp()
    func = pm.pytensorf.compile_pymc(model.value_vars, logp)
    print(func(1.0))


if __name__ == "__main__":
    with pm.Model() as model:

        def logp(value):
            return -(value**2)

        pm.DensityDist("x", logp=logp)

    logp = model.logp()
    func = pm.pytensorf.compile_pymc(model.value_vars, logp)
    pickled_model = cloudpickle.dumps(model)

    ctx = multiprocessing.get_context("spawn")
    process = ctx.Process(target=use_logp_func, args=(pickled_model,))
    process.start()
    process.join()

@lucianopaz
Copy link
Contributor

I completely agree that this problem isn’t unique to smc and is a design caveat that needs to be addressed more comprehensively.
I think that we can kind of patch some things:

  1. Make Model objects __setstate__ and __getstate__ repopulate the dispatch registries
  2. Get CustomDist rv ops to have these methods defined somehow (maybe clojures) that repopulates the dispatch registries.

I’m not sure if these two methods can cover all use patterns though.

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 12, 2024

Alternatively we could pass the functions needed to each process which is more like what pm.sample does.

This also avoids recompiling the same functions multiple times?

@EliasRas
Copy link
Author

@lucianopaz
Point 1. is pretty straightforward but could you explain what you meant by 2.? How would it be different from overriding __getstate__ and __setstate__?

@EliasRas
Copy link
Author

I started work on __getstate__ and __setstate__ but I realized that I can't just copy the current implementation due to circular imports. Would either of these sound like a good idea?

  1. Create utils module (or something similar) in pymc.model that handles the custom methods and registering them.
  2. Add a method for RVs that registers the custom methods if necessary. CustomDistRV and CustomSymbolicDistRV would have the registrations and RandomVariable and SymbolicRandomVariable would just pass.

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 17, 2024

I think we should explore an alternative where we compile the functions SMC needs and fork afterwards like pm.sample does.

This approach seems more brittle?

It would also avoid re-compiling the same functions in each chain

@EliasRas
Copy link
Author

EliasRas commented Jul 17, 2024

I agree on the re-compiling part but shouldn't this still be fixed? It feels like an arbitrary decision to "disallow" using multiprocessing this way only on Windows even if it is a bad way.

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 17, 2024

I agree on the re-compiling part but shouldn't this still be fixed? It feels like an arbitrary decision to "disallow" using multiprocessing this way only on Windows even if it is a bad way.

I think this limitation is likely deeper than what you're addressing here. As @aseyboldt and @lucianopaz mentioned we're using dynamic dispatching as a recurring theme in our codebase and pytensor's

However, I don't agree with their solutions

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 17, 2024

Using the class that's being dispatched to register the dispatches during pickling seems at odds with the point of dispatching. The class shouldn't have to know what's being dispatched upon.

For instance we also have icdf methods, what if someone dispatched on it from the outside, does pickling work for it? Or would the setstate/getstate need to know about icdf (as well any other dispatch that may not even be part of PyMC)?

It's also not a PyMC model responsibility. CustomDist can be defined just fine outside of a model

@EliasRas
Copy link
Author

EliasRas commented Jul 17, 2024

Thank you for taking the time to explain. I'll start working on the compilation approach.

@ricardoV94
Copy link
Member

However since this fixes existing behavior I think we can go ahead and merge it as a temporary patch?

pymc/smc/sampling.py Outdated Show resolved Hide resolved
pymc/smc/sampling.py Outdated Show resolved Hide resolved
pymc/smc/sampling.py Show resolved Hide resolved
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you open a follow up issue for the compiling once and then forking?

@EliasRas
Copy link
Author

I started working on compiling the functions in the main process. Should I close this or is there any chance of this getting merged before the newer PR?

@EliasRas
Copy link
Author

EliasRas commented Sep 4, 2024

I started working on compiling the functions in the main process. Should I close this or is there any chance of this getting merged before the newer PR?

@ricardoV94 @lucianopaz @aseyboldt @twiecki Just following up on this. Should I close the PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG: pymc.sample_smc fails with pymc.CustomDist
5 participants