Skip to content

Commit 00b0efa

Browse files
committed
Avoid stack overflows with non-standard float types; closes JuliaMath#76
1 parent 68e9db0 commit 00b0efa

File tree

3 files changed

+109
-7
lines changed

3 files changed

+109
-7
lines changed

Project.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ OpenLibm_jll = "05823500-19ac-5b8b-9628-191a04bc5112"
1111
julia = "1.6"
1212

1313
[extras]
14+
DoubleFloats = "497a8b3b-efae-58df-a0af-a86822472b78"
1415
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1516

1617
[targets]
17-
test = ["Test"]
18+
test = ["DoubleFloats", "Test"]

src/NaNMath.jl

+43-3
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,67 @@ module NaNMath
33
using OpenLibm_jll
44
const libm = OpenLibm_jll.libopenlibm
55

6+
67
for f in (:sin, :cos, :tan, :asin, :acos, :acosh, :atanh, :log, :log2, :log10,
78
:lgamma, :log1p)
89
@eval begin
910
($f)(x::Float64) = ccall(($(string(f)),libm), Float64, (Float64,), x)
1011
($f)(x::Float32) = ccall(($(string(f,"f")),libm), Float32, (Float32,), x)
11-
($f)(x::Real) = ($f)(float(x))
12+
($f)(x::Float16) = Float16(($f)(Float32(x)))
13+
function ($f)(x::Real)
14+
xf = float(x)
15+
x === xf && throw(MethodError($f, (x,)))
16+
return ($f)(xf)
17+
end
1218
end
1319
end
20+
sin(x::T) where {T<:AbstractFloat} = isfinite(x) ? Base.sin(x) : T(NaN)
21+
cos(x::T) where {T<:AbstractFloat} = isfinite(x) ? Base.cos(x) : T(NaN)
22+
tan(x::T) where {T<:AbstractFloat} = isfinite(x) ? Base.tan(x) : T(NaN)
23+
asin(x::T) where {T<:AbstractFloat} = abs(x) > 1 ? T(NaN) : Base.asin(x)
24+
acos(x::T) where {T<:AbstractFloat} = abs(x) > 1 ? T(NaN) : Base.acos(x)
25+
acosh(x::T) where {T<:AbstractFloat} = x < 1 ? T(NaN) : Base.acosh(x)
26+
atanh(x::T) where {T<:AbstractFloat} = abs(x) > 1 ? T(NaN) : Base.atanh(x)
27+
log(x::T) where {T<:AbstractFloat} = x < 0 ? T(NaN) : Base.log(x)
28+
log2(x::T) where {T<:AbstractFloat} = x < 0 ? T(NaN) : Base.log2(x)
29+
log10(x::T) where {T<:AbstractFloat} = x < 0 ? T(NaN) : Base.log10(x)
30+
# lgamma does not have a Base version; the MethodError above will suffice
31+
log1p(x::T) where {T<:AbstractFloat} = x < -1 ? T(NaN) : Base.log1p(x)
32+
1433

1534
# Would be more efficient to remove the domain check in Base.sqrt(),
1635
# but this doesn't seem easy to do.
1736
sqrt(x::T) where {T<:AbstractFloat} = x < 0.0 ? T(NaN) : Base.sqrt(x)
18-
sqrt(x::Real) = sqrt(float(x))
37+
function sqrt(x::Real)
38+
xf = float(x)
39+
x === xf && throw(MethodError(sqrt, (x,)))
40+
return sqrt(xf)
41+
end
1942

2043
# Don't override built-in ^ operator
2144
pow(x::Float64, y::Float64) = ccall((:pow,libm), Float64, (Float64,Float64), x, y)
2245
pow(x::Float32, y::Float32) = ccall((:powf,libm), Float32, (Float32,Float32), x, y)
46+
pow(x::Float16, y::Float16) = Float16(pow(Float32(x), Float32(y)))
2347
# We `promote` first before converting to floating pointing numbers to ensure that
2448
# e.g. `pow(::Float32, ::Int)` ends up calling `pow(::Float32, ::Float32)`
2549
pow(x::Number, y::Number) = pow(promote(x, y)...)
26-
pow(x::T, y::T) where {T<:Number} = pow(float(x), float(y))
50+
function pow(x::T, y::T) where {T<:Number}
51+
xf = float(x)
52+
yf = float(y)
53+
x === xf && y === yf && throw(MethodError(pow, (x,y)))
54+
return pow(xf, yf)
55+
end
56+
function pow(x::T, y::T) where {T<:AbstractFloat}
57+
try
58+
return x^y
59+
catch e
60+
if isa(e, DomainError)
61+
return T(NaN)
62+
else
63+
rethrow(e)
64+
end
65+
end
66+
end
2767

