Skip to content

reinstate reactant's tests#2609

Closed
CarloLucibello wants to merge 2 commits intomasterfrom
cl/reactant
Closed

reinstate reactant's tests#2609
CarloLucibello wants to merge 2 commits intomasterfrom
cl/reactant

Conversation

@CarloLucibello
Copy link
Member

No description provided.

@CarloLucibello CarloLucibello marked this pull request as draft May 30, 2025 00:09
@wsmoses
Copy link
Contributor

wsmoses commented Jul 3, 2025

@CarloLucibello do you know when this started failing? This implies that somewhere the function var"#11#23"{var"#loss#37"} captures a global variable or something

@wsmoses
Copy link
Contributor

wsmoses commented Sep 11, 2025

bump

@CarloLucibello
Copy link
Member Author

@wsmoses
Copy link
Contributor

wsmoses commented Dec 25, 2025

My earlier comment still applies. All data needs to be passed in as an argument, not a captured global

@CarloLucibello
Copy link
Member Author

I can't see any global....

@CarloLucibello
Copy link
Member Author

Removing the if test_reactant branching from the code below, the error disappears:

using Flux
using Enzyme
# # using CUDA, cuDNN
using Reactant
using MLDataDevices
using Test 
using Random, Statistics, LinearAlgebra
using Zygote
using Functors

Reactant.set_default_backend("cpu")

function enzyme_withgradient(f, x...)
    args = []
    for x in x
        if x isa Number
            push!(args, Enzyme.Active(x))
        else
            push!(args, Enzyme.Duplicated(x, Enzyme.make_zero(x)))
        end
    end
    ad = Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal)
    ret = Enzyme.autodiff(ad, Enzyme.Const(f), Enzyme.Active, args...)
    g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
    return ret[2], g
end

reactant_withgradient(f, x...) = Reactant.@jit enzyme_withgradient(f, x...)

reactant_loss(loss, x...) = Reactant.@jit loss(x...)

function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4)
    fmapstructure_with_path(a, b) do kp, x, y
        # @show kp
        if x isa AbstractArray
            @test x  y rtol=rtol atol=atol
        elseif x isa Number
            @test x  y rtol=rtol atol=atol
        end
    end
end

function test_reactant_gradients(
            f, 
            xs...;
            rtol=1e-4, atol=1e-4,
            test_reactant = true,
            test_grad_f = true,
            test_grad_x = true,
            loss = (f, xs...) -> mean(f(xs...)),
            )

    Flux.trainmode!(f)

    l = loss(f, xs...)
    @test l isa Number

    if test_reactant
        reactant_dev = MLDataDevices.reactant_device(force=true)
        cpu_dev = cpu_device()
        xs_re = xs |> reactant_dev
        f_re = f |> reactant_dev
        l_re = reactant_loss(loss, f_re, xs_re...)
        @test l  l_re rtol=rtol atol=atol
    end

    if test_grad_x
        y, g = Zygote.withgradient((xs...) -> loss(f, xs...), xs...)

        if test_reactant
            y_re, g_re = reactant_withgradient((xs...) -> loss(f_re, xs...), xs_re...)
            @test y  y_re rtol=rtol atol=atol
            check_equal_leaves(g_re |> cpu_dev, g; rtol, atol)
        end
    end

    if test_grad_f
        y, g = Zygote.withgradient(f -> loss(f, xs...), f)

        if test_reactant
            y_re, g_re = reactant_withgradient(f -> loss(f, xs_re...), f_re)
            @test y  y_re rtol=rtol atol=atol
            check_equal_leaves(g_re |> cpu_dev, g; rtol, atol)
        end
    end
    return true
end


m, x = Dense(2=>4), randn(Float32, 2)
loss(m, x) = mean(m(x))
r_dev = reactant_device(force=true)
r_m = m |> r_dev
r_x = x |> r_dev
@jit loss(r_m, r_x)
@jit enzyme_withgradient(loss, r_m, r_x)
@jit enzyme_withgradient(r_x -> loss(r_m, r_x), r_x)
@test test_reactant_gradients(m, x; loss) # ERROR

@wsmoses is that something that can be fixed or worked around?

@CarloLucibello
Copy link
Member Author

Other two failures:

m, x = Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 1)
loss(m, x) = mean(m(x))
r_dev = reactant_device(force=true)
r_m = m |> r_dev
r_x = x |> r_dev
@jit loss(r_m, r_x)
y, gs = Zygote.withgradient(loss, m, x)
r_y, r_gs = @jit enzyme_withgradient(loss, r_m, r_x)
@test y  r_y rtol=1e-4 atol=1e-4
check_equal_leaves(gs, r_gs; rtol=1e-4, atol=1e-4)

@jit enzyme_withgradient(r_x -> loss(r_m, r_x), r_x)
# ERROR: MethodError: no method matching overloaded_conv!(::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…}, ::ConcretePJRTArray{…}, ::DenseConvDims{…})
# The function `overloaded_conv!` exists, but no method is defined for this combination of argument types.
# Closest candidates are:
#   overloaded_conv!(::AbstractArray{Reactant.TracedRNumber{T}, N}, ::AbstractArray{Reactant.TracedRNumber{T2}, N}, ::AbstractArray{Reactant.TracedRNumber{T3}, N}, ::DenseConvDims) where {T, T2, T3, N}
#    @ ReactantNNlibExt ~/.julia/packages/Reactant/UjybB/ext/ReactantNNlibExt/Implementations.jl:40
# Stacktrace:
#   [1] conv!(y::Reactant.TracedRArray{…}, x::Reactant.TracedRArray{…}, w::ConcretePJRTArray{…}, cdims::DenseConvDims{…}; kwargs::@Kwargs{})
#     @ ReactantNNlibExt ~/.julia/packages/Reactant/UjybB/ext/ReactantNNlibExt/Overlay.jl:3
#   [2] conv!
#     @ ~/.julia/packages/Reactant/UjybB/ext/ReactantNNlibExt/Overlay.jl:1 [inlined]
#   [3] (::Nothing)(none::typeof(conv!), none::Reactant.TracedRArray{…}, none::Reactant.TracedRArray{…}, none::ConcretePJRTArray{…}, none::DenseConvDims{…})
#     @ Reactant ./<missing>:0
#   [4] call_with_reactant
#     @ ./none:-1 [inlined]
#   [5] call_with_reactant(::Reactant.MustThrowError, ::typeof(conv!), ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…}, ::ConcretePJRTArray{…}, ::DenseConvDims{…})
#     @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
#   [6] #conv#124
#     @ ~/.julia/packages/NNlib/ytFya/src/conv.jl:88 [inlined]
#   [7] (::Nothing)(none::NNlib.var"##conv#124", none::@Kwargs{}, none::typeof(conv), none::Reactant.TracedRArray{…}, none::ConcretePJRTArray{…}, none::DenseConvDims{…})
#     @ Reactant ./<missing>:0
#   [8] call_with_reactant
#     @ ./none:-1 [inlined]
#   [9] call_with_reactant(::Reactant.MustThrowError, ::NNlib.var"##conv#124", ::@Kwargs{}, ::typeof(conv), ::Reactant.TracedRArray{…}, ::ConcretePJRTArray{…}, ::DenseConvDims{…})
#     @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
#  [10] conv
#     @ ~/.julia/packages/NNlib/ytFya/src/conv.jl:83 [inlined]
#  [11] (::Nothing)(none::typeof(conv), none::Reactant.TracedRArray{…}, none::ConcretePJRTArray{…}, none::DenseConvDims{…})
#     @ Reactant ./<missing>:0
#  [12] call_with_reactant
#     @ ./none:-1 [inlined]
#  [13] call_with_reactant(::Reactant.MustThrowError, ::typeof(conv), ::Reactant.TracedRArray{…}, ::ConcretePJRTArray{…}, ::DenseConvDims{…})
#     @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
#  [14] Conv
#     @ ~/.julia/packages/Flux/WMUyh/src/layers/conv.jl:201 [inlined]
#  [15] (::Nothing)(none::Conv{…}, none::Reactant.TracedRArray{…})
#     @ Reactant ./<missing>:0
#  [16] call_with_reactant
#     @ ./none:-1 [inlined]
#  [17] call_with_reactant(::Reactant.MustThrowError, ::Conv{…}, ::Reactant.TracedRArray{…})
#     @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
#  [18] loss
#     @ ~/juliadev/Flux/prova.jl:159 [inlined]
#  [19] (::Nothing)(none::typeof(loss), none::Conv{…}, none::Reactant.TracedRArray{…})
#     @ Reactant ./<missing>:0
#  [20] call_with_reactant
#     @ ./none:-1 [inlined]
#  [21] call_with_reactant(::typeof(loss), ::Conv{…}, ::Reactant.TracedRArray{…})
#     @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
#  [22] #123
#     @ ~/juliadev/Flux/prova.jl:169 [inlined]
#  [23] (::Nothing)(none::var"#123#124", none::Reactant.TracedRArray{Float32, 4})
#     @ Reactant ./<missing>:0
#  [24] call_with_reactant
#     @ ./none:-1 [inlined]
#  [25] call_with_reactant(::var"#123#124", ::Reactant.TracedRArray{Float32, 4})
#     @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
#  [26] make_mlir_fn(f::var"#123#124", args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, within_autodiff::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Nothing, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
#     @ Reactant.TracedUtils ~/.julia/packages/Reactant/UjybB/src/TracedUtils.jl:348
#  [27] make_mlir_fn
#     @ ~/.julia/packages/Reactant/UjybB/src/TracedUtils.jl:277 [inlined]
#  [28] overload_autodiff(::ReverseMode{…}, f::Const{…}, ::Type{…}, args::Duplicated{…})
#     @ Reactant ~/.julia/packages/Reactant/UjybB/src/Enzyme.jl:315
#  [29] autodiff
#     @ ~/.julia/packages/Reactant/UjybB/src/Overlay.jl:36 [inlined]
#  [30] (::Nothing)(none::typeof(autodiff), none::ReverseMode{…}, none::Const{…}, none::Type{…}, none::Tuple{…})
#     @ Reactant ./<missing>:0
#  [31] call_with_reactant
#     @ ./none:-1 [inlined]
#  [32] call_with_reactant(::typeof(autodiff), ::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Duplicated{…})
#     @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
#  [33] macro expansion
#     @ ~/.julia/packages/Reactant/UjybB/src/utils.jl:314 [inlined]
#  [34] applyiterate_with_reactant(::typeof(iterate), ::typeof(autodiff), ::Core.SimpleVector, ::Vector{Any})
#     @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:300
#  [35] enzyme_withgradient
#     @ ~/juliadev/Flux/test/test_utils.jl:32 [inlined]
#  [36] (::Nothing)(none::typeof(enzyme_withgradient), none::var"#123#124", none::Tuple{Reactant.TracedRArray{Float32, 4}})
#     @ Reactant ./<missing>:0
#  [37] call_with_reactant
#     @ ./none:-1 [inlined]
#  [38] call_with_reactant(::typeof(enzyme_withgradient), ::var"#123#124", ::Reactant.TracedRArray{Float32, 4})
#     @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
#  [39] make_mlir_fn(f::typeof(enzyme_withgradient), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, within_autodiff::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
#     @ Reactant.TracedUtils ~/.julia/packages/Reactant/UjybB/src/TracedUtils.jl:348
#  [40] make_mlir_fn
#     @ ~/.julia/packages/Reactant/UjybB/src/TracedUtils.jl:277 [inlined]
#  [41] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(enzyme_withgradient), args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}, sdygroupidcache::Tuple{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
#     @ Reactant.Compiler ~/.julia/packages/Reactant/UjybB/src/Compiler.jl:1679
#  [42] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
#     @ Reactant.Compiler ~/.julia/packages/Reactant/UjybB/src/Compiler.jl:3630
#  [43] compile_xla
#     @ ~/.julia/packages/Reactant/UjybB/src/Compiler.jl:3602 [inlined]
#  [44] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
#     @ Reactant.Compiler ~/.julia/packages/Reactant/UjybB/src/Compiler.jl:3706
#  [45] top-level scope
#     @ ~/.julia/packages/Reactant/UjybB/src/Compiler.jl:2775
# Some type information was truncated. Use `show(err)` to see complete types.

