-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
minor correction in sampling.py and starting.py #4458
Conversation
CI failure seems unrelated - could you add the snippet from the issue as a test? |
pymc3/sampling.py
Outdated
update_start_vals(start, model.test_point, model) | ||
else: | ||
start = start[:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will still change the dictionary inplace.
a = [dict(a=1, b=2), dict(a=1, b=2)]
b = a[:]
for b_ in b:
b_['c'] = 3
a
[{'a': 1, 'b': 2, 'c': 3}, {'a': 1, 'b': 2, 'c': 3}]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can try start = [s.copy() for s in start]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I think @chandan5362 original suggestion to use deepcopy was good
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the function in question only adds new keys, but does not change keys already in place, a shallow copy (and nested shallow copy) should be fine. But I have no strong objections to deepcopy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea, we could use start.copy()
but to stay on the safer side, we should probably use deepcopy
though it does not make any sense here. Also we won't have to use start.copy
inside list comprehension
if we use deepcopy
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, deepcopy(None) == None
, so it can go even before any is not None
or isinstance
checks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line can also be removed now.
pymc3/sampling.py
Outdated
@@ -427,8 +427,10 @@ def sample( | |||
check_start_vals(model.test_point, model) | |||
else: | |||
if isinstance(start, dict): | |||
start = {k: v for k, v in start.items()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not simply start.copy()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is what I was suggesting initially,. Anyway, I will replace that with deepcopy
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line can be removed now.
yeah sure, I will add it as test,. |
I think
We have a couple of fragile tests that randomly fail, but nobody yet got the time to dive and check what can be done: https://github.com/pymc-devs/pymc3/issues?q=is%3Aopen+label%3Atests+flaky+test |
in this case it was just an httperror, unrelated to pymc3 tests
|
pymc3/tests/test_sampling.py
Outdated
draws=100, | ||
start=start_dict, | ||
) | ||
assert len(start_dict) == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perhaps just assert start_dict == {"X0_mu": 25}
?
also, is it possible to make a test case which hits the other branch (i.e. where there is a list of dicts)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think, assert len(start_dict) == 1
would also work as sample
method just add the transformed_RV
to the dictionary. Anyway I will replace that.
Even ,I was just thinking of adding the test for list of dict too.
But, i am not aware of such cases where we will be passing the list of dictionary.
May be, If you could come up with any such case, I will add a test for that too.
@ricardoV94 , are you aware of any such case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A list of start
dicts can be passed to initialize each chain on a different parameterset.
You can switch the step method to pm.Metropolis
, because it is much faster to initialize. Also reduce to tune=5, draws=10, chains=3
or so to speed up the test a bit.
The parameters of the distribution are not actually important for this test case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small suggestion: you can add a comment at the beginning of the test function referring to the original issue for some context
These tests are failing. I think I messed up with my branch |
pymc3/tuning/starting.py
Outdated
if start is None: | ||
start = model.test_point | ||
else: | ||
start = {k: v for k, v in start.items()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only need the deepcopy above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohhh sorry, I just forgot to remove that🤦♂️
pymc3/tests/test_sampling.py
Outdated
@@ -121,6 +121,37 @@ def test_iter_sample(self): | |||
for i, trace in enumerate(samps): | |||
assert i == len(trace) - 1, "Trace does not have correct length." | |||
|
|||
def test_sample_does_not_modify_start_as_list_of_dicts(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this test fail in master? It looks like no transforms would be added in this test model and therefore the dictionary wouldn't be changed anyway, but I could be wrong. I imagined it worked something like your test below but with a dictionary for each chain:
start_dict = [{"X0_mu": 25}, {"X0_mu": 25}]
with pm.model() as m:
X0_mu = pm.Lognormal("X0_mu", mu=np.log(0.25), sd=0.10)
trace = pm.sample(
step=pm.Metropolis(),
tune=5,
draws=10,
chains=2,
start=start_dict,
)
assert start_dict == ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see, you made one parameter be missing on purpose in each chain...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this test fail in master? It looks like no transforms would be added in this test model and therefore the dictionary wouldn't be changed anyway.
unfortunately, it updates the dictionary with model.test_point
even though no transformed variable is there to be added.
pymc3/sampling.py
Outdated
@@ -427,8 +427,10 @@ def sample( | |||
check_start_vals(model.test_point, model) | |||
else: | |||
if isinstance(start, dict): | |||
start = {k: v for k, v in start.items()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line can be removed now.
pymc3/sampling.py
Outdated
update_start_vals(start, model.test_point, model) | ||
else: | ||
start = start[:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line can also be removed now.
pymc3/tests/test_sampling.py
Outdated
start=start_dict, | ||
) | ||
assert start_dict == {"X0_mu": 25} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not overcomplicate the tests. Also we should test both pm.sample
and pm.find_MAP
.
- no need to reuse the complicated
self.model
model - variable names and parameters don't matter
- distribution should be transformed by default (Uniform, Lognormal, ...)
- everything in the same test case so the compilation is done just once
with pm.Model():
pm.Lognormal("untransformed")
# test that find_MAP doesn't change the start dict
start = {"untransformed": 2}
pm.find_MAP(start=start, niter=5)
assert start == {"untransformed": 2}
# check that sample doesn't change it either
start = {"untransformed": 0.5}
...
# and also not if start is different for each chain
start = [{"untransformed": 2}, {"untransformed": 0.5}]
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did remove these lines but It came back from nowhere (I might have fetched before committing).
Anyway, this time I will make sure that these lines does not come back.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better if we take this test outside from the TestSample
class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this need a release note or is it small enough to not warrant one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
Does this need a release note or is it small enough to not warrant one?
It was big enough for a user to notice, so I guess it warrants one?
@chandan5362 Can you add a line to the release notes too? Sorry I forgot about this before. |
yeah sure, |
addresses the issue #4456