Skip to content

Commit fa493fb

Browse files
authored
Clamp cdf and ccdf of Truncated (#1865)
1 parent fe57164 commit fa493fb

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

src/truncate.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,8 @@ function logpdf(d::Truncated, x::Real)
166166
end
167167

168168
function cdf(d::Truncated, x::Real)
169-
result = (cdf(d.untruncated, x) - d.lcdf) / d.tp
169+
result = clamp((cdf(d.untruncated, x) - d.lcdf) / d.tp, 0, 1)
170+
# Special cases for values outside of the support to avoid e.g. NaN issues with `Binomial`
170171
return if d.lower !== nothing && x < d.lower
171172
zero(result)
172173
elseif d.upper !== nothing && x >= d.upper
@@ -188,7 +189,8 @@ function logcdf(d::Truncated, x::Real)
188189
end
189190

190191
function ccdf(d::Truncated, x::Real)
191-
result = (d.ucdf - cdf(d.untruncated, x)) / d.tp
192+
result = clamp((d.ucdf - cdf(d.untruncated, x)) / d.tp, 0, 1)
193+
# Special cases for values outside of the support to avoid e.g. NaN issues with `Binomial`
192194
return if d.lower !== nothing && x <= d.lower
193195
one(result)
194196
elseif d.upper !== nothing && x > d.upper

test/truncate.jl

+18
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,21 @@ end
214214

215215
@test isa(quantile(d, ForwardDiff.Dual(1.,0.)), ForwardDiff.Dual)
216216
end
217+
218+
@testset "cdf outside of [0, 1] (#1854)" begin
219+
dist = truncated(Normal(2.5, 0.2); lower=0.0)
220+
@test @inferred(cdf(dist, 3.741058503233821e-17)) === 0.0
221+
@test @inferred(ccdf(dist, 3.741058503233821e-17)) === 1.0
222+
@test @inferred(cdf(dist, 1.4354474178676617e-18)) === 0.0
223+
@test @inferred(ccdf(dist, 1.4354474178676617e-18)) === 1.0
224+
@test @inferred(cdf(dist, 8.834854780587132e-18)) === 0.0
225+
@test @inferred(ccdf(dist, 8.834854780587132e-18)) === 1.0
226+
227+
dist = truncated(
228+
Normal(2.122039143928797, 0.07327367710864985);
229+
lower = 1.9521656132878236,
230+
upper = 2.8274429454898398,
231+
)
232+
@test @inferred(cdf(dist, 2.82)) === 1.0
233+
@test @inferred(ccdf(dist, 2.82)) === 0.0
234+
end

0 commit comments

Comments
 (0)