Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesTestUtils"
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
version = "0.7.1"
version = "0.7.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
6 changes: 5 additions & 1 deletion src/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ function test_frule(
end
res = frule((NoTangent(), deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
res === nothing && throw(MethodError(frule, typeof((f, xs...))))
res isa Tuple || error("The frule should return (y, ∂y), not $res.")
msg = "The frule should return (y, ∂y), not $res."
@test_msg msg res isa Tuple && length(res) == 2
Ω_ad, dΩ_ad = res
Ω = f(deepcopy(xs)...; deepcopy(fkwargs)...)
test_approx(Ω_ad, Ω; isapprox_kwargs...)
Expand Down Expand Up @@ -191,6 +192,9 @@ function test_rrule(
∂self = ∂s[1]
x̄s_ad = ∂s[2:end]
@test ∂self === NoTangent() # No internal fields
msg = "The pullback must return `(∂self, ∂args...)`, where `length(∂args) == " *
"length(args)`, and `args` are the arguments of the primal function."
@test_msg msg length(x̄s_ad) == length(inputs)

# Correctness testing via finite differencing.
# TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
Expand Down
18 changes: 18 additions & 0 deletions test/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,24 @@ end
@test fails(() -> test_frule(my_identity2, 2.2))
@test fails(() -> test_rrule(my_identity2, 2.2))
end

@testset "wrong number of outputs #167" begin
foo(x, y) = x + 2y

function ChainRulesCore.frule((_, ẋ, ẏ), ::typeof(foo), x, y)
return foo(x, y), ẋ + 2ẏ, NoTangent() # extra derivative
#return foo(x, y), ẋ + 2ẏ # correct expression
end

function ChainRulesCore.rrule(::typeof(foo), x, y)
foo_pullback(dz) = NoTangent(), dz # missing derivative
#foo_pullback(dz) = NoTangent(), dz, 2dz # correct expression
return foo(x,y), foo_pullback
end

@test fails(() -> test_frule(foo, 2.1, 2.1))
@test fails(() -> test_rrule(foo, 21.0, 32.0))
end
end

@testset "Tuple primal that is not equal to differential backing" begin
Expand Down