Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesTestUtils"
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
version = "0.7.1"
version = "0.7.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
8 changes: 4 additions & 4 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ julia> using ChainRulesTestUtils;

julia> test_frule(two2three, 3.33, -7.77);
Test Summary: | Pass Total
test_frule: two2three on Float64,Float64 | 5 5
test_frule: two2three on Float64,Float64 | 6 6
```

### Testing the `rrule`
Expand All @@ -71,7 +71,7 @@ The call will test the `rrule` for function `f` at the point `x`, and similarly
```jldoctest ex; output = false
julia> test_rrule(two2three, 3.33, -7.77);
Test Summary: | Pass Total
test_rrule: two2three on Float64,Float64 | 6 6
test_rrule: two2three on Float64,Float64 | 7 7
```

## Scalar example
Expand All @@ -98,11 +98,11 @@ call.
```jldoctest ex; output = false
julia> test_scalar(relu, 0.5);
Test Summary: | Pass Total
test_scalar: relu at 0.5 | 7 7
test_scalar: relu at 0.5 | 9 9

julia> test_scalar(relu, -0.5);
Test Summary: | Pass Total
test_scalar: relu at -0.5 | 7 7
test_scalar: relu at -0.5 | 9 9
```

## Specifying Tangents
Expand Down
24 changes: 13 additions & 11 deletions src/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(),
end

"""
test_frule(f, inputs...; kwargs...)
test_frule(f, args..; kwargs...)

# Arguments
- `f`: Function for which the `frule` should be tested.
- `inputs` either the primal inputs `x`, or primals and their tangents: `x ⊢ ẋ`
- `args` either the primal args `x`, or primals and their tangents: `x ⊢ ẋ`
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
- `ẋ`: differential w.r.t. `x`, will be generated automatically if not provided
Non-differentiable arguments, such as indices, should have `ẋ` set as `NoTangent()`.
Expand All @@ -87,7 +87,7 @@ end
"""
function test_frule(
f,
inputs...;
args...;
output_tangent=Auto(),
fdm=_fdm,
check_inferred::Bool=true,
Expand All @@ -99,18 +99,18 @@ function test_frule(
# To simplify some of the calls we make later lets group the kwargs for reuse
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)

@testset "test_frule: $f on $(_string_typeof(inputs))" begin
@testset "test_frule: $f on $(_string_typeof(args))" begin
_ensure_not_running_on_functor(f, "test_frule")

xẋs = auto_primal_and_tangent.(inputs)
xẋs = auto_primal_and_tangent.(args)
xs = primal.(xẋs)
ẋs = tangent.(xẋs)
if check_inferred && _is_inferrable(f, deepcopy(xs)...; deepcopy(fkwargs)...)
_test_inferred(frule, (NoTangent(), deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
end
res = frule((NoTangent(), deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
res === nothing && throw(MethodError(frule, typeof((f, xs...))))
res isa Tuple || error("The frule should return (y, ∂y), not $res.")
@test_msg "The frule should return (y, ∂y), not $res." res isa Tuple{Any,Any}
Ω_ad, dΩ_ad = res
Ω = f(deepcopy(xs)...; deepcopy(fkwargs)...)
test_approx(Ω_ad, Ω; isapprox_kwargs...)
Expand All @@ -135,11 +135,11 @@ function test_frule(
end

"""
test_rrule(f, inputs...; kwargs...)
test_rrule(f, args...; kwargs...)

# Arguments
- `f`: Function to which rule should be applied.
- `inputs` either the primal inputs `x`, or primals and their tangents: `x ⊢ ẋ`
- `args` either the primal args `x`, or primals and their tangents: `x ⊢ ẋ`
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
- `x̄`: currently accumulated cotangent, will be generated automatically if not provided
Non-differentiable arguments, such as indices, should have `x̄` set as `NoTangent()`.
Expand All @@ -155,7 +155,7 @@ end
"""
function test_rrule(
f,
inputs...;
args...;
output_tangent=Auto(),
fdm=_fdm,
check_inferred::Bool=true,
Expand All @@ -167,11 +167,11 @@ function test_rrule(
# To simplify some of the calls we make later lets group the kwargs for reuse
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)

@testset "test_rrule: $f on $(_string_typeof(inputs))" begin
@testset "test_rrule: $f on $(_string_typeof(args))" begin
_ensure_not_running_on_functor(f, "test_rrule")

# Check correctness of evaluation.
xx̄s = auto_primal_and_tangent.(inputs)
xx̄s = auto_primal_and_tangent.(args)
xs = primal.(xx̄s)
accumulated_x̄ = tangent.(xx̄s)
if check_inferred && _is_inferrable(f, xs...; fkwargs...)
Expand All @@ -191,6 +191,8 @@ function test_rrule(
∂self = ∂s[1]
x̄s_ad = ∂s[2:end]
@test ∂self === NoTangent() # No internal fields
msg = "The pullback should return 1 cotangent for each primal input."
@test_msg msg length(x̄s_ad) == length(args)

# Correctness testing via finite differencing.
# TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
Expand Down
16 changes: 16 additions & 0 deletions test/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,22 @@ end
@test fails(() -> test_frule(my_identity2, 2.2))
@test fails(() -> test_rrule(my_identity2, 2.2))
end

@testset "wrong number of outputs #167" begin
foo(x, y) = x + 2y

function ChainRulesCore.frule((_, ẋ, ẏ), ::typeof(foo), x, y)
return foo(x, y), ẋ + 2ẏ, NoTangent() # extra derivative
end

function ChainRulesCore.rrule(::typeof(foo), x, y)
foo_pullback(dz) = NoTangent(), dz # missing derivative
return foo(x,y), foo_pullback
end

@test fails(() -> test_frule(foo, 2.1, 2.1))
@test fails(() -> test_rrule(foo, 21.0, 32.0))
end
end

@testset "Tuple primal that is not equal to differential backing" begin
Expand Down