-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move predict
from Turing
#716
Conversation
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
The reason is some tests implicitly rely on the variance of the posterior samples. Discarding some initial samples fixes this. |
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Pull Request Test Coverage Report for Build 12435577043Details
💛 - Coveralls |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #716 +/- ##
==========================================
+ Coverage 85.93% 86.05% +0.12%
==========================================
Files 36 36
Lines 4280 4325 +45
==========================================
+ Hits 3678 3722 +44
- Misses 602 603 +1 ☔ View full report in Codecov by Sentry. |
We had a fast discussion on this today at the meeting. Tor raised that we should probably implement Also although we don't use |
Specifically, I was thinking |
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
src/model.jl
Outdated
varinfos::AbstractArray{<:AbstractVarInfo}; | ||
include_all=false, | ||
) | ||
predictive_samples = Array{PredictiveSample}(undef, size(varinfos)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need the PredictiveSample
here?
My original suggestion was just to use Vector{<:OrderedDict}
for the return-value (an abstractly typed PredictiveSample
doesn't really offer anything beyond this, does it?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't think too deep about this. A new type certainly is easier to dispatch on, but may not be necessary. Let me look into it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But we don't need to dispatch on this, do we?
Also, maybe it makes more sense to follow the convetion of return the same type as the input type, i.e. in this case we should return a AbstractArray{<:AbstractVarInfo}
and in the Chains
case we return Chains
Otherwise stuff is starting to look nice though:) |
src/model.jl
Outdated
varinfos::AbstractArray{<:AbstractVarInfo}; | ||
include_all=false, | ||
) | ||
predictive_samples = similar(varinfos, OrderedDict{Symbol,Any}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a resaon why you're using Symbol
instead of VarName
here? Seems better to use VarName
, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this is confusing. The OrderedDict
here is actually
OrderedDict{Symbol, Any}(
values => ..., # a vector of Tuples (varname, value)
logp =>
)
using NamedTuple now, and use better field names
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But if your keeping other information than just the realizations (which, tbh, is IMO all we need here), why aren't we just returning the varinfos themselves (I suggested this is in the other comment here)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah gotcha 👍 , I totally misunderstood: I was reading the AbstractVector
part but somehow ignored {<:AbstractVarInfo}
part.
I can get behind the idea of using a vector of VarInfo
for predict
and return a vector of VarInfo
s. But I think the interface need to be spec-ed more. For instance, ideally we want be more clear on questions like: in the returned VarInfo
s, should the VarName
be varname leaves or as appeared in the model; should the values in the returned VarInfo in transformed or constrained space; how exactly should model
and input VarInfo
s conform to each other.
I am a bit short for time now, so after some thoughts, I think it's probably a good idea now to just keep all the logic in MCMCChainsExt
and maintain exactly the same interface Turing.jl
has now. Then in the future, we can work on to improve predict
interface.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hould the VarName be varname leaves or as appeared in the model
As appeared in the model:)
should the values in the returned VarInfo in transformed or constrained spac
Constrained space.
how exactly should model and input VarInfos conform to each other.
Confused; what do you mean?
Overall, I'm still a bit confused by this discussion: Turing.jl's predict
literally does: iterate over chain, create varinfo, evaluate model on varinfo, and extract variables from varinfo.
So, why do we not just do
# In DynamicPPL.jl proper:
function predict(rng::Random.AbstractRNG, model::Model, chain::AbstractVector{<:AbstractVarInfo})
varinfo = DynamicPPL.VarInfo(model)
return map(chain) do varinfo_params
DynamicPPL.setval_and_resample!(varinfo, varinfo_params)
model(rng, varinfo)
return deepcopy(varinfo)
end
end
which is effectively what Turing.jl's predict
does before converting into a Chains
?
EDIT: This is ignoring the values_as_in_model
which apparently is used in Turing.jl's predict
, though, as mentioned in the other comment, it's very unclear if that's what we want here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, sorry it was a bit confusing.
I am thinking that it'll be more intuitive for predict
to hold that
predicted_vis = predict(rng, model, varinfos)
then
_varinfos = predict(rng, model, predicted_vis)
returns varinfos
that looks like varinfos
.
But if the values are in constrained space, can this break?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also does the above code return varinfo
s with values in constainted space?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a few comments:+1:
The current impl has different behavior from Turing.predict
in a few different ways, so we should address these issues before merging.
test/Project.toml
Outdated
@@ -2,6 +2,7 @@ | |||
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | |||
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" | |||
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" | |||
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, this doesn't quite seem worth it to test predict
, no? What's the reasoning here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't add anything or change the implementation in this PR.
Agree AHMC is heavy dep, but tests like https://github.com/TuringLang/DynamicPPL.jl/blob/fd1277b7201477448d3257cab65557b850bcf5b4/test/ext/DynamicPPLMCMCChainsExt.jl#L48C1-L55C45
rely on quality of samples
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, but should just replace them with samples from the prior or something. This is just checking that the statistics are correct; it doesn't matter if these statistics are from the prior or posterior 🤷
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, would it be really bad to make AdvancedHMC
be a test dependency of DynamicPPL? (again, I don't like this either, but it's not too bad, I would be for adding an issue for removing this dependency later than tempering more with this PR anymore)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't look at this PR properly until Wednesday, but in https://github.com/TuringLang/DynamicPPL.jl/pull/733/files#diff-3981168ff1709b3f48c35e40f491c26d9b91fc29373e512f1272f3b928cea6c0 I wrote a function that generates a chain by sampling from the prior. (It's called make_chain_from_prior
if the link doesn't bring you to the right place)
Feel free to take it if you think it's useful :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sunxd3, @penelopeysm the posterior of Bayesian linear regression can be obtained in closed form (i.e. it is a Gaussian, see here). I suggest that
- add this BLR model to DynamicPPL test models
- implement its analytical posterior
- sample from the analytical posterior directly and drop the AHMC deps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Though the closed-form posterior is a good idea, there's really no need to run this test on posterior samples:) These were just some stats that were picked to have something to compare to; prior chain is the way to go I think 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prior chain make sense: should we generate samples from prior, take out samples of a particular variable, and try to predict it?
src/model.jl
Outdated
function predict(model::Model, chain; include_all=false) | ||
# this is only defined in `ext/DynamicPPLMCMCChainsExt.jl` | ||
# TODO: add other methods for different type of `chain` arguments: e.g., `VarInfo`, `NamedTuple`, and `OrderedDict` | ||
return predict(Random.default_rng(), model, chain; include_all) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If so, we should definitively inform the user of this, no? Otherwise they'll just be like "oh why is this not defined?"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we want to export predict
right now, so predict
is only available through Turing.jl
, give or take.
would function not defined
be meaningful enough if user give other types of input?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If Turing exports it, it's better for DynamicPPL to export it, too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, I was proposing delaying this until a good predict
spec is reached
test/ext/DynamicPPLMCMCChainsExt.jl
Outdated
m_lin_reg = linear_reg(xs_train, ys_train) | ||
chain_lin_reg = sample( | ||
DynamicPPL.LogDensityFunction(m_lin_reg), | ||
AdvancedHMC.NUTS(0.65), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really doesn't seem necessary to use NUTS
here. Just construct a Chains
by hand or something, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same reason as above: some tests relies on the quality of the samples
ext/DynamicPPLMCMCChainsExt.jl
Outdated
|
||
# Examples | ||
```jldoctest | ||
julia> using DynamicPPL, AbstractMCMC, AdvancedHMC, ForwardDiff; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here: no need to use AdvancedHMC
(or any of the other packages), just construct the Chains
by hand.
This also doesn't actually show that you need to import MCMCChains
for this to work, which might be a good idea
) | ||
model(rng, varinfo, DynamicPPL.SampleFromPrior()) | ||
|
||
vals = DynamicPPL.values_as_in_model(model, varinfo) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is actually changing the behavior from Turing.jl's implementation. This will result in also including variables used in :=
statements, which is not currently done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ooooh nice catch; thanks! Hmm, uncertain if this is desired behavior though 😕
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw your issue on :=
, totally understand the concern here. But if we are not exporting predict
, we can change this in near future, also we might want to use fix
in the future, so the behavior will be right then.
We would need to make a minor release of Turing
if we change this now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But isn't this the purpose of this PR? To move the predict
from Turing.jl to DynamicPPL.jl?
also we might want to use fix in the future
Whether we're using fix
or not is just an internal impl detail, and is not relevant for its usage, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But isn't this the purpose of this PR? To move the predict from Turing.jl to DynamicPPL.jl?
Ideally, I would want this PR to do a proper implementation of predict
in DynamicPPL. But now, I am okay with the PR being only a first step towards that.
Whether we're using fix or not is just an internal impl detail, and is not relevant for its usage, right?
what I was trying to say is that, with fix
it should have the right behavior (with regards to :=
). Of course not the only way to reach the desired behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Improving it in a separate PR sounds good, but please create an issue to track @torfjelde's comment.
src/model.jl
Outdated
the samples in `chain`. This is useful when you want to sample only new variables from the posterior | ||
predictive distribution. | ||
""" | ||
function predict(model::Model, chain; include_all=false) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In Turing.jl we're currently overloading StatsBase.predict
, so we should probably do the same here, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree with this, but probably not time yet. Definitely after TuringLang/AbstractPPL.jl#81 is merged 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But is this PR then held up until that PR is merged then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, that PR doesn't really matter; overloading StatsBase.predict
here and now just means that we'll immediately be compliant with the AbstractPPL.jl interface when that PR merges?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Grey area: for me it is okay, because this PR is just about introduce a Turing
-faced predict
, not a user faced one yet. At the moment predict
is not a public API yet
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If nothing significant is missing in TuringLang/AbstractPPL.jl#81, let's merge it and overload AbstractPPL.predict
here.
@sunxd3, let's get this merged in the next few days. |
will do, on top of my priority list |
Some regression test for TuringLang/Turing.jl#1352 are removed, as far as I can tell, it should be covered by tests of |
edit: this is inaccurate, |
@yebai @torfjelde @penelopeysm I think this should be ready for another look |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests are failing because fix, condition are exported by AbstractPPL, while DynamicPPL currently doesn't actually import these from AbstractPPL.
Let's fix these in this PR if possible.
I think the tests are run, but the codecov thinks the code in |
This PR migrates the
predict
function from Turing.jl to DynamicPPL while maintaining its existing interface and core implementation. Sincepredict
returns aMCMCChains.Chain
, the implementation is placed inMCMCChainsExt
, similar togenerated_quantities
.The purpose of the PR is not to add a "proper"
predict
implementation for DynamicPPL just yet, but as a first step towards that. Some improvements we should make in the future:StatsBase.predict
NamedTuple
,OrderedDict
,VarInfo
, etc.values_as_in_model
is probably wrong (ref Better support for:=
Turing.jl#2409)