-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pushfwd-inverses #98
Pushfwd-inverses #98
Conversation
…l into pushfwd-inverses
I think I tracked down the problem: julia> μ = StdUniform()
StdUniform()
julia> ν = pushfwd(logit, StdUniform())
PushforwardMeasure(LogExpFunctions.logit, StdUniform())
julia> f = transport_to(ν, μ)
TransportFunction(PushforwardMeasure(LogExpFunctions.logit, StdUniform()), StdUniform())
julia> x = rand(μ)
0.1916567450883606
julia> y = f(x)
-1.4392808027126194
julia> inverse(f)(y)
logistic (generic function with 2 methods) |
Codecov ReportBase: 50.49% // Head: 51.53% // Increases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## dev #98 +/- ##
==========================================
+ Coverage 50.49% 51.53% +1.04%
==========================================
Files 41 41
Lines 1115 1139 +24
==========================================
+ Hits 563 587 +24
Misses 552 552
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
Getting closer. Current error is ArgumentError: Measure ν of type StdUniform has static(1) DOF but μ of type PushforwardMeasure has MeasureBase.NoDOF{PushforwardMeasure{InverseFunctions.FunctionWithInverse{typeof(logit), typeof(LogExpFunctions.logistic)}, InverseFunctions.FunctionWithInverse{typeof(LogExpFunctions.logistic), typeof(logit)}, StdUniform, MeasureBase.WithVolCorr}}() DOF |
Ah, darn, we don't have |
Yes, but luckily it's pretty easy to work around. Current problem is a DOF mismatch, which seems a little weirder. Looking into that now |
Tests are passing! I needed to add this: getdof(μ::PushforwardMeasure) = getdof(transport_origin(μ)) |
Think we need to drop the case (f, μ, ν_ref) == ((-) ∘ log1p ∘ (-), StdUniform(), StdExponential()) since 1.7 fails type inference for it. Guessing it would do better with FunctionChains.jl |
Should pass again in v1.9 (JuliaLang/julia#45715), but we can definitely use FunctionChains, it shouldn't add any measurable load time. |
I don't think we even need it as a dependency, it was just for one particular test |
Ok, then let's maybe use an |
It's passing now, and we'd have to either unwind all of the structure or have a complicated condition to test whether to check The one failing case was on v1.7, and only for If we made the change, it would really only be testing that long chains of function compositions are type-stable, nothing about measures in particular. |
Also, a single pushforward using that composed function works just fine. The only problem was in the "pushforward-of-pushforward" test I added, and then only in v1.7 |
* drop debugging code * small update for Likelihood, and a test * fixing up likelihoods * improve `basemeasure_depth` dispatch * still some trouble with inferred basemeasure_depth * clean up `For` dispatch * simplify _logdensityof * optimize for Returns{True} case * rework basemeasure_depth * aggressive tests passing!! * drop type-level stuff * drop help * license * affero * copyright notice * merge * Drop Create Commons * cleanup after merge * update support computations * insupport(d::SuperpositionMeasure, x) * dorp ParamWeighted * insupport(d::FactoredBase, x) * export unsafe_logdensityof * call promote_type instead of promote_rule * logdensity_def for named tuple product measures * type annotation for now * debugging * drop shows * speed up mapped arrays * throw an error for `Union{}` types * MT tests passing * updates * get tests passing * MIT license for MeasureBase * bump version * cleanup * spacing * Move ConditionalMeasure to MeasureBase * add LogarithmicNumbers * export basemeasure_sequence * update superpose * fix logdensity_rel * remove FIXME (it's fixed!!) * logdensityof(d::Density, x) * simplify insupport(::Lebesgue, ::Real) * clean up * assume insupport yields Bool * change logdensity_rel fall-through to warning and return NaN * update logdensity_rel * drop old code * fix warning * export logdensity_rel * logdensity_def(μ::Dirac, ν::Dirac, x) * logdensity_def methods * drop `static` * ]add StatsFuns * Fixing up superposition * [compat] entries * trying to speed things up * bugfixes * logdensity_rel tests * logdensity_rel tests * drop qualifier, and add a test * more tests * type constraint in "logdensityof(μ::AbstractMeasure, x)" (was piracy, oops) * add some docs * docs * docs * typo * moar speed * don't export Test * some more updates * logdensity_rel for products * `kleisli` docs * update instance_type * instance_type => Core.Typeof * `powermeasure` bug fix * fix logdensity_rel bug * get `commonbase` to take x type into account * test powers * commonbase docstring * deprecate instance_type * avoid breakage * switch || terms * @ifelse macro * simplify logdensity_rel * give up on this @ifelse business * bump version * working on likelihoods * update likelihood * powerweightedmeasure * powerweighted update * more powerweighted methods * bugfix * dropFactoredBase * drop FactoredBase * (::ProductMeasure) | constraint * update conditional measure * update Dirac * move conditional.jl down in the `include`s * Kleisli => TransitionKernel * simplify logdensity_def(::PowerMeasure, x) * rename kleisli.jl to kernel.jl * update Dirac tests * update Half * get tests passing * update kernel * Update Project.toml * no call-site inlining * restrict single-arg `kernel` to <:ParameterizedMeasure * export log_likelihood_ratio * Drop DensityKind(::Likelihood), at least for now * isfinite(x) instead of x>-Inf * add `condition` constructor * EOF newline * simplify logdensity_def for power measures * finishing up * updates * kernel stuff * kernel stuff * update showe methods * ass a TODO * use `dot` instead of `sum` * drop old code * typo * formatting * cleanup * kernel updates * uncomment * bugfix * drop old code * pretty printing * exports, cleanup * drop old for.jl * Make DensityKind(::AbstractLikelihood) = IsDensity() * update Compat version * Make likelihoods work with Distributions * _map(f, x::MappedArrays.ReadonlyMappedArray) * export productmeasure * AbstractMeasure(::AbstractMeasure) * fixedrng * StdNormal * add SpecialFunctions * no need to qualify * update basemeasure * include stdnormal * include fixedrng * update tests * using SpecialFunctions * fixing transport_def * transport_def bugfix * StdMeasure(::typeof(randn)) * checked_arg for LebesgueMeasure * NoTransformOrigin => NoTransportOrigin * transport interface for pushforwards * transporting pushforwards * Use LebesgueMeasure for basemeasure * updates * make testvalue fall back on FixedRNG approach * un-break testvalue * CI for Juila 1.8 * fixes * `rand` on a pushforward calls rand on its parent * LebesgueMeasure => LebesgueBase CountingMeasure => CountingBase * tests passing! * change `invoke` type * Change `test_interface` to check for 2-arg testvalue * manually-specifed inverses * more pushfwd stuff * A little less wrong * add mass interface * pullback * mass interface * working on mass interface * add some `massof` methods * Maybe <:Number is better for invalidations? * float instead of Int * logmassof * transports for proxies * drop latent-joint.jl * drop exports * Drop `logmassof` for now * reorganize Lebesgue measure * IntervalSets * proxy(::Lebesgue{MeasureBase.RealNumbers}) = LebesgueBase() * calling a "useproxy" measure calls its proxy * StdUniform()(s::Interval) * typo * (m::AbstractMeasure)(s::Interval) * bugfix * comment * IntervalSets version constraint * update dynamic_basemeasure_depth * format * Calling a measure calls `massof` * work on massof * AbstractSuperpositionMeasure * fix typo * typo * format * docstrings * remove massof(::PowerWeightedMeasure) method * make `massof` better * update testvalue * formatting * update _massof * Update transports for weighted measures * add chain rules * invariant mass under transport * typo * bugfix * hasmethod => Tricks.static_hasmethod * `massof` methods * roll back tranports for WeightedMeasure * Improve transport implementation and add product support (#97) * Improve default transport implementation Increases type stability. * Rename NoTransformOrigin to NoTransportOrigin * Add rrule for _origin_depth * Fix ambiguities when forwarding NoTransportOrigin and NoTransportOrigin * Define getdof for product measures * Generalize test_transport to tuple-valued measures * Implement transport for tuple-based products Co-authored-by: Chad Scherrer <[email protected]> * `@useproxy` delegates `massof` * drop CI for nightly * callable densities (#85) * callable densities * separate `Density` and `LogDensity`, etc * bugfix * move some code around * format * updates * working on densities * update CI * bugfix * formatting * reorg * fix typos * Drop LogDensityMeasure and refactor * docstring * inner type constructor with assertion * type parameters * 2-arg density_rel and logdensity_rel * oops * properties * typo * fix ambiguity * bugfix * bugfix * updates * drop densityof and logdensityof for AbstractDensity * updates * update tests * update * formatting * bad calls throw errors * drop CI for nightly * Pushfwd-inverses (#98) * use InverseFunctions.setinverse * bug fixes * bugfix * bugfix * pushfwd of a pushfwd * format * drop old comment * drop CI for nightly * working on pushfwd * bugfix * inverse(f) => ν.finv * separate logdensity functions from transport API * format * don't unwrap FunctionWithInverse * drop redundant method * leave logdensityof alone, instead write unsafe_logdensityof * more tests * more work on tests * tests * tests * still messing with tests * tests passing * small edits * formatting * add some more failing tests * add atol to isapprox in test * getdof(μ::PushforwardMeasure) = getdof(transport_origin(μ)) * update atol * small fix * drop ((-) ∘ log1p ∘ (-), StdUniform(), StdExponential()) * remove duplicate method * remove duplicate `include` * simplify getdof(::PushforwardMeasure) * Stieltjes measure function (#100) * smf * more smf stuff * transport_to * Lebesgue smf * smf for std measures * format * transport_to * oops * smfinv * bugfix * more fixes * minor refactoring * formatting * change x to p * bugfix * smfinv(::StdLogistic, p) * add NoSMF and NoSMFInverse * roll back some changes * transport_def methods * formatting * another rollback * make transport_def depend on smf(inv) * update smf and transports for ::Half * change `include` order * test_smf * more tests * tests * add tests * formatting * Drop unneeded type parameters * smfinv => invsmf * add some inverses * Base.Fix1 versions * some more methods * drop redundant `transport_def`s * update `pushfwd` * change name * add type * formatting * fix docstring * depend on FunctinoChains * Use fchain * simplify transport_def for StdLogistic * simplify transport_def for StdNormal * drop redundant method Co-authored-by: Oliver Schulz <[email protected]>
* clean up `For` dispatch * simplify _logdensityof * optimize for Returns{True} case * rework basemeasure_depth * aggressive tests passing!! * drop type-level stuff * drop help * license * affero * copyright notice * merge * Drop Create Commons * cleanup after merge * update support computations * insupport(d::SuperpositionMeasure, x) * dorp ParamWeighted * insupport(d::FactoredBase, x) * export unsafe_logdensityof * call promote_type instead of promote_rule * logdensity_def for named tuple product measures * type annotation for now * debugging * drop shows * speed up mapped arrays * throw an error for `Union{}` types * MT tests passing * updates * get tests passing * MIT license for MeasureBase * bump version * cleanup * spacing * Move ConditionalMeasure to MeasureBase * add LogarithmicNumbers * export basemeasure_sequence * update superpose * fix logdensity_rel * remove FIXME (it's fixed!!) * logdensityof(d::Density, x) * simplify insupport(::Lebesgue, ::Real) * clean up * assume insupport yields Bool * change logdensity_rel fall-through to warning and return NaN * update logdensity_rel * drop old code * fix warning * export logdensity_rel * logdensity_def(μ::Dirac, ν::Dirac, x) * logdensity_def methods * drop `static` * ]add StatsFuns * Fixing up superposition * [compat] entries * trying to speed things up * bugfixes * logdensity_rel tests * logdensity_rel tests * drop qualifier, and add a test * more tests * type constraint in "logdensityof(μ::AbstractMeasure, x)" (was piracy, oops) * add some docs * docs * docs * typo * moar speed * don't export Test * some more updates * logdensity_rel for products * `kleisli` docs * update instance_type * instance_type => Core.Typeof * `powermeasure` bug fix * fix logdensity_rel bug * get `commonbase` to take x type into account * test powers * commonbase docstring * deprecate instance_type * avoid breakage * switch || terms * @ifelse macro * simplify logdensity_rel * give up on this @ifelse business * bump version * working on likelihoods * update likelihood * powerweightedmeasure * powerweighted update * more powerweighted methods * bugfix * dropFactoredBase * drop FactoredBase * (::ProductMeasure) | constraint * update conditional measure * update Dirac * move conditional.jl down in the `include`s * Kleisli => TransitionKernel * simplify logdensity_def(::PowerMeasure, x) * rename kleisli.jl to kernel.jl * update Dirac tests * update Half * get tests passing * update kernel * Update Project.toml * no call-site inlining * restrict single-arg `kernel` to <:ParameterizedMeasure * export log_likelihood_ratio * Drop DensityKind(::Likelihood), at least for now * isfinite(x) instead of x>-Inf * add `condition` constructor * EOF newline * simplify logdensity_def for power measures * finishing up * updates * kernel stuff * kernel stuff * update showe methods * ass a TODO * use `dot` instead of `sum` * drop old code * typo * formatting * cleanup * kernel updates * uncomment * bugfix * drop old code * pretty printing * exports, cleanup * drop old for.jl * Make DensityKind(::AbstractLikelihood) = IsDensity() * update Compat version * Make likelihoods work with Distributions * _map(f, x::MappedArrays.ReadonlyMappedArray) * export productmeasure * AbstractMeasure(::AbstractMeasure) * fixedrng * StdNormal * add SpecialFunctions * no need to qualify * update basemeasure * include stdnormal * include fixedrng * update tests * using SpecialFunctions * fixing transport_def * transport_def bugfix * StdMeasure(::typeof(randn)) * checked_arg for LebesgueMeasure * NoTransformOrigin => NoTransportOrigin * transport interface for pushforwards * transporting pushforwards * Use LebesgueMeasure for basemeasure * updates * make testvalue fall back on FixedRNG approach * un-break testvalue * CI for Juila 1.8 * fixes * `rand` on a pushforward calls rand on its parent * LebesgueMeasure => LebesgueBase CountingMeasure => CountingBase * tests passing! * change `invoke` type * Change `test_interface` to check for 2-arg testvalue * manually-specifed inverses * more pushfwd stuff * A little less wrong * add mass interface * pullback * mass interface * working on mass interface * add some `massof` methods * Maybe <:Number is better for invalidations? * float instead of Int * logmassof * transports for proxies * drop latent-joint.jl * drop exports * Drop `logmassof` for now * reorganize Lebesgue measure * IntervalSets * proxy(::Lebesgue{MeasureBase.RealNumbers}) = LebesgueBase() * calling a "useproxy" measure calls its proxy * StdUniform()(s::Interval) * typo * (m::AbstractMeasure)(s::Interval) * bugfix * comment * IntervalSets version constraint * update dynamic_basemeasure_depth * format * Calling a measure calls `massof` * work on massof * AbstractSuperpositionMeasure * fix typo * typo * format * docstrings * remove massof(::PowerWeightedMeasure) method * make `massof` better * update testvalue * formatting * update _massof * Update transports for weighted measures * add chain rules * invariant mass under transport * typo * bugfix * hasmethod => Tricks.static_hasmethod * `massof` methods * roll back tranports for WeightedMeasure * Improve transport implementation and add product support (#97) * Improve default transport implementation Increases type stability. * Rename NoTransformOrigin to NoTransportOrigin * Add rrule for _origin_depth * Fix ambiguities when forwarding NoTransportOrigin and NoTransportOrigin * Define getdof for product measures * Generalize test_transport to tuple-valued measures * Implement transport for tuple-based products Co-authored-by: Chad Scherrer <[email protected]> * `@useproxy` delegates `massof` * drop CI for nightly * callable densities (#85) * callable densities * separate `Density` and `LogDensity`, etc * bugfix * move some code around * format * updates * working on densities * update CI * bugfix * formatting * reorg * fix typos * Drop LogDensityMeasure and refactor * docstring * inner type constructor with assertion * type parameters * 2-arg density_rel and logdensity_rel * oops * properties * typo * fix ambiguity * bugfix * bugfix * updates * drop densityof and logdensityof for AbstractDensity * updates * update tests * update * formatting * bad calls throw errors * drop CI for nightly * Pushfwd-inverses (#98) * use InverseFunctions.setinverse * bug fixes * bugfix * bugfix * pushfwd of a pushfwd * format * drop old comment * drop CI for nightly * working on pushfwd * bugfix * inverse(f) => ν.finv * separate logdensity functions from transport API * format * don't unwrap FunctionWithInverse * drop redundant method * leave logdensityof alone, instead write unsafe_logdensityof * more tests * more work on tests * tests * tests * still messing with tests * tests passing * small edits * formatting * add some more failing tests * add atol to isapprox in test * getdof(μ::PushforwardMeasure) = getdof(transport_origin(μ)) * update atol * small fix * drop ((-) ∘ log1p ∘ (-), StdUniform(), StdExponential()) * remove duplicate method * remove duplicate `include` * simplify getdof(::PushforwardMeasure) * Stieltjes measure function (#100) * smf * more smf stuff * transport_to * Lebesgue smf * smf for std measures * format * transport_to * oops * smfinv * bugfix * more fixes * minor refactoring * formatting * change x to p * bugfix * smfinv(::StdLogistic, p) * add NoSMF and NoSMFInverse * roll back some changes * transport_def methods * formatting * another rollback * make transport_def depend on smf(inv) * update smf and transports for ::Half * change `include` order * test_smf * more tests * tests * add tests * formatting * Drop unneeded type parameters * smfinv => invsmf * add some inverses * Base.Fix1 versions * some more methods * drop redundant `transport_def`s * update `pushfwd` * change name * add type * formatting * fix docstring * depend on FunctinoChains * Use fchain * simplify transport_def for StdLogistic * simplify transport_def for StdNormal * drop redundant method * update test_interface * useproxy for smf * update test_smf * bump version Co-authored-by: Oliver Schulz <[email protected]>
Requires
InverseFunctions = "0.1.8"
, and reworkspushfwd(f, finv, m)
to useInverseFunctions.setinverse