2868
"""
2969
NaNMath.sum(A)

test/runtests.jl

+64-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,56 @@
11
using NaNMath
22
using Test
3+
using DoubleFloats
4+
5+
6+
# https://github.com/JuliaMath/NaNMath.jl/issues/76
7+
@test_throws MethodError NaNMath.pow(1.0, 1.0+im)
8+
9+
10+
for T in (Float64, Float32, Float16, BigFloat)
11+
for f in (:sin, :cos, :tan, :asin, :acos, :acosh, :atanh, :log, :log2, :log10,
12+
:log1p) # Note: do :lgamma separately because it can't handle BigFloat
13+
@eval begin
14+
@test NaNMath.$f($T(2//3)) isa $T
15+
@test NaNMath.$f($T(3//2)) isa $T
16+
@test NaNMath.$f($T(-2//3)) isa $T
17+
@test NaNMath.$f($T(-3//2)) isa $T
18+
@test NaNMath.$f($T(Inf)) isa $T
19+
@test NaNMath.$f($T(-Inf)) isa $T
20+
end
21+
end
22+
end
23+
for T in (Float64, Float32, Float16)
24+
@test NaNMath.lgamma(T(2//3)) isa T
25+
@test NaNMath.lgamma(T(3//2)) isa T
26+
@test NaNMath.lgamma(T(-2//3)) isa T
27+
@test NaNMath.lgamma(T(-3//2)) isa T
28+
@test NaNMath.lgamma(T(Inf)) isa T
29+
@test NaNMath.lgamma(T(-Inf)) isa T
30+
end
31+
@test_throws MethodError NaNMath.lgamma(BigFloat(2//3))
332

433
@test isnan(NaNMath.log(-10))
34+
@test isnan(NaNMath.log(-10f0))
35+
@test isnan(NaNMath.log(Float16(-10)))
536
@test isnan(NaNMath.log1p(-100))
37+
@test isnan(NaNMath.log1p(-100f0))
38+
@test isnan(NaNMath.log1p(Float16(-100)))
639
@test isnan(NaNMath.pow(-1.5,2.3))
740
@test isnan(NaNMath.pow(-1.5f0,2.3f0))
841
@test isnan(NaNMath.pow(-1.5,2.3f0))
942
@test isnan(NaNMath.pow(-1.5f0,2.3))
43+
@test isnan(NaNMath.pow(Float16(-1.5),Float16(2.3)))
44+
@test isnan(NaNMath.pow(Float16(-1.5),2.3))
45+
@test isnan(NaNMath.pow(-1.5,Float16(2.3)))
46+
@test isnan(NaNMath.pow(Float16(-1.5),2.3f0))
47+
@test isnan(NaNMath.pow(-1.5f0,Float16(2.3)))
48+
@test isnan(NaNMath.pow(-1.5f0,BigFloat(2.3)))
49+
@test isnan(NaNMath.pow(BigFloat(-1.5),BigFloat(2.3)))
50+
@test isnan(NaNMath.pow(BigFloat(-1.5),2.3f0))
51+
@test isnan(NaNMath.pow(-1.5f0,Double64(2.3)))
52+
@test isnan(NaNMath.pow(Double64(-1.5),Double64(2.3)))
53+
@test isnan(NaNMath.pow(Double64(-1.5),2.3f0))
1054
@test NaNMath.pow(-1,2) isa Float64
1155
@test NaNMath.pow(-1.5f0,2) isa Float32
1256
@test NaNMath.pow(-1.5f0,2//1) isa Float32
@@ -16,11 +60,28 @@ using Test
1660
@test NaNMath.pow(-1.5,2//1) isa Float64
1761
@test NaNMath.pow(-1.5,2.3f0) isa Float64
1862
@test NaNMath.pow(-1.5,2.3) isa Float64
19-
@test isnan(NaNMath.sqrt(-5))
63+
@test NaNMath.pow(Float16(-1.5),2.3) isa Float64
64+
@test NaNMath.pow(Float16(-1.5),Float16(2.3)) isa Float16
65+
@test NaNMath.pow(-1.5,Float16(2.3)) isa Float64
66+
@test NaNMath.pow(Float16(-1.5),2.3f0) isa Float32
67+
@test NaNMath.pow(-1.5f0,Float16(2.3)) isa Float32
68+
@test NaNMath.pow(-1.5f0,BigFloat(2.3)) isa BigFloat
69+
@test NaNMath.pow(BigFloat(-1.5),BigFloat(2.3)) isa BigFloat
70+
@test NaNMath.pow(BigFloat(-1.5),2.3f0) isa BigFloat
71+
@test NaNMath.pow(-1.5f0,Double64(2.3)) isa Double64
72+
@test NaNMath.pow(Double64(-1.5),Double64(2.3)) isa Double64
73+
@test NaNMath.pow(Double64(-1.5),2.3f0) isa Double64
74+
@test NaNMath.sqrt(-5) isa Float64
2075
@test NaNMath.sqrt(5) == Base.sqrt(5)
76+
@test NaNMath.sqrt(-5f0) isa Float32
77+
@test NaNMath.sqrt(5f0) == Base.sqrt(5f0)
78+
@test NaNMath.sqrt(Float16(-5)) isa Float16
79+
@test NaNMath.sqrt(Float16(5)) == Base.sqrt(Float16(5))
80+
@test NaNMath.sqrt(BigFloat(-5)) isa BigFloat
81+
@test NaNMath.sqrt(BigFloat(5)) == Base.sqrt(BigFloat(5))
2182
@test isnan(NaNMath.sqrt(-3.2f0)) && NaNMath.sqrt(-3.2f0) isa Float32
22-
@test isnan(NaNMath.sqrt(-BigFloat(7.0))) && NaNMath.sqrt(-BigFloat(7.0)) isa BigFloat
23-
@test isnan(NaNMath.sqrt(-7)) && NaNMath.sqrt(-7) isa Float64
83+
@test isnan(NaNMath.sqrt(-BigFloat(7.0))) && NaNMath.sqrt(-BigFloat(7.0)) isa BigFloat
84+
@test isnan(NaNMath.sqrt(-7)) && NaNMath.sqrt(-7) isa Float64
2485
@inferred NaNMath.sqrt(5)
2586
@inferred NaNMath.sqrt(5.0)
2687
@inferred NaNMath.sqrt(5.0f0)

0 commit comments

Comments
 (0)