Skip to content

Commit

Permalink
Merge pull request #28 from JuliaMath/CR/dual_svector
Browse files Browse the repository at this point in the history
Loosen the HCubature element type bound for SVectors to allow Dual numbers
  • Loading branch information
stevengj authored Jul 2, 2020
2 parents 2de7a27 + c2f6478 commit 66074a8
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 18 deletions.
26 changes: 12 additions & 14 deletions src/HCubature.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ end
cubrule(::Val{0}, ::Type{T}) where {T} = Trivial()
countevals(::Trivial) = 1

function hcubature_(f, a::SVector{n,T}, b::SVector{n,T}, norm, rtol_, atol, maxevals, initdiv) where {n, T<:AbstractFloat}
function hcubature_(f, a::SVector{n,T}, b::SVector{n,T}, norm, rtol_, atol, maxevals, initdiv) where {n, T<:Real}
rtol = rtol_ == 0 == atol ? sqrt(eps(T)) : rtol_
(rtol < 0 || atol < 0) && throw(ArgumentError("invalid negative tolerance"))
maxevals < 0 && throw(ArgumentError("invalid negative maxevals"))
Expand Down Expand Up @@ -121,19 +121,15 @@ function hcubature_(f, a::SVector{n,T}, b::SVector{n,T}, norm, rtol_, atol, maxe
return I,E
end

function hcubature_(f, a::SVector{n,T}, b::SVector{n,S},
norm, rtol, atol, maxevals, initdiv) where {n, T<:Real, S<:Real}
function hcubature_(f, a::AbstractVector{T}, b::AbstractVector{S},
norm, rtol, atol, maxevals, initdiv) where {T<:Real, S<:Real}
length(a) == length(b) || throw(DimensionMismatch("endpoints $a and $b must have the same length"))
F = float(promote_type(T, S))
return hcubature_(f, SVector{n,F}(a), SVector{n,F}(b), norm, rtol, atol, maxevals, initdiv)
return hcubature_(f, SVector{length(a),F}(a), SVector{length(a),F}(b), norm, rtol, atol, maxevals, initdiv)
end
function hcubature_(f, a::AbstractVector{<:Real}, b::AbstractVector{<:Real},
norm, rtol, atol, maxevals, initdiv)
n = length(a)
n == length(b) || throw(DimensionMismatch("endpoints $a and $b must have the same length"))
hcubature_(f, SVector{n}(a), SVector{n}(b), norm, rtol, atol, maxevals, initdiv)
function hcubature_(f, a::Tuple{Vararg{Real,n}}, b::Tuple{Vararg{Real,n}}, norm, rtol, atol, maxevals, initdiv) where {n}
hcubature_(f, SVector{n}(float.(a)), SVector{n}(float.(b)), norm, rtol, atol, maxevals, initdiv)
end
hcubature_(f, a::Tuple{Vararg{Real,n}}, b::Tuple{Vararg{Real,n}}, norm, rtol, atol, maxevals, initdiv) where {n} =
hcubature_(f, SVector{n}(a), SVector{n}(b), norm, rtol, atol, maxevals, initdiv)

"""
hcubature(f, a, b; norm=norm, rtol=sqrt(eps), atol=0, maxevals=typemax(Int), initdiv=1)
Expand Down Expand Up @@ -197,8 +193,10 @@ Alternatively, for 1d integrals you can import the [`QuadGK`](@ref) module
and call the [`quadgk`](@ref) function, which provides additional flexibility
e.g. in choosing the order of the quadrature rule.
"""
hquadrature(f, a, b; norm=norm, rtol::Real=0, atol::Real=0,
maxevals::Integer=typemax(Int), initdiv::Integer=1) =
hcubature_(x -> f(x[1]), SVector(a), SVector(b), norm, rtol, atol, maxevals, initdiv)
function hquadrature(f, a::T, b::S; norm=norm, rtol::Real=0, atol::Real=0,
maxevals::Integer=typemax(Int), initdiv::Integer=1) where {T<:Real, S<:Real}
F = float(promote_type(T, S))
hcubature_(x -> f(x[1]), SVector{1,F}(a), SVector{1,F}(b), norm, rtol, atol, maxevals, initdiv)
end

end # module
4 changes: 2 additions & 2 deletions src/gauss-kronrod.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Gauss-Kronrod quadrature rule, via the QuadGK package, since Genz-Malik
# rule does not handle the 1d case. We will just use a fixed-order (7) G-K rule.

struct GaussKronrod{T<:AbstractFloat}
struct GaussKronrod{T<:Real}
x::Vector{T}
w::Vector{T}
wg::Vector{T}
Expand All @@ -11,7 +11,7 @@ end
# call QuadGK.kronrod every time.
const gkcache = Dict{Type, GaussKronrod}()

function GaussKronrod(::Type{T}) where {T<:AbstractFloat}
function GaussKronrod(::Type{T}) where {T<:Real}
haskey(gkcache, T) && return gkcache[T]::GaussKronrod{T}
gkcache[T] = g = GaussKronrod{T}(QuadGK.kronrod(T,7)...)
return g
Expand Down
4 changes: 2 additions & 2 deletions src/genz-malik.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ end
to an `n`-dimensional Genz-Malik cubature rule over coordinates
of type `T`.
"""
struct GenzMalik{n,T<:AbstractFloat}
struct GenzMalik{n,T<:Real}
p::NTuple{4,Vector{SVector{n,T}}} # points for the last 4 G-M weights
w::NTuple{5,T} # weights for the 5 terms in the G-M rule
w′::NTuple{4,T} # weights for the embedded lower-degree rule
Expand All @@ -72,7 +72,7 @@ const gmcache = Dict{Tuple{Int,Type}, GenzMalik}()
Construct an n-dimensional Genz-Malik rule for coordinates of type `T`.
"""
function GenzMalik(v::Val{n}, ::Type{T}=Float64) where {n, T<:AbstractFloat}
function GenzMalik(v::Val{n}, ::Type{T}=Float64) where {n, T<:Real}
haskey(gmcache, (n,T)) && return gmcache[n,T]::GenzMalik{n,T}

n < 2 && throw(ArgumentError("invalid dimension $n: GenzMalik rule requires dimension > 2"))
Expand Down

0 comments on commit 66074a8

Please sign in to comment.