Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preventing StackOverflowError automatically with a @safe_turbo? #430

Closed
MilesCranmer opened this issue Sep 14, 2022 · 6 comments
Closed

Comments

@MilesCranmer
Copy link
Contributor

Hey all,

I am eager to switch from @inbounds @simd to @turbo in the evaluation loops of SymbolicRegression.jl, which is the backend for PySR. You can see my initial pull request here: MilesCranmer/SymbolicRegression.jl#132 which shows the loops I am attempting to speed up.

The way my package works is that the user can pass any binary or unary operator (e.g., +, -, *, /, cos, exp, or any function they define). These operators will be arranged into expressions by a genetic algorithm until a combination is found that matches a relationship in a dataset. The evaluation of each expression needs to be high-performance, and is performed by a loop over an array with the @inbounds @simd macros. This loop is behind a function barrier so that the performance is the same as if I had hard-coded each operator the user passed.

I tried to switch to @turbo; however, I ended up seeing StackOverflowError when testing SpecialFunctions.gamma as an operator. From my look through these issues: #233, #232, it seems like certain functions will raise this error if they do not have SIMD implemented.

Now, this is an issue for my package, since the user is allowed (and encouraged) to specify any Julia function they wish as the operator - even complex, branching functions (including any of SpecialFunctions) - and those operators will be used inside these loops. Now, I understand that each operator would need to be implemented as a SIMD operation for @turbo to give a speedup, but I would still like to get the speed for other more standard operators, like +, -, *, /, etc.

I am therefore wondering if it is possible to change @turbo to be robust against StackOverflowErrors, and fall back to a non-SIMD operations? Or, maybe a @safe_turbo macro implemented that could do this? (This assumes I do not know what code the user would use inside each loop at runtime.)

Thanks!
Miles

@timholy
Copy link
Contributor

timholy commented Sep 14, 2022

A likely culprit is the assumption that constructors/conversion returns an object of the given type, see JuliaLang/julia#42372

@chriselrod
Copy link
Member

Adding a safe key word arg to @turbo or a @safe_turbo option wouldn't be difficult.

