Skip to content

Commit

Permalink
Merge pull request #439 from SomTambe/master
Browse files Browse the repository at this point in the history
Add gradients for FilledExtrapolations
  • Loading branch information
mkitti authored Jul 10, 2021
2 parents c5bd734 + f3b2d7f commit a8370bd
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/extrapolation/filled.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ end
return ret
end

@inline function Interpolations.gradient(etp::FilledExtrapolation{T}, x::Vararg{Number,N}) where {T,N}
itp = parent(etp)
if checkbounds(Bool, itp, x...)
gradient(itp, x...)
else
ptype = promote_type(eltype(itp), map(typeof, x)...)
return zeros(SVector{ndims(etp), ptype})
end
end

expand_index_resid_etp(deg, fillvalue, (l, u), x, etp::FilledExtrapolation, xN) =
(l <= x <= u || Base.throw_boundserror(etp, xN))

Expand Down
33 changes: 33 additions & 0 deletions test/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,37 @@ using Test, Interpolations, DualNumbers, LinearAlgebra
end
end
end

@testset "Extrapolations" begin
# out of bounds extrapolation check where fillvalue=constant
dims = [1, 2, 3]
@testset "Constant Fillvalue" begin
dtypes = [Float32, Float64]
for dt in dtypes
for dm in dims
itp = interpolate(rand(dt, ntuple(_ -> 5, dm)...), BSpline(Linear()))
etp = extrapolate(itp, rand(dt)) # TODO: Add tests for different boundary conditions in extrapolate
cs = ntuple(_->3, dm)
@test @inferred(Interpolations.gradient(etp, cs...)) == @inferred(Interpolations.gradient(itp, cs...))
for ds in [ntuple(_->-0.5, dm), ntuple(_->5.5, dm), ntuple(_->6, dm)]
@test all(iszero, @inferred(Interpolations.gradient(etp, ds...)))
end
end
end
end
@testset "Boundary Conditions" begin
BC = (Flat, Line, Periodic, Reflect, Natural)
for dm in dims
for bc in BC
itp = interpolate(rand(5, 5), BSpline(Linear()))
etp = extrapolate(itp, bc()) # TODO: Add tests for different boundary conditions in extrapolate
cs = ntuple(_->3, 2)
@test @inferred(Interpolations.gradient(etp, cs...)) == @inferred(Interpolations.gradient(itp, cs...))
for ds in [ntuple(_->-0.5, 2), ntuple(_->5.5, 2), ntuple(_->6, 2)]
@test size(@inferred(Interpolations.gradient(etp, ds...))) == (2,)
end
end
end
end
end
end

0 comments on commit a8370bd

Please sign in to comment.