Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ Removed the method `returned(::Model, values, keys)`; please use `returned(::Mod
The method `DynamicPPL.init` (for implementing `AbstractInitStrategy`) now has a different signature: it must return a tuple of the generated value, plus a transform function that maps it back to unlinked space.
This is a generalisation of the previous behaviour, where `init` would always return an unlinked value (in effect forcing the transform to be the identity function).

Improved performance of transformations of univariate distributions' samples to and from their vectorised forms.

## 0.38.9

Remove warning when using Enzyme as the AD backend.
Expand Down
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Expand Down Expand Up @@ -62,10 +64,12 @@ DocStringExtensions = "0.9"
EnzymeCore = "0.6 - 0.8"
ForwardDiff = "0.10.12, 1"
InteractiveUtils = "1"
InverseFunctions = "0.1.17"
JET = "0.9, 0.10, 0.11"
KernelAbstractions = "0.9.33"
LinearAlgebra = "1.6"
LogDensityProblems = "2"
LogExpFunctions = "0.3.29"
MCMCChains = "6, 7"
MacroTools = "0.5.6"
MarginalLogDensities = "0.4.3"
Expand Down
229 changes: 177 additions & 52 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
using LogExpFunctions: LogExpFunctions
using InverseFunctions: InverseFunctions, inverse

# singleton for indicating if no default arguments are present
struct NoDefault end
const NO_DEFAULT = NoDefault()
Expand Down Expand Up @@ -261,42 +264,6 @@ invlink_transform(dist) = inverse(link_transform(dist))
# Helper functions for vectorize/reconstruct values #
#####################################################

"""
UnwrapSingletonTransform(input_size::InSize)

A transformation that unwraps a singleton array, returning a scalar.

The `input_size` field is the expected size of the input. In practice this only determines
the number of indices, since all dimensions must be 1 for a singleton. `input_size` is used
to check the validity of the input, but also to determine the correct inverse operation.

By default `input_size` is `(1,)`, in which case `tovec` is the inverse.
"""
struct UnwrapSingletonTransform{InSize} <: Bijectors.Bijector
input_size::InSize
end

UnwrapSingletonTransform() = UnwrapSingletonTransform((1,))

function (f::UnwrapSingletonTransform)(x)
if size(x) != f.input_size
throw(DimensionMismatch("Expected input of size $(f.input_size), got $(size(x))"))
end
return only(x)
end

function Bijectors.with_logabsdet_jacobian(f::UnwrapSingletonTransform, x)
return f(x), zero(LogProbType)
end

function Bijectors.with_logabsdet_jacobian(
inv_f::Bijectors.Inverse{<:UnwrapSingletonTransform}, x
)
f = inv_f.orig
result = reshape([x], f.input_size)
return result, zero(LogProbType)
end

"""
ReshapeTransform(input_size::InSize, output_size::OutSize)

Expand Down Expand Up @@ -370,14 +337,186 @@ function Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:ToChol}, y)
)
end

struct Only end
struct NotOnly end
(::Only)(x) = x[]
(::NotOnly)(y) = [y]
function Bijectors.with_logabsdet_jacobian(::Only, x::AbstractVector{T}) where {T<:Real}
return (x[], zero(T))
end
Bijectors.with_logabsdet_jacobian(::Only, x::AbstractVector) = (x[], zero(LogProbType))
InverseFunctions.inverse(::Only) = NotOnly()
InverseFunctions.inverse(::NotOnly) = Only()
Bijectors.with_logabsdet_jacobian(::NotOnly, y::T) where {T<:Real} = ([y], zero(T))
Bijectors.with_logabsdet_jacobian(::NotOnly, y) = ([y], zero(LogProbType))
struct ExpOnly{L<:Real}
lower::L
end
(e::ExpOnly)(y::AbstractVector{<:Real}) = exp(y[]) + e.lower
function Bijectors.with_logabsdet_jacobian(e::ExpOnly, y::AbstractVector{<:Real})
yi = y[]
x = exp(yi)
return (x + e.lower, yi)
end
InverseFunctions.inverse(e::ExpOnly) = LogVect(e.lower)
struct LogVect{L<:Real}
lower::L
end
(l::LogVect)(x::Real) = [log(x - l.lower)]
function Bijectors.with_logabsdet_jacobian(l::LogVect, x::Real)
logx = log(x - l.lower)
return ([logx], -logx)
end
InverseFunctions.inverse(l::LogVect) = ExpOnly(l.lower)
struct TruncateOnly{L<:Real,U<:Real}
lower::L
upper::U
end
function (t::TruncateOnly)(y::AbstractVector{<:Real})
lbounded, ubounded = isfinite(t.lower), isfinite(t.upper)
return if lbounded && ubounded
((t.upper - t.lower) * LogExpFunctions.logistic(y[])) + t.lower
elseif lbounded
exp(y[]) + t.lower
elseif ubounded
t.upper - exp(y[])
else
y[]
end
end
function Bijectors.with_logabsdet_jacobian(
t::TruncateOnly, y::AbstractVector{T}
) where {T<:Real}
lbounded, ubounded = isfinite(t.lower), isfinite(t.upper)
return if lbounded && ubounded
bma = t.upper - t.lower
yi = y[]
res = (bma * LogExpFunctions.logistic(yi)) + t.lower
# TODO: Bijectors uses this:
# absy = abs(yi)
# return log(bma) - absy - (2 * log1pexp(-absy))
# Check if it's more numerically stable. Don't immediately see a reason why, but I
# assume there's a reason for it.
logjac = log(bma) + yi - (2 * LogExpFunctions.log1pexp(yi))
res, logjac
elseif lbounded
yi = y[]
exp(yi) + t.lower, yi
elseif ubounded
yi = y[]
t.upper - exp(yi), yi
else
y[], zero(T)
end
end
InverseFunctions.inverse(t::TruncateOnly) = UntruncateVect(t.lower, t.upper)

struct UntruncateVect{L<:Real,U<:Real}
lower::L
upper::U
end
function (u::UntruncateVect)(x::Real)
lbounded, ubounded = isfinite(u.lower), isfinite(u.upper)
return [
if lbounded && ubounded
LogExpFunctions.logit((x - u.lower) / (u.upper - u.lower))
elseif lbounded
log(x - u.lower)
elseif ubounded
log(u.upper - x)
else
x
end,
]
end
function Bijectors.with_logabsdet_jacobian(u::UntruncateVect, x::Real)
lbounded, ubounded = isfinite(u.lower), isfinite(u.upper)
return if lbounded && ubounded
bma = u.upper - u.lower
xma = x - u.lower
xma_over_bma = xma / bma
[LogExpFunctions.logit(xma_over_bma)], -log(xma_over_bma * (u.upper - x))
elseif lbounded
log_xma = log(x - u.lower)
[log_xma], -log_xma
elseif ubounded
log_bmx = log(u.upper - x)
[log_bmx], -log_bmx
else
return zero(x)
end
end
InverseFunctions.inverse(u::UntruncateVect) = TruncateOnly(u.lower, u.upper)

for dist_type in [
Distributions.Cauchy,
Distributions.Chernoff,
Distributions.Gumbel,
Distributions.JohnsonSU,
Distributions.Laplace,
Distributions.Logistic,
Distributions.NoncentralT,
Distributions.Normal,
Distributions.NormalCanon,
Distributions.NormalInverseGaussian,
Distributions.PGeneralizedGaussian,
Distributions.SkewedExponentialPower,
Distributions.SkewNormal,
Distributions.TDist,
]
@eval begin
from_linked_vec_transform(::$dist_type) = Only()
to_linked_vec_transform(::$dist_type) = NotOnly()
end
end
for dist_type in [
Distributions.BetaPrime,
Distributions.Chi,
Distributions.Chisq,
Distributions.Erlang,
Distributions.Exponential,
Distributions.FDist,
# Wikipedia's definition of the Frechet distribution allows for a location parameter,
# which could cause its minimum to be nonzero. However, Distributionsistributions.jl's `Frechet`
# does not implement this, so we can lump it in here.
Distributions.Frechet,
Distributions.Gamma,
Distributions.InverseGamma,
Distributions.InverseGaussian,
Distributions.Kolmogorov,
Distributions.Lindley,
Distributions.LogNormal,
Distributions.NoncentralChisq,
Distributions.NoncentralF,
Distributions.Rayleigh,
Distributions.Rician,
Distributions.StudentizedRange,
Distributions.Weibull,
]
@eval begin
from_linked_vec_transform(d::$dist_type) = ExpOnly(minimum(d))
to_linked_vec_transform(d::$dist_type) = LogVect(minimum(d))
end
end
function to_linked_vec_transform(d::Distributions.ContinuousUnivariateDistribution)
return UntruncateVect(minimum(d), maximum(d))
end
function from_linked_vec_transform(d::Distributions.ContinuousUnivariateDistribution)
return TruncateOnly(minimum(d), maximum(d))
end
from_vec_transform(::Distributions.UnivariateDistribution) = Only()
to_vec_transform(::Distributions.UnivariateDistribution) = NotOnly()
from_linked_vec_transform(::DiscreteUnivariateDistribution) = Only()
to_linked_vec_transform(::DiscreteUnivariateDistribution) = NotOnly()

"""
from_vec_transform(x)

Return the transformation from the vector representation of `x` to original representation.
"""
from_vec_transform(x::AbstractArray) = from_vec_transform_for_size(size(x))
from_vec_transform(C::Cholesky) = ToChol(C.uplo) ∘ ReshapeTransform(size(C.UL))
from_vec_transform(::Real) = UnwrapSingletonTransform()
from_vec_transform(::Real) = Only()

"""
from_vec_transform_for_size(sz::Tuple)
Expand All @@ -395,7 +534,6 @@ Return the transformation from the vector representation of a realization from
distribution `dist` to the original representation compatible with `dist`.
"""
from_vec_transform(dist::Distribution) = from_vec_transform_for_size(size(dist))
from_vec_transform(::UnivariateDistribution) = UnwrapSingletonTransform()
from_vec_transform(dist::LKJCholesky) = ToChol(dist.uplo) ∘ ReshapeTransform(size(dist))

struct ProductNamedTupleUnvecTransform{names,T<:NamedTuple{names}}
Expand Down Expand Up @@ -441,7 +579,7 @@ end
# This function returns the length of the vector that the function from_vec_transform
# expects. This helps us determine which segment of a concatenated vector belongs to which
# variable.
_input_length(from_vec_trfm::UnwrapSingletonTransform) = 1
_input_length(::Only) = 1
_input_length(from_vec_trfm::ReshapeTransform) = prod(from_vec_trfm.output_size)
function _input_length(trfm::ProductNamedTupleUnvecTransform)
return sum(_input_length ∘ from_vec_transform, values(trfm.dists))
Expand Down Expand Up @@ -477,19 +615,6 @@ function from_linked_vec_transform(dist::Distribution)
f_vec = from_vec_transform(inverse(f_invlink), size(dist))
return f_invlink ∘ f_vec
end

# UnivariateDistributions need to be handled as a special case, because size(dist) is (),
# which makes the usual machinery think we are dealing with a 0-dim array, whereas in
# actuality we are dealing with a scalar.
# TODO(mhauru) Hopefully all this can go once the old Gibbs sampler is removed and
# VarNamedVector takes over from Metadata.
function from_linked_vec_transform(dist::UnivariateDistribution)
f_invlink = invlink_transform(dist)
f_vec = from_vec_transform(inverse(f_invlink), size(dist))
f_combined = f_invlink ∘ f_vec
sz = Bijectors.output_size(f_combined, size(dist))
return UnwrapSingletonTransform(sz) ∘ f_combined
end
function from_linked_vec_transform(dist::Distributions.ProductNamedTupleDistribution)
return invlink_transform(dist)
end
Expand Down
Loading