Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 73 additions & 13 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ for T in (:(Core.Method), :(Core.CodeInstance), :(Core.MethodInstance))
@eval function has_equal_data_internal(
x::$T, y::$T, equal_undefs::Bool, d::Dict{Tuple{UInt,UInt},Bool}
)
x == y
return x == y
end
end

Expand Down Expand Up @@ -777,11 +777,19 @@ function test_frule_performance(

# Test allocations in primal.
f(x...)
@test (count_allocs(f, x...)) == 0
@static if VERSION >= v"1.12"
@test count_allocs(f, x...) == 0
else
@test (@allocations f(x...)) == 0
end

# Test allocations in forwards-mode.
__forwards(rule, f_ḟ, x_ẋ...)
@test (count_allocs(__forwards, rule, f_ḟ, x_ẋ...)) == 0
@static if VERSION >= v"1.12"
@test count_allocs(__forwards, rule, f_ḟ, x_ẋ...) == 0
else
@test (@allocations __forwards(rule, f_ḟ, x_ẋ...)) == 0
end
end
end

Expand Down Expand Up @@ -820,15 +828,22 @@ function test_rrule_performance(

# Test allocations in primal.
f(x...)

@test count_allocs(f, x...) == 0
@static if VERSION >= v"1.12"
@test count_allocs(f, x...) == 0
else
@test (@allocations f(x...)) == 0
end

# Test allocations in round-trip.
f_f̄_fwds = to_fwds(f_f̄)
x_x̄_fwds = map(to_fwds, x_x̄)
__forwards_and_backwards(rule, f_f̄_fwds, x_x̄_fwds...)
count_allocs(__forwards_and_backwards, rule, f_f̄_fwds, x_x̄_fwds...)
@test count_allocs(__forwards_and_backwards, rule, f_f̄_fwds, x_x̄_fwds...) == 0
@static if VERSION >= v"1.12"
@test count_allocs(__forwards_and_backwards, rule, f_f̄_fwds, x_x̄_fwds...) == 0
else
@test (@allocations __forwards_and_backwards(rule, f_f̄_fwds, x_x̄_fwds...)) ==
0
end
end
end

Expand Down Expand Up @@ -1381,12 +1396,57 @@ function test_get_tangent_field_performance(t::Union{MutableTangent,Tangent})
end
end

# Function barrier to ensure inference in value types.
function count_allocs(f::F, x::Vararg{Any,N}) where {F,N}
test_hook(count_allocs, f, x...) do
@static if VERSION >= v"1.12-"
Base.allocations(f, x...)
else
# This faff is needed to work around the fact that `Base.allocations(() -> f(x...))` reports
# spurious allocations when any of `x` is a DataType (which necessitates a manual expansion
# of the splat), and `Base.allocations(f, x...)` also reports spurious allocations sometimes
# when the arguments `x` aren't interpolated (which necessitates a closure). The only way to
# make it work is to generate the code `Base.allocations(() -> f(x1, x2))`, etc., for each
# arity of `f` (up to a reasonable limit). It would be nicer to use a generated function for
# this, but generated functions can't contain closures.
function __allocs end
function count_allocs end
const __MAX_ARGS_ALLOCS = 10
@static if VERSION >= v"1.12-"
for nargs in 0:__MAX_ARGS_ALLOCS
args = [Symbol("x", i) for i in 1:nargs]
types = [Symbol("X", i) for i in 1:nargs]
sigs = [:($(args[i])::$(types[i])) for i in 1:nargs]
fexpr = quote
function count_allocs(f::F, $(sigs...)) where {F,$(types...)}
test_hook(count_allocs, f, $(args...)) do
stats = Base.gc_num()
@noinline clos = () -> f($(args...))
clos()
diff = Base.GC_Diff(Base.gc_num(), stats)
return Base.gc_alloc_count(diff)
end
end
end
eval(fexpr)
end
# Catch-all method for when there are more than __MAX_ARGS_ALLOCS arguments. The risk of
# using Vararg here on Julia 1.12 is that it leads to incomplete specialisation when any
# of the arguments are DataTypes, which can cause spurious allocations. See e.g.
# https://discourse.julialang.org/t/specialization-on-vararg-of-types/108251.
function count_allocs(f::F, x::Vararg{Any,N}) where {F,N}
test_hook(count_allocs, f, x...) do
# This method should only be hit if N > __MAX_ARGS_ALLOCS, but we can check
# nonetheless
N > __MAX_ARGS_ALLOCS &&
@warn "using varargs method for `count_allocs` since there were $N arguments; this may lead to spurious allocations being reported if any arguments are `DataType`s"
stats = Base.gc_num()
@noinline clos = () -> f(x...)
clos()
diff = Base.GC_Diff(Base.gc_num(), stats)
return Base.gc_alloc_count(diff)
end
end
else
# Fallback for Julia <= 1.11. Note that this will report spurious allocations if any of
# `x` are `DataType`s so it is best to just call `@allocations f(x...)` directly instead
# of `count_allocs(f, x...)`.
function count_allocs(f::F, x::Vararg{Any,N}) where {F,N}
test_hook(count_allocs, f, x...) do
@allocations f(x...)
end
end
Expand Down
Loading