diff --git a/src/test_utils.jl b/src/test_utils.jl index 83a868bf2d..6c3573b418 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -276,7 +276,7 @@ for T in (:(Core.Method), :(Core.CodeInstance), :(Core.MethodInstance)) @eval function has_equal_data_internal( x::$T, y::$T, equal_undefs::Bool, d::Dict{Tuple{UInt,UInt},Bool} ) - x == y + return x == y end end @@ -777,11 +777,19 @@ function test_frule_performance( # Test allocations in primal. f(x...) - @test (count_allocs(f, x...)) == 0 + @static if VERSION >= v"1.12" + @test count_allocs(f, x...) == 0 + else + @test (@allocations f(x...)) == 0 + end # Test allocations in forwards-mode. __forwards(rule, f_ḟ, x_ẋ...) - @test (count_allocs(__forwards, rule, f_ḟ, x_ẋ...)) == 0 + @static if VERSION >= v"1.12" + @test count_allocs(__forwards, rule, f_ḟ, x_ẋ...) == 0 + else + @test (@allocations __forwards(rule, f_ḟ, x_ẋ...)) == 0 + end end end @@ -820,15 +828,22 @@ function test_rrule_performance( # Test allocations in primal. f(x...) - - @test count_allocs(f, x...) == 0 + @static if VERSION >= v"1.12" + @test count_allocs(f, x...) == 0 + else + @test (@allocations f(x...)) == 0 + end # Test allocations in round-trip. f_f̄_fwds = to_fwds(f_f̄) x_x̄_fwds = map(to_fwds, x_x̄) __forwards_and_backwards(rule, f_f̄_fwds, x_x̄_fwds...) - count_allocs(__forwards_and_backwards, rule, f_f̄_fwds, x_x̄_fwds...) - @test count_allocs(__forwards_and_backwards, rule, f_f̄_fwds, x_x̄_fwds...) == 0 + @static if VERSION >= v"1.12" + @test count_allocs(__forwards_and_backwards, rule, f_f̄_fwds, x_x̄_fwds...) == 0 + else + @test (@allocations __forwards_and_backwards(rule, f_f̄_fwds, x_x̄_fwds...)) == + 0 + end end end @@ -1381,12 +1396,57 @@ function test_get_tangent_field_performance(t::Union{MutableTangent,Tangent}) end end -# Function barrier to ensure inference in value types. -function count_allocs(f::F, x::Vararg{Any,N}) where {F,N} - test_hook(count_allocs, f, x...) do - @static if VERSION >= v"1.12-" - Base.allocations(f, x...) - else +# This faff is needed to work around the fact that `Base.allocations(() -> f(x...))` reports +# spurious allocations when any of `x` is a DataType (which necessitates a manual expansion +# of the splat), and `Base.allocations(f, x...)` also reports spurious allocations sometimes +# when the arguments `x` aren't interpolated (which necessitates a closure). The only way to +# make it work is to generate the code `Base.allocations(() -> f(x1, x2))`, etc., for each +# arity of `f` (up to a reasonable limit). It would be nicer to use a generated function for +# this, but generated functions can't contain closures. +function __allocs end +function count_allocs end +const __MAX_ARGS_ALLOCS = 10 +@static if VERSION >= v"1.12-" + for nargs in 0:__MAX_ARGS_ALLOCS + args = [Symbol("x", i) for i in 1:nargs] + types = [Symbol("X", i) for i in 1:nargs] + sigs = [:($(args[i])::$(types[i])) for i in 1:nargs] + fexpr = quote + function count_allocs(f::F, $(sigs...)) where {F,$(types...)} + test_hook(count_allocs, f, $(args...)) do + stats = Base.gc_num() + @noinline clos = () -> f($(args...)) + clos() + diff = Base.GC_Diff(Base.gc_num(), stats) + return Base.gc_alloc_count(diff) + end + end + end + eval(fexpr) + end + # Catch-all method for when there are more than __MAX_ARGS_ALLOCS arguments. The risk of + # using Vararg here on Julia 1.12 is that it leads to incomplete specialisation when any + # of the arguments are DataTypes, which can cause spurious allocations. See e.g. + # https://discourse.julialang.org/t/specialization-on-vararg-of-types/108251. + function count_allocs(f::F, x::Vararg{Any,N}) where {F,N} + test_hook(count_allocs, f, x...) do + # This method should only be hit if N > __MAX_ARGS_ALLOCS, but we can check + # nonetheless + N > __MAX_ARGS_ALLOCS && + @warn "using varargs method for `count_allocs` since there were $N arguments; this may lead to spurious allocations being reported if any arguments are `DataType`s" + stats = Base.gc_num() + @noinline clos = () -> f(x...) + clos() + diff = Base.GC_Diff(Base.gc_num(), stats) + return Base.gc_alloc_count(diff) + end + end +else + # Fallback for Julia <= 1.11. Note that this will report spurious allocations if any of + # `x` are `DataType`s so it is best to just call `@allocations f(x...)` directly instead + # of `count_allocs(f, x...)`. + function count_allocs(f::F, x::Vararg{Any,N}) where {F,N} + test_hook(count_allocs, f, x...) do @allocations f(x...) end end