Skip to content

Commit 72a123a

Browse files
committed
optimise univariate transforms
1 parent accb515 commit 72a123a

File tree

2 files changed

+19
-50
lines changed

2 files changed

+19
-50
lines changed

HISTORY.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ Removed the method `returned(::Model, values, keys)`; please use `returned(::Mod
4949
The method `DynamicPPL.init` (for implementing `AbstractInitStrategy`) now has a different signature: it must return a tuple of the generated value, plus a transform function that maps it back to unlinked space.
5050
This is a generalisation of the previous behaviour, where `init` would always return an unlinked value (in effect forcing the transform to be the identity function).
5151

52+
Improved performance of transformations of univariate distributions' samples to and from their vectorised forms.
53+
5254
## 0.38.9
5355

5456
Remove warning when using Enzyme as the AD backend.

src/utils.jl

Lines changed: 17 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -261,42 +261,6 @@ invlink_transform(dist) = inverse(link_transform(dist))
261261
# Helper functions for vectorize/reconstruct values #
262262
#####################################################
263263

264-
"""
265-
UnwrapSingletonTransform(input_size::InSize)
266-
267-
A transformation that unwraps a singleton array, returning a scalar.
268-
269-
The `input_size` field is the expected size of the input. In practice this only determines
270-
the number of indices, since all dimensions must be 1 for a singleton. `input_size` is used
271-
to check the validity of the input, but also to determine the correct inverse operation.
272-
273-
By default `input_size` is `(1,)`, in which case `tovec` is the inverse.
274-
"""
275-
struct UnwrapSingletonTransform{InSize} <: Bijectors.Bijector
276-
input_size::InSize
277-
end
278-
279-
UnwrapSingletonTransform() = UnwrapSingletonTransform((1,))
280-
281-
function (f::UnwrapSingletonTransform)(x)
282-
if size(x) != f.input_size
283-
throw(DimensionMismatch("Expected input of size $(f.input_size), got $(size(x))"))
284-
end
285-
return only(x)
286-
end
287-
288-
function Bijectors.with_logabsdet_jacobian(f::UnwrapSingletonTransform, x)
289-
return f(x), zero(LogProbType)
290-
end
291-
292-
function Bijectors.with_logabsdet_jacobian(
293-
inv_f::Bijectors.Inverse{<:UnwrapSingletonTransform}, x
294-
)
295-
f = inv_f.orig
296-
result = reshape([x], f.input_size)
297-
return result, zero(LogProbType)
298-
end
299-
300264
"""
301265
ReshapeTransform(input_size::InSize, output_size::OutSize)
302266
@@ -370,14 +334,26 @@ function Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:ToChol}, y)
370334
)
371335
end
372336

337+
struct Only end
338+
struct NotOnly end
339+
(::Only)(x) = x[]
340+
(::NotOnly)(y) = [y]
341+
function Bijectors.with_logabsdet_jacobian(::Only, x::AbstractVector{T}) where {T<:Real}
342+
return (x[], zero(T))
343+
end
344+
Bijectors.with_logabsdet_jacobian(::Only, x::AbstractVector) = (x[], zero(LogProbType))
345+
Bijectors.inverse(::Only) = NotOnly()
346+
Bijectors.with_logabsdet_jacobian(::NotOnly, y::T) where {T<:Real} = ([y], zero(T))
347+
Bijectors.with_logabsdet_jacobian(::NotOnly, y) = ([y], zero(LogProbType))
348+
373349
"""
374350
from_vec_transform(x)
375351
376352
Return the transformation from the vector representation of `x` to original representation.
377353
"""
378354
from_vec_transform(x::AbstractArray) = from_vec_transform_for_size(size(x))
379355
from_vec_transform(C::Cholesky) = ToChol(C.uplo) ReshapeTransform(size(C.UL))
380-
from_vec_transform(::Real) = UnwrapSingletonTransform()
356+
from_vec_transform(::Real) = Only()
381357

382358
"""
383359
from_vec_transform_for_size(sz::Tuple)
@@ -395,7 +371,7 @@ Return the transformation from the vector representation of a realization from
395371
distribution `dist` to the original representation compatible with `dist`.
396372
"""
397373
from_vec_transform(dist::Distribution) = from_vec_transform_for_size(size(dist))
398-
from_vec_transform(::UnivariateDistribution) = UnwrapSingletonTransform()
374+
from_vec_transform(::UnivariateDistribution) = Only()
399375
from_vec_transform(dist::LKJCholesky) = ToChol(dist.uplo) ReshapeTransform(size(dist))
400376

401377
struct ProductNamedTupleUnvecTransform{names,T<:NamedTuple{names}}
@@ -441,7 +417,7 @@ end
441417
# This function returns the length of the vector that the function from_vec_transform
442418
# expects. This helps us determine which segment of a concatenated vector belongs to which
443419
# variable.
444-
_input_length(from_vec_trfm::UnwrapSingletonTransform) = 1
420+
_input_length(::Only) = 1
445421
_input_length(from_vec_trfm::ReshapeTransform) = prod(from_vec_trfm.output_size)
446422
function _input_length(trfm::ProductNamedTupleUnvecTransform)
447423
return sum(_input_length from_vec_transform, values(trfm.dists))
@@ -477,18 +453,9 @@ function from_linked_vec_transform(dist::Distribution)
477453
f_vec = from_vec_transform(inverse(f_invlink), size(dist))
478454
return f_invlink f_vec
479455
end
480-
481-
# UnivariateDistributions need to be handled as a special case, because size(dist) is (),
482-
# which makes the usual machinery think we are dealing with a 0-dim array, whereas in
483-
# actuality we are dealing with a scalar.
484-
# TODO(mhauru) Hopefully all this can go once the old Gibbs sampler is removed and
485-
# VarNamedVector takes over from Metadata.
486456
function from_linked_vec_transform(dist::UnivariateDistribution)
487-
f_invlink = invlink_transform(dist)
488-
f_vec = from_vec_transform(inverse(f_invlink), size(dist))
489-
f_combined = f_invlink f_vec
490-
sz = Bijectors.output_size(f_combined, size(dist))
491-
return UnwrapSingletonTransform(sz) f_combined
457+
# This is a performance optimisation
458+
return Only() invlink_transform(dist)
492459
end
493460
function from_linked_vec_transform(dist::Distributions.ProductNamedTupleDistribution)
494461
return invlink_transform(dist)

0 commit comments

Comments
 (0)