If you want to try:

  1. Add it to the kwargs here:
    function check_macro_kwarg(
    arg,
    inline::Bool,
    check_empty::Bool,
    u₁::Int8,
    u₂::Int8,
    v::Int8,
    threads::Int,
    warncheckarg::Int,
    )
    ((arg.head === :(=)) && (length(arg.args) == 2)) ||
    throw(ArgumentError("macro kwarg should be of the form `argname = value`."))
    kw = (arg.args[1])::Symbol
    value = (arg.args[2])
    if kw === :inline
    inline = value::Bool
    elseif kw === :unroll
    if value isa Integer
    u₁ = convert(Int8, value)::Int8
    elseif Meta.isexpr(value, :tuple, 2)
    u₁ = convert(Int8, value.args[1])::Int8
    u₂ = convert(Int8, value.args[2])::Int8
    else
    throw(ArgumentError("Don't know how to process argument in `unroll=$value`."))
    end
    elseif kw === :vectorize
    v = convert(Int8, value)
    elseif kw === :check_empty
    check_empty = value::Bool
    elseif kw === :thread
    if value isa Bool
    threads = Core.ifelse(value::Bool, -1, 1)
    elseif value isa Integer
    threads = max(1, convert(Int, value)::Int)
    else
    throw(ArgumentError("Don't know how to process argument in `thread=$value`."))
    end
    elseif kw === :warn_check_args
    warncheckarg = convert(Int, value)::Int
    else
    throw(
    ArgumentError(
    "Received unrecognized keyword argument $kw. Recognized arguments include:\n`inline`, `unroll`, `check_empty`, and `thread`.",
    ),
    )
    end
    inline, check_empty, u₁, u₂, v, threads, warncheckarg
    end
    function process_args(
    args;
    inline::Bool = false,
    check_empty::Bool = false,
    u₁::Int8 = zero(Int8),
    u₂::Int8 = zero(Int8),
    v::Int8 = zero(Int8),
    threads::Int = 1,
    warncheckarg::Int = 1,
    )
    for arg args
    inline, check_empty, u₁, u₂, v, threads, warncheckarg =
    check_macro_kwarg(arg, inline, check_empty, u₁, u₂, v, threads, warncheckarg)
    end
    inline, check_empty, u₁, u₂, v, threads, warncheckarg
    end
    # check if the body of loop is a block, if not convert it to a block issue#395
    # and check if the range of loop is an enumerate, if it is replace it, issue#393
    function check_inputs!(q, prepreamble)
    if Meta.isexpr(q, :for)
    if !Meta.isexpr(q.args[2], :block)
    q.args[2] = Expr(:block, q.args[2])
    replace_enumerate!(q, prepreamble) # must after warp block
    else # maybe inner loops in block
    replace_enumerate!(q, prepreamble)
    for arg in q.args[2].args
    check_inputs!(arg, prepreamble) # check recursively for inner loop
    end
    end
    end
    return q
    end
    function replace_enumerate!(q, prepreamble)
    looprange = q.args[1]
    if Meta.isexpr(looprange, :block)
    for i = 1:length(looprange.args)
    replace_single_enumerate!(q, prepreamble, i)
    end
    else
    replace_single_enumerate!(q, prepreamble)
    end
    return q
    end
    function replace_single_enumerate!(q, prepreamble, i = nothing)
    if isnothing(i) # not nest loop
    looprange, body = q.args[1], q.args[2]
    else # nest loop
    looprange, body = q.args[1].args[i], q.args[2]
    end
    @assert Meta.isexpr(looprange, :(=), 2)
    itersyms, r = looprange.args
    if Meta.isexpr(r, :call, 2) && r.args[1] == :enumerate
    _iter = r.args[2]
    if _iter isa Symbol
    iter = _iter
    else # name complex expr
    iter = gensym(:iter)
    push!(prepreamble.args, :($iter = $_iter))
    end
    if Meta.isexpr(itersyms, :tuple, 2)
    indsym, varsym = itersyms.args[1]::Symbol, itersyms.args[2]::Symbol
    _replace_looprange!(q, i, indsym, iter)
    pushfirst!(body.args, :($varsym = $iter[$indsym+firstindex($iter)-1]))
    elseif Meta.isexpr(itersyms, :tuple, 1) # like `for (i,) in enumerate(...)`
    indsym = itersyms.args[1]::Symbol
    _replace_looprange!(q, i, indsym, iter)
    elseif itersyms isa Symbol # if itersyms are not unbox in loop range
    throw(ArgumentError("`for $itersyms in enumerate($r)` is not supported,
    please use `for ($(itersyms)_i, $(itersyms)_v) in enumerate($r)` instead."))
    else
    throw(ArgumentError("Don't know how to handle expression `$itersyms`."))
    end
    end
    return q
    end
    _replace_looprange!(q, ::Nothing, indsym, iter) =
    q.args[1] = :($indsym = Base.OneTo(length($iter)))
    _replace_looprange!(q, i::Int, indsym, iter) =
    q.args[1].args[i] = :($indsym = Base.OneTo(length($iter)))
    function turbo_macro(mod, src, q, args...)
    q = macroexpand(mod, q)
    if q.head === :for
    ls = LoopSet(q, mod)
    inline, check_empty, u₁, u₂, v, threads, warncheckarg = process_args(args)
    esc(setup_call(ls, q, src, inline, check_empty, u₁, u₂, v, threads, warncheckarg))
    else
    inline, check_empty, u₁, u₂, v, threads, warncheckarg =
    process_args(args, inline = true)
    substitute_broadcast(q, Symbol(mod), inline, u₁, u₂, v, threads, warncheckarg)
    end
    end

    so that process_args will recognize it and set it as true if detected.
  2. Forward the option to setup_call
  3. Have setup_call forward it to check_args_call
    pushprepreamble!(ls, Expr(:if, check_args_call(ls), call, argfailure))
    Or write a separate function, and do && in the if statement there if safe=true.
  4. The new function/check should iterate over instructions, and chain && on ArrayInterface.can_avx for these functions. Basically, iterate over op in operations(ls), iscompute(op) || continue, and then something like
c = callexpr(op.instruction)
pushfirst!(c.args, ArrayInterface.can_avx)

Chain/intersect all of these in the if, and it'll run the fallback @inbounds @fastmath loop instead of the @turbo loop if any of these can_avx are false. For reference, can_avx:

julia> using ArrayInterface, LoopVectorization, SpecialFunctions

julia> ArrayInterface.can_avx(+)
true

julia> ArrayInterface.can_avx(exp)
true

julia> ArrayInterface.can_avx(gamma)
false

julia> ArrayInterface.can_avx(beta)
false

@chriselrod
Copy link
Member

Some special functions are also written in Julia, so you may be able to get SIMD compatible versions.
You can look at gcd if you want an idea of how to handle while loops, for example: https://github.com/JuliaSIMD/VectorizationBase.jl/blob/46bce4794e71bf06364729e61bffbadc7f48a666/src/special/misc.jl#L160-L184

@chriselrod
Copy link
Member

Perhaps we should do this automatically instead, with Base.promote_op(f, somesimdtypes...) !== Union{}

@MilesCranmer
Copy link
Contributor Author

MilesCranmer commented Sep 18, 2022

Thanks, adding a kwarg and using can_avx sounds like a good option! Will try it out.

@MilesCranmer
Copy link
Contributor Author

Made a PR over in #431

This issue was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants