Conversation
|
@CarloLucibello do you know when this started failing? This implies that somewhere the function |
|
bump |
514fc56 to
73627f4
Compare
|
@wsmoses a first error is "Cannot trace concrete" |
|
My earlier comment still applies. All data needs to be passed in as an argument, not a captured global |
|
I can't see any global.... |
|
Removing the using Flux
using Enzyme
# # using CUDA, cuDNN
using Reactant
using MLDataDevices
using Test
using Random, Statistics, LinearAlgebra
using Zygote
using Functors
Reactant.set_default_backend("cpu")
function enzyme_withgradient(f, x...)
args = []
for x in x
if x isa Number
push!(args, Enzyme.Active(x))
else
push!(args, Enzyme.Duplicated(x, Enzyme.make_zero(x)))
end
end
ad = Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal)
ret = Enzyme.autodiff(ad, Enzyme.Const(f), Enzyme.Active, args...)
g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
return ret[2], g
end
reactant_withgradient(f, x...) = Reactant.@jit enzyme_withgradient(f, x...)
reactant_loss(loss, x...) = Reactant.@jit loss(x...)
function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4)
fmapstructure_with_path(a, b) do kp, x, y
# @show kp
if x isa AbstractArray
@test x ≈ y rtol=rtol atol=atol
elseif x isa Number
@test x ≈ y rtol=rtol atol=atol
end
end
end
function test_reactant_gradients(
f,
xs...;
rtol=1e-4, atol=1e-4,
test_reactant = true,
test_grad_f = true,
test_grad_x = true,
loss = (f, xs...) -> mean(f(xs...)),
)
Flux.trainmode!(f)
l = loss(f, xs...)
@test l isa Number
if test_reactant
reactant_dev = MLDataDevices.reactant_device(force=true)
cpu_dev = cpu_device()
xs_re = xs |> reactant_dev
f_re = f |> reactant_dev
l_re = reactant_loss(loss, f_re, xs_re...)
@test l ≈ l_re rtol=rtol atol=atol
end
if test_grad_x
y, g = Zygote.withgradient((xs...) -> loss(f, xs...), xs...)
if test_reactant
y_re, g_re = reactant_withgradient((xs...) -> loss(f_re, xs...), xs_re...)
@test y ≈ y_re rtol=rtol atol=atol
check_equal_leaves(g_re |> cpu_dev, g; rtol, atol)
end
end
if test_grad_f
y, g = Zygote.withgradient(f -> loss(f, xs...), f)
if test_reactant
y_re, g_re = reactant_withgradient(f -> loss(f, xs_re...), f_re)
@test y ≈ y_re rtol=rtol atol=atol
check_equal_leaves(g_re |> cpu_dev, g; rtol, atol)
end
end
return true
end
m, x = Dense(2=>4), randn(Float32, 2)
loss(m, x) = mean(m(x))
r_dev = reactant_device(force=true)
r_m = m |> r_dev
r_x = x |> r_dev
@jit loss(r_m, r_x)
@jit enzyme_withgradient(loss, r_m, r_x)
@jit enzyme_withgradient(r_x -> loss(r_m, r_x), r_x)
@test test_reactant_gradients(m, x; loss) # ERROR@wsmoses is that something that can be fixed or worked around? |
|
Other two failures: m, x = Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 1)
loss(m, x) = mean(m(x))
r_dev = reactant_device(force=true)
r_m = m |> r_dev
r_x = x |> r_dev
@jit loss(r_m, r_x)
y, gs = Zygote.withgradient(loss, m, x)
r_y, r_gs = @jit enzyme_withgradient(loss, r_m, r_x)
@test y ≈ r_y rtol=1e-4 atol=1e-4
check_equal_leaves(gs, r_gs; rtol=1e-4, atol=1e-4)
@jit enzyme_withgradient(r_x -> loss(r_m, r_x), r_x)
# ERROR: MethodError: no method matching overloaded_conv!(::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…}, ::ConcretePJRTArray{…}, ::DenseConvDims{…})
# The function `overloaded_conv!` exists, but no method is defined for this combination of argument types.
# Closest candidates are:
# overloaded_conv!(::AbstractArray{Reactant.TracedRNumber{T}, N}, ::AbstractArray{Reactant.TracedRNumber{T2}, N}, ::AbstractArray{Reactant.TracedRNumber{T3}, N}, ::DenseConvDims) where {T, T2, T3, N}
# @ ReactantNNlibExt ~/.julia/packages/Reactant/UjybB/ext/ReactantNNlibExt/Implementations.jl:40
# Stacktrace:
# [1] conv!(y::Reactant.TracedRArray{…}, x::Reactant.TracedRArray{…}, w::ConcretePJRTArray{…}, cdims::DenseConvDims{…}; kwargs::@Kwargs{})
# @ ReactantNNlibExt ~/.julia/packages/Reactant/UjybB/ext/ReactantNNlibExt/Overlay.jl:3
# [2] conv!
# @ ~/.julia/packages/Reactant/UjybB/ext/ReactantNNlibExt/Overlay.jl:1 [inlined]
# [3] (::Nothing)(none::typeof(conv!), none::Reactant.TracedRArray{…}, none::Reactant.TracedRArray{…}, none::ConcretePJRTArray{…}, none::DenseConvDims{…})
# @ Reactant ./<missing>:0
# [4] call_with_reactant
# @ ./none:-1 [inlined]
# [5] call_with_reactant(::Reactant.MustThrowError, ::typeof(conv!), ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…}, ::ConcretePJRTArray{…}, ::DenseConvDims{…})
# @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
# [6] #conv#124
# @ ~/.julia/packages/NNlib/ytFya/src/conv.jl:88 [inlined]
# [7] (::Nothing)(none::NNlib.var"##conv#124", none::@Kwargs{}, none::typeof(conv), none::Reactant.TracedRArray{…}, none::ConcretePJRTArray{…}, none::DenseConvDims{…})
# @ Reactant ./<missing>:0
# [8] call_with_reactant
# @ ./none:-1 [inlined]
# [9] call_with_reactant(::Reactant.MustThrowError, ::NNlib.var"##conv#124", ::@Kwargs{}, ::typeof(conv), ::Reactant.TracedRArray{…}, ::ConcretePJRTArray{…}, ::DenseConvDims{…})
# @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
# [10] conv
# @ ~/.julia/packages/NNlib/ytFya/src/conv.jl:83 [inlined]
# [11] (::Nothing)(none::typeof(conv), none::Reactant.TracedRArray{…}, none::ConcretePJRTArray{…}, none::DenseConvDims{…})
# @ Reactant ./<missing>:0
# [12] call_with_reactant
# @ ./none:-1 [inlined]
# [13] call_with_reactant(::Reactant.MustThrowError, ::typeof(conv), ::Reactant.TracedRArray{…}, ::ConcretePJRTArray{…}, ::DenseConvDims{…})
# @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
# [14] Conv
# @ ~/.julia/packages/Flux/WMUyh/src/layers/conv.jl:201 [inlined]
# [15] (::Nothing)(none::Conv{…}, none::Reactant.TracedRArray{…})
# @ Reactant ./<missing>:0
# [16] call_with_reactant
# @ ./none:-1 [inlined]
# [17] call_with_reactant(::Reactant.MustThrowError, ::Conv{…}, ::Reactant.TracedRArray{…})
# @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
# [18] loss
# @ ~/juliadev/Flux/prova.jl:159 [inlined]
# [19] (::Nothing)(none::typeof(loss), none::Conv{…}, none::Reactant.TracedRArray{…})
# @ Reactant ./<missing>:0
# [20] call_with_reactant
# @ ./none:-1 [inlined]
# [21] call_with_reactant(::typeof(loss), ::Conv{…}, ::Reactant.TracedRArray{…})
# @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
# [22] #123
# @ ~/juliadev/Flux/prova.jl:169 [inlined]
# [23] (::Nothing)(none::var"#123#124", none::Reactant.TracedRArray{Float32, 4})
# @ Reactant ./<missing>:0
# [24] call_with_reactant
# @ ./none:-1 [inlined]
# [25] call_with_reactant(::var"#123#124", ::Reactant.TracedRArray{Float32, 4})
# @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
# [26] make_mlir_fn(f::var"#123#124", args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, within_autodiff::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Nothing, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
# @ Reactant.TracedUtils ~/.julia/packages/Reactant/UjybB/src/TracedUtils.jl:348
# [27] make_mlir_fn
# @ ~/.julia/packages/Reactant/UjybB/src/TracedUtils.jl:277 [inlined]
# [28] overload_autodiff(::ReverseMode{…}, f::Const{…}, ::Type{…}, args::Duplicated{…})
# @ Reactant ~/.julia/packages/Reactant/UjybB/src/Enzyme.jl:315
# [29] autodiff
# @ ~/.julia/packages/Reactant/UjybB/src/Overlay.jl:36 [inlined]
# [30] (::Nothing)(none::typeof(autodiff), none::ReverseMode{…}, none::Const{…}, none::Type{…}, none::Tuple{…})
# @ Reactant ./<missing>:0
# [31] call_with_reactant
# @ ./none:-1 [inlined]
# [32] call_with_reactant(::typeof(autodiff), ::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Duplicated{…})
# @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
# [33] macro expansion
# @ ~/.julia/packages/Reactant/UjybB/src/utils.jl:314 [inlined]
# [34] applyiterate_with_reactant(::typeof(iterate), ::typeof(autodiff), ::Core.SimpleVector, ::Vector{Any})
# @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:300
# [35] enzyme_withgradient
# @ ~/juliadev/Flux/test/test_utils.jl:32 [inlined]
# [36] (::Nothing)(none::typeof(enzyme_withgradient), none::var"#123#124", none::Tuple{Reactant.TracedRArray{Float32, 4}})
# @ Reactant ./<missing>:0
# [37] call_with_reactant
# @ ./none:-1 [inlined]
# [38] call_with_reactant(::typeof(enzyme_withgradient), ::var"#123#124", ::Reactant.TracedRArray{Float32, 4})
# @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
# [39] make_mlir_fn(f::typeof(enzyme_withgradient), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, within_autodiff::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
# @ Reactant.TracedUtils ~/.julia/packages/Reactant/UjybB/src/TracedUtils.jl:348
# [40] make_mlir_fn
# @ ~/.julia/packages/Reactant/UjybB/src/TracedUtils.jl:277 [inlined]
# [41] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(enzyme_withgradient), args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}, sdygroupidcache::Tuple{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
# @ Reactant.Compiler ~/.julia/packages/Reactant/UjybB/src/Compiler.jl:1679
# [42] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
# @ Reactant.Compiler ~/.julia/packages/Reactant/UjybB/src/Compiler.jl:3630
# [43] compile_xla
# @ ~/.julia/packages/Reactant/UjybB/src/Compiler.jl:3602 [inlined]
# [44] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
# @ Reactant.Compiler ~/.julia/packages/Reactant/UjybB/src/Compiler.jl:3706
# [45] top-level scope
# @ ~/.julia/packages/Reactant/UjybB/src/Compiler.jl:2775
# Some type information was truncated. Use `show(err)` to see complete types.
@jit enzyme_withgradient(r_m -> loss(r_m, r_x), r_m)
# ERROR: MethodError: no method matching primitive_type(::Type{Reactant.TracedRNumber{Float32}})
# The function `primitive_type` exists, but no method is defined for this combination of argument types.
# Closest candidates are:
# primitive_type(::Type{UInt16})
# @ Reactant ~/.julia/packages/Reactant/UjybB/src/xla/Utils.jl:44
# primitive_type(::Type{UInt8})
# @ Reactant ~/.julia/packages/Reactant/UjybB/src/xla/Utils.jl:44
# primitive_type(::Type{Reactant.F8E5M2})
# @ Reactant ~/.julia/packages/Reactant/UjybB/src/xla/Utils.jl:44
# ...
# Stacktrace:
# [1] macro expansion
# @ ~/.julia/packages/Reactant/UjybB/src/utils.jl:-1 [inlined]
# [2] call_with_reactant(::typeof(Reactant.XLA.primitive_type), ::Type{Reactant.TracedRNumber{Float32}})
# @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:947
# [3] #similar#19
# @ ~/.julia/packages/Reactant/UjybB/src/xla/PJRT/Buffer.jl:102 [inlined]
# [4] (::Nothing)(none::Reactant.XLA.PJRT.var"##similar#19", none::Reactant.XLA.PJRT.Client, none::Nothing, none::Reactant.XLA.PJRT.Device, none::typeof(similar), none::Type{…}, none::Type, none::NTuple{…})
# @ Reactant ./<missing>:0
# [5] top-level scope
# @ none:0
# [6] similar
# @ ~/.julia/packages/Reactant/UjybB/src/xla/PJRT/Buffer.jl:76 [inlined]
# [7] similar
# @ ~/.julia/packages/Reactant/UjybB/src/xla/PJRT/Buffer.jl:114 [inlined]
# [8] similar
# @ ~/.julia/packages/Reactant/UjybB/src/xla/PJRT/AsyncBuffer.jl:16 [inlined]
# [9] (::Nothing)(none::typeof(similar), none::Reactant.XLA.PJRT.AsyncBuffer, none::Tuple{Type, NTuple{4, Int64}})
# @ Reactant ./<missing>:0
# [10] top-level scope
# @ none:0
# [11] #110
# @ ~/.julia/packages/Reactant/UjybB/src/ConcreteRArray.jl:433 [inlined]
# [12] (::Nothing)(none::Reactant.var"#110#111"{Reactant.TracedRNumber{…}, ConcretePJRTArray{…}, Tuple{…}}, none::Int64)
# @ Reactant ./<missing>:0
# [13] call_with_reactant
# @ ./none:-1 [inlined]
# [14] call_with_reactant(::Reactant.MustThrowError, ::Reactant.var"#110#111"{…}, ::Int64)
# @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
# [15] ntuple
# @ ./ntuple.jl:50 [inlined]
# [16] similar
# @ ~/.julia/packages/Reactant/UjybB/src/ConcreteRArray.jl:431 [inlined]
# [17] (::Nothing)(none::typeof(similar), none::ConcretePJRTArray{…}, none::Type{…}, none::NTuple{…})
# @ Reactant ./<missing>:0
# [18] call_with_reactant
# @ ./none:-1 [inlined]
# [19] call_with_reactant(::Reactant.MustThrowError, ::typeof(similar), ::ConcretePJRTArray{…}, ::Type{…}, ::NTuple{…})
# @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
# [20] similar
# @ ~/.julia/packages/Reactant/UjybB/src/ConcreteRArray.jl:427 [inlined]
# [21] (::Nothing)(none::typeof(similar), none::ConcretePJRTArray{Float32, 4, 1}, none::Type{Reactant.TracedRNumber{Float32}})
# @ Reactant ./<missing>:0
# [22] call_with_reactant
# @ ./none:-1 [inlined]
# [23] call_with_reactant(::Reactant.MustThrowError, ::typeof(similar), ::ConcretePJRTArray{…}, ::Type{…})
# @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
# [24] AbstractArray
# @ ./array.jl:622 [inlined]
# [25] (::Nothing)(none::Type{AbstractArray{Reactant.TracedRNumber{Float32}, 4}}, none::ConcretePJRTArray{Float32, 4, 1})
# @ Reactant ./<missing>:0
# [26] call_with_reactant
# @ ./none:-1 [inlined]
# [27] call_with_reactant(::Reactant.MustThrowError, ::Type{AbstractArray{…}}, ::ConcretePJRTArray{Float32, 4, 1})
# @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
# [28] AbstractArray
# @ ./boot.jl:677 [inlined]
# [29] (::Nothing)(none::Type{AbstractArray{Reactant.TracedRNumber{Float32}}}, none::ConcretePJRTArray{Float32, 4, 1})
# @ Reactant ./<missing>:0
# [30] call_with_reactant
# @ ./none:-1 [inlined]
# [31] call_with_reactant(::Reactant.MustThrowError, ::Type{AbstractArray{…}}, ::ConcretePJRTArray{Float32, 4, 1})
# @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
# [32] convert
# @ ./abstractarray.jl:17 [inlined]
# [33] (::Nothing)(none::typeof(convert), none::Type{AbstractArray{…}}, none::ConcretePJRTArray{Float32, 4, 1})
# @ Reactant ./<missing>:0
# [34] call_with_reactant
# @ ./none:-1 [inlined]
# [35] call_with_reactant(::Reactant.MustThrowError, ::typeof(convert), ::Type{…}, ::ConcretePJRTArray{…})
# @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
# [36] _match_eltype
# @ ~/.julia/packages/Flux/WMUyh/src/layers/stateless.jl:77 [inlined]
# [37] (::Nothing)(none::typeof(Flux._match_eltype), none::Conv{…}, none::Type{…}, none::ConcretePJRTArray{…})
# @ Reactant ./<missing>:0
# [38] call_with_reactant
# @ ./none:-1 [inlined]
# [39] call_with_reactant(::Reactant.MustThrowError, ::typeof(Flux._match_eltype), ::Conv{…}, ::Type{…}, ::ConcretePJRTArray{…})
# @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
# [40] _match_eltype
# @ ~/.julia/packages/Flux/WMUyh/src/layers/stateless.jl:85 [inlined]
# [41] (::Nothing)(none::typeof(Flux._match_eltype), none::Conv{…}, none::ConcretePJRTArray{…})
# @ Reactant ./<missing>:0
# [42] call_with_reactant
# @ ./none:-1 [inlined]
# [43] call_with_reactant(::Reactant.MustThrowError, ::typeof(Flux._match_eltype), ::Conv{…}, ::ConcretePJRTArray{…})
# @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
# [44] Conv
# @ ~/.julia/packages/Flux/WMUyh/src/layers/conv.jl:200 [inlined]
# [45] (::Nothing)(none::Conv{…}, none::ConcretePJRTArray{…})
# @ Reactant ./<missing>:0
# [46] call_with_reactant
# @ ./none:-1 [inlined]
# [47] call_with_reactant(::Reactant.MustThrowError, ::Conv{…}, ::ConcretePJRTArray{…})
# @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
# [48] loss
# @ ~/juliadev/Flux/prova.jl:159 [inlined]
# [49] (::Nothing)(none::typeof(loss), none::Conv{…}, none::ConcretePJRTArray{…})
# @ Reactant ./<missing>:0
# [50] call_with_reactant
# @ ./none:-1 [inlined]
# [51] call_with_reactant(::typeof(loss), ::Conv{…}, ::ConcretePJRTArray{…})
# @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
# [52] #126
# @ ~/juliadev/Flux/prova.jl:170 [inlined]
# [53] (::Nothing)(none::var"#126#127", none::Conv{2, 4, typeof(identity), Reactant.TracedRArray{…}, Reactant.TracedRArray{…}})
# @ Reactant ./<missing>:0
# [54] call_with_reactant
# @ ./none:-1 [inlined]
# [55] call_with_reactant(::var"#126#127", ::Conv{2, 4, typeof(identity), Reactant.TracedRArray{…}, Reactant.TracedRArray{…}})
# @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
# [56] make_mlir_fn(f::var"#126#127", args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, within_autodiff::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Nothing, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
# @ Reactant.TracedUtils ~/.julia/packages/Reactant/UjybB/src/TracedUtils.jl:348
# [57] make_mlir_fn
# @ ~/.julia/packages/Reactant/UjybB/src/TracedUtils.jl:277 [inlined]
# [58] overload_autodiff(::ReverseMode{…}, f::Const{…}, ::Type{…}, args::Duplicated{…})
# @ Reactant ~/.julia/packages/Reactant/UjybB/src/Enzyme.jl:315
# [59] autodiff
# @ ~/.julia/packages/Reactant/UjybB/src/Overlay.jl:36 [inlined]
# [60] (::Nothing)(none::typeof(autodiff), none::ReverseMode{…}, none::Const{…}, none::Type{…}, none::Tuple{…})
# @ Reactant ./<missing>:0
# [61] call_with_reactant
# @ ./none:-1 [inlined]
# [62] call_with_reactant(::typeof(autodiff), ::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Duplicated{…})
# @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
# [63] macro expansion
# @ ~/.julia/packages/Reactant/UjybB/src/utils.jl:314 [inlined]
# [64] applyiterate_with_reactant(::typeof(iterate), ::typeof(autodiff), ::Core.SimpleVector, ::Vector{Any})
# @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:300
# [65] enzyme_withgradient
# @ ~/juliadev/Flux/test/test_utils.jl:32 [inlined]
# [66] (::Nothing)(none::typeof(enzyme_withgradient), none::var"#126#127", none::Tuple{Conv{…}})
# @ Reactant ./<missing>:0
# [67] call_with_reactant
# @ ./none:-1 [inlined]
# [68] call_with_reactant(::typeof(enzyme_withgradient), ::var"#126#127", ::Conv{…})
# @ Reactant ~/.julia/packages/Reactant/UjybB/src/utils.jl:0
# [69] make_mlir_fn(f::typeof(enzyme_withgradient), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, within_autodiff::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
# @ Reactant.TracedUtils ~/.julia/packages/Reactant/UjybB/src/TracedUtils.jl:348
# [70] make_mlir_fn
# @ ~/.julia/packages/Reactant/UjybB/src/TracedUtils.jl:277 [inlined]
# [71] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(enzyme_withgradient), args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}, sdygroupidcache::Tuple{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
# @ Reactant.Compiler ~/.julia/packages/Reactant/UjybB/src/Compiler.jl:1679
# [72] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
# @ Reactant.Compiler ~/.julia/packages/Reactant/UjybB/src/Compiler.jl:3630
# [73] compile_xla
# @ ~/.julia/packages/Reactant/UjybB/src/Compiler.jl:3602 [inlined]
# [74] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
# @ Reactant.Compiler ~/.julia/packages/Reactant/UjybB/src/Compiler.jl:3706
# Some type information was truncated. Use `show(err)` to see complete types.
|
|
This all looks to have a similar root cause, can the utility function be written to not have the closure? |
|
It can be done, although generally we would like to test closures as well, we have to special case reactant. What about the if branch problem, can it be fixed on the reactant side? Or would it go away removing closures? |
|
the problem in your recent snippet I believe is using a global here you are capuring the global r_x. |
I think I see the issue here, and its the closure being passed into reactant_withgradient, which is accidentally capturing into the fuctor all other variables in scope (including the ones that aren't relevant). Locally this seems to work: using Flux
using Enzyme
# # using CUDA, cuDNN
using Reactant
using MLDataDevices
using Test
using Random, Statistics, LinearAlgebra
using Zygote
using Functors
Reactant.set_default_backend("cpu")
function enzyme_withgradient(f, x...)
args = []
for x in x
if x isa Number
push!(args, Enzyme.Active(x))
else
push!(args, Enzyme.Duplicated(x, Enzyme.make_zero(x)))
end
end
ad = Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal)
ret = Enzyme.autodiff(ad, Enzyme.Const(f), Enzyme.Active, args...)
g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
return ret[2], g
end
reactant_withgradient(f, x...) = Reactant.@jit enzyme_withgradient(f, x...)
reactant_loss(loss, x...) = Reactant.@jit loss(x...)
function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4)
fmapstructure_with_path(a, b) do kp, x, y
# @show kp
if x isa AbstractArray
@test x ≈ y rtol=rtol atol=atol
elseif x isa Number
@test x ≈ y rtol=rtol atol=atol
end
end
end
function test_reactant_gradients(
f,
xs...;
rtol=1e-4, atol=1e-4,
test_reactant = true,
test_grad_f = true,
test_grad_x = true,
loss = (f, xs...) -> mean(f(xs...)),
)
Flux.trainmode!(f)
l = loss(f, xs...)
@test l isa Number
if test_reactant
reactant_dev = MLDataDevices.reactant_device(force=true)
cpu_dev = cpu_device()
xs_re = xs |> reactant_dev
f_re = f |> reactant_dev
l_re = reactant_loss(loss, f_re, xs_re...)
@test l ≈ l_re rtol=rtol atol=atol
end
if test_grad_x
y, g = Zygote.withgradient((xs...) -> loss(f, xs...), xs...)
if test_reactant
y_re, g_re = reactant_withgradient(Base.Fix1(loss, f_re), xs_re...)
@test y ≈ y_re rtol=rtol atol=atol
check_equal_leaves(g_re |> cpu_dev, g; rtol, atol)
end
end
if test_grad_f
y, g = Zygote.withgradient(f -> loss(f, xs...), f)
if test_reactant
y_re, g_re = reactant_withgradient(Base.Fix2(loss, xs_re[1]), f_re)
@test y ≈ y_re rtol=rtol atol=atol
check_equal_leaves(g_re |> cpu_dev, g; rtol, atol)
end
end
return true
end
m, x = Dense(2=>4), randn(Float32, 2)
loss(m, x) = mean(m(x))
r_dev = reactant_device(force=true)
r_m = m |> r_dev
r_x = x |> r_dev
@jit loss(r_m, r_x)
@jit enzyme_withgradient(loss, r_m, r_x)
@jit enzyme_withgradient(r_x -> loss(r_m, r_x), r_x)
@test test_reactant_gradients(m, x; loss) # ERRORso just a fix1/fix2 |
|
this is now subsumed by #2600 and can presumably be closed |
No description provided.