@jit enzyme_withgradient(r_m -> loss(r_m, r_x), r_m)
# ERROR: MethodError: no method matching primitive_type(::Type{Reactant.TracedRNumber{Float32}})
# The function `primitive_type` exists, but no method is defined for this combination of argument types.

# Closest candidates are:
#   primitive_type(::Type{UInt16})
#    @ Reactant ~/.julia/packages/Reactant/UjybB/src/xla/Utils.jl:44
#   primitive_type(::Type{UInt8})
#    @ Reactant ~/.julia/packages/Reactant/UjybB/src/xla/Utils.jl:44
#   primitive_type(::Type{Reactant.F8E5M2})
#    @ Reactant ~/.julia/packages/Reactant/UjybB/src/xla/Utils.jl:44
#   ...

# Stacktrace:
#   [1] macro expansion
#     @ ~/.julia/packages/Reactant/UjybB/src/utils.jl:-1 [inlined]
#   [2] call_with_reactant(::typeof(Reactant.XLA.primitive_type), ::Type{Reactant.TracedRNumber{Float32}})
#     @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:947
#   [3] #similar#19
#     @ ~/.julia/packages/Reactant/UjybB/src/xla/PJRT/Buffer.jl:102 [inlined]
#   [4] (::Nothing)(none::Reactant.XLA.PJRT.var"##similar#19", none::Reactant.XLA.PJRT.Client, none::Nothing, none::Reactant.XLA.PJRT.Device, none::typeof(similar), none::Type{…}, none::Type, none::NTuple{…})
#     @ Reactant ./<missing>:0
#   [5] top-level scope
#     @ none:0
#   [6] similar
#     @ ~/.julia/packages/Reactant/UjybB/src/xla/PJRT/Buffer.jl:76 [inlined]
#   [7] similar
#     @ ~/.julia/packages/Reactant/UjybB/src/xla/PJRT/Buffer.jl:114 [inlined]
#   [8] similar
#     @ ~/.julia/packages/Reactant/UjybB/src/xla/PJRT/AsyncBuffer.jl:16 [inlined]
#   [9] (::Nothing)(none::typeof(similar), none::Reactant.XLA.PJRT.AsyncBuffer, none::Tuple{Type, NTuple{4, Int64}})
#     @ Reactant ./<missing>:0
#  [10] top-level scope
#     @ none:0
#  [11] #110
#     @ ~/.julia/packages/Reactant/UjybB/src/ConcreteRArray.jl:433 [inlined]
#  [12] (::Nothing)(none::Reactant.var"#110#111"{Reactant.TracedRNumber{…}, ConcretePJRTArray{…}, Tuple{…}}, none::Int64)
#     @ Reactant ./<missing>:0
#  [13] call_with_reactant
#     @ ./none:-1 [inlined]
#  [14] call_with_reactant(::Reactant.MustThrowError, ::Reactant.var"#110#111"{…}, ::Int64)
#     @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
#  [15] ntuple
#     @ ./ntuple.jl:50 [inlined]
#  [16] similar
#     @ ~/.julia/packages/Reactant/UjybB/src/ConcreteRArray.jl:431 [inlined]
#  [17] (::Nothing)(none::typeof(similar), none::ConcretePJRTArray{…}, none::Type{…}, none::NTuple{…})
#     @ Reactant ./<missing>:0
#  [18] call_with_reactant
#     @ ./none:-1 [inlined]
#  [19] call_with_reactant(::Reactant.MustThrowError, ::typeof(similar), ::ConcretePJRTArray{…}, ::Type{…}, ::NTuple{…})
#     @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
#  [20] similar
#     @ ~/.julia/packages/Reactant/UjybB/src/ConcreteRArray.jl:427 [inlined]
#  [21] (::Nothing)(none::typeof(similar), none::ConcretePJRTArray{Float32, 4, 1}, none::Type{Reactant.TracedRNumber{Float32}})
#     @ Reactant ./<missing>:0
#  [22] call_with_reactant
#     @ ./none:-1 [inlined]
#  [23] call_with_reactant(::Reactant.MustThrowError, ::typeof(similar), ::ConcretePJRTArray{…}, ::Type{…})
#     @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
#  [24] AbstractArray
#     @ ./array.jl:622 [inlined]
#  [25] (::Nothing)(none::Type{AbstractArray{Reactant.TracedRNumber{Float32}, 4}}, none::ConcretePJRTArray{Float32, 4, 1})
#     @ Reactant ./<missing>:0
#  [26] call_with_reactant
#     @ ./none:-1 [inlined]
#  [27] call_with_reactant(::Reactant.MustThrowError, ::Type{AbstractArray{…}}, ::ConcretePJRTArray{Float32, 4, 1})
#     @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
#  [28] AbstractArray
#     @ ./boot.jl:677 [inlined]
#  [29] (::Nothing)(none::Type{AbstractArray{Reactant.TracedRNumber{Float32}}}, none::ConcretePJRTArray{Float32, 4, 1})
#     @ Reactant ./<missing>:0
#  [30] call_with_reactant
#     @ ./none:-1 [inlined]
#  [31] call_with_reactant(::Reactant.MustThrowError, ::Type{AbstractArray{…}}, ::ConcretePJRTArray{Float32, 4, 1})
#     @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
#  [32] convert
#     @ ./abstractarray.jl:17 [inlined]
#  [33] (::Nothing)(none::typeof(convert), none::Type{AbstractArray{…}}, none::ConcretePJRTArray{Float32, 4, 1})
#     @ Reactant ./<missing>:0
#  [34] call_with_reactant
#     @ ./none:-1 [inlined]
#  [35] call_with_reactant(::Reactant.MustThrowError, ::typeof(convert), ::Type{…}, ::ConcretePJRTArray{…})
#     @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
#  [36] _match_eltype
#     @ ~/.julia/packages/Flux/WMUyh/src/layers/stateless.jl:77 [inlined]
#  [37] (::Nothing)(none::typeof(Flux._match_eltype), none::Conv{…}, none::Type{…}, none::ConcretePJRTArray{…})
#     @ Reactant ./<missing>:0
#  [38] call_with_reactant
#     @ ./none:-1 [inlined]
#  [39] call_with_reactant(::Reactant.MustThrowError, ::typeof(Flux._match_eltype), ::Conv{…}, ::Type{…}, ::ConcretePJRTArray{…})
#     @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
#  [40] _match_eltype
#     @ ~/.julia/packages/Flux/WMUyh/src/layers/stateless.jl:85 [inlined]
#  [41] (::Nothing)(none::typeof(Flux._match_eltype), none::Conv{…}, none::ConcretePJRTArray{…})
#     @ Reactant ./<missing>:0
#  [42] call_with_reactant
#     @ ./none:-1 [inlined]
#  [43] call_with_reactant(::Reactant.MustThrowError, ::typeof(Flux._match_eltype), ::Conv{…}, ::ConcretePJRTArray{…})
#     @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
#  [44] Conv
#     @ ~/.julia/packages/Flux/WMUyh/src/layers/conv.jl:200 [inlined]
#  [45] (::Nothing)(none::Conv{…}, none::ConcretePJRTArray{…})
#     @ Reactant ./<missing>:0
#  [46] call_with_reactant
#     @ ./none:-1 [inlined]
#  [47] call_with_reactant(::Reactant.MustThrowError, ::Conv{…}, ::ConcretePJRTArray{…})
#     @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
#  [48] loss
#     @ ~/juliadev/Flux/prova.jl:159 [inlined]
#  [49] (::Nothing)(none::typeof(loss), none::Conv{…}, none::ConcretePJRTArray{…})
#     @ Reactant ./<missing>:0
#  [50] call_with_reactant
#     @ ./none:-1 [inlined]
#  [51] call_with_reactant(::typeof(loss), ::Conv{…}, ::ConcretePJRTArray{…})
#     @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
#  [52] #126
#     @ ~/juliadev/Flux/prova.jl:170 [inlined]
#  [53] (::Nothing)(none::var"#126#127", none::Conv{2, 4, typeof(identity), Reactant.TracedRArray{…}, Reactant.TracedRArray{…}})
#     @ Reactant ./<missing>:0
#  [54] call_with_reactant
#     @ ./none:-1 [inlined]
#  [55] call_with_reactant(::var"#126#127", ::Conv{2, 4, typeof(identity), Reactant.TracedRArray{…}, Reactant.TracedRArray{…}})
#     @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
#  [56] make_mlir_fn(f::var"#126#127", args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, within_autodiff::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Nothing, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
#     @ Reactant.TracedUtils ~/.julia/packages/Reactant/UjybB/src/TracedUtils.jl:348
#  [57] make_mlir_fn
#     @ ~/.julia/packages/Reactant/UjybB/src/TracedUtils.jl:277 [inlined]
#  [58] overload_autodiff(::ReverseMode{…}, f::Const{…}, ::Type{…}, args::Duplicated{…})
#     @ Reactant ~/.julia/packages/Reactant/UjybB/src/Enzyme.jl:315
#  [59] autodiff
#     @ ~/.julia/packages/Reactant/UjybB/src/Overlay.jl:36 [inlined]
#  [60] (::Nothing)(none::typeof(autodiff), none::ReverseMode{…}, none::Const{…}, none::Type{…}, none::Tuple{…})
#     @ Reactant ./<missing>:0
#  [61] call_with_reactant
#     @ ./none:-1 [inlined]
#  [62] call_with_reactant(::typeof(autodiff), ::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Duplicated{…})
#     @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
#  [63] macro expansion
#     @ ~/.julia/packages/Reactant/UjybB/src/utils.jl:314 [inlined]
#  [64] applyiterate_with_reactant(::typeof(iterate), ::typeof(autodiff), ::Core.SimpleVector, ::Vector{Any})
#     @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:300
#  [65] enzyme_withgradient
#     @ ~/juliadev/Flux/test/test_utils.jl:32 [inlined]
#  [66] (::Nothing)(none::typeof(enzyme_withgradient), none::var"#126#127", none::Tuple{Conv{…}})
#     @ Reactant ./<missing>:0
#  [67] call_with_reactant
#     @ ./none:-1 [inlined]
#  [68] call_with_reactant(::typeof(enzyme_withgradient), ::var"#126#127", ::Conv{…})
#     @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
#  [69] make_mlir_fn(f::typeof(enzyme_withgradient), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, within_autodiff::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
#     @ Reactant.TracedUtils ~/.julia/packages/Reactant/UjybB/src/TracedUtils.jl:348
#  [70] make_mlir_fn
#     @ ~/.julia/packages/Reactant/UjybB/src/TracedUtils.jl:277 [inlined]
#  [71] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(enzyme_withgradient), args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}, sdygroupidcache::Tuple{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
#     @ Reactant.Compiler ~/.julia/packages/Reactant/UjybB/src/Compiler.jl:1679
#  [72] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
#     @ Reactant.Compiler ~/.julia/packages/Reactant/UjybB/src/Compiler.jl:3630
#  [73] compile_xla
#     @ ~/.julia/packages/Reactant/UjybB/src/Compiler.jl:3602 [inlined]
#  [74] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
#     @ Reactant.Compiler ~/.julia/packages/Reactant/UjybB/src/Compiler.jl:3706
# Some type information was truncated. Use `show(err)` to see complete types.

@wsmoses
Copy link
Contributor

wsmoses commented Dec 26, 2025

This all looks to have a similar root cause, can the utility function be written to not have the closure?

@CarloLucibello
Copy link
Member Author

CarloLucibello commented Dec 26, 2025

It can be done, although generally we would like to test closures as well, we have to special case reactant.

What about the if branch problem, can it be fixed on the reactant side? Or would it go away removing closures?

@wsmoses
Copy link
Contributor

wsmoses commented Dec 26, 2025

the problem in your recent snippet I believe is using a global

@jit enzyme_withgradient(r_m -> loss(r_m, r_x), r_m)

here you are capuring the global r_x.

@wsmoses
Copy link
Contributor

wsmoses commented Dec 29, 2025

@wsmoses is that something that can be fixed or worked around?

I think I see the issue here, and its the closure being passed into reactant_withgradient, which is accidentally capturing into the fuctor all other variables in scope (including the ones that aren't relevant).

Locally this seems to work:

using Flux
using Enzyme
# # using CUDA, cuDNN
using Reactant
using MLDataDevices
using Test 
using Random, Statistics, LinearAlgebra
using Zygote
using Functors

Reactant.set_default_backend("cpu")

function enzyme_withgradient(f, x...)
    args = []
    for x in x
        if x isa Number
            push!(args, Enzyme.Active(x))
        else
            push!(args, Enzyme.Duplicated(x, Enzyme.make_zero(x)))
        end
    end
    ad = Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal)
    ret = Enzyme.autodiff(ad, Enzyme.Const(f), Enzyme.Active, args...)
    g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
    return ret[2], g
end

reactant_withgradient(f, x...) = Reactant.@jit enzyme_withgradient(f, x...)

reactant_loss(loss, x...) = Reactant.@jit loss(x...)

function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4)
    fmapstructure_with_path(a, b) do kp, x, y
        # @show kp
        if x isa AbstractArray
            @test x  y rtol=rtol atol=atol
        elseif x isa Number
            @test x  y rtol=rtol atol=atol
        end
    end
end

function test_reactant_gradients(
            f, 
            xs...;
            rtol=1e-4, atol=1e-4,
            test_reactant = true,
            test_grad_f = true,
            test_grad_x = true,
            loss = (f, xs...) -> mean(f(xs...)),
            )

    Flux.trainmode!(f)

    l = loss(f, xs...)
    @test l isa Number

    if test_reactant
        reactant_dev = MLDataDevices.reactant_device(force=true)
        cpu_dev = cpu_device()
        xs_re = xs |> reactant_dev
        f_re = f |> reactant_dev
        l_re = reactant_loss(loss, f_re, xs_re...)
        @test l  l_re rtol=rtol atol=atol
    end

    if test_grad_x
        y, g = Zygote.withgradient((xs...) -> loss(f, xs...), xs...)

        if test_reactant
            y_re, g_re = reactant_withgradient(Base.Fix1(loss, f_re), xs_re...)
            @test y  y_re rtol=rtol atol=atol
            check_equal_leaves(g_re |> cpu_dev, g; rtol, atol)
        end
    end

    if test_grad_f
        y, g = Zygote.withgradient(f -> loss(f, xs...), f)

        if test_reactant
            y_re, g_re = reactant_withgradient(Base.Fix2(loss, xs_re[1]), f_re)
            @test y  y_re rtol=rtol atol=atol
            check_equal_leaves(g_re |> cpu_dev, g; rtol, atol)
        end
    end
    return true
end


m, x = Dense(2=>4), randn(Float32, 2)
loss(m, x) = mean(m(x))
r_dev = reactant_device(force=true)
r_m = m |> r_dev
r_x = x |> r_dev
@jit loss(r_m, r_x)
@jit enzyme_withgradient(loss, r_m, r_x)
@jit enzyme_withgradient(r_x -> loss(r_m, r_x), r_x)
@test test_reactant_gradients(m, x; loss) # ERROR

so just a fix1/fix2

@wsmoses
Copy link
Contributor

wsmoses commented Jan 2, 2026

this is now subsumed by #2600 and can presumably be closed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants