diff --git a/src/MuladdMacro.jl b/src/MuladdMacro.jl index 41f0d0b..908c158 100644 --- a/src/MuladdMacro.jl +++ b/src/MuladdMacro.jl @@ -1,6 +1,6 @@ module MuladdMacro -using MacroTools: postwalk, @capture +using MacroTools: postwalk """ @muladd @@ -82,28 +82,48 @@ function sub_to_muladd(ex) return ex end +""" + iscall(ex, op) + +Determine whether `ex` is a call of operation `op` with at least two arguments. +""" +iscall(ex::Expr, op) = + ex.head == :call && length(ex.args) > 2 && ex.args[1] == op +iscall(ex, op) = false + """ isdotcall(ex) Determine whether `ex` is a dot call. """ -isdotcall(ex) = - (@capture(ex, (f_)(__)) && startswith(string(f), '.')) || - @capture(ex, (_).(__)) +isdotcall(ex::Expr) = + (ex.head == :. && length(ex.args) == 2 && Meta.isexpr(ex.args[2], :tuple)) || + (ex.head == :call && !isempty(ex.args) && startswith(string(ex.args[1]), '.')) +isdotcall(ex) = false + +""" + isdotcall(ex, op) + +Determine whether `ex` is a dot call of operation `op` with at least two arguments. +""" +isdotcall(ex::Expr, op) = + (ex.head == :. && length(ex.args) == 2 && ex.args[1] == op && Meta.isexpr(ex.args[2], :tuple) && length(ex.args[2].args) > 1) || + (ex.head == :call && length(ex.args) > 2 && ex.args[1] == Symbol('.', op)) +isdotcall(ex, op) = false """ issum(ex) Determine whether `ex` is a sum. """ -issum(ex) = @capture(ex, +(__) | .+(__) | (+).(__)) +issum(ex) = iscall(ex, :+) || isdotcall(ex, :+) """ issub(ex) Determine whether `ex` is a subtraction. """ -issub(ex) = @capture(ex, -(_, _) | .-(_, _) | (-).(_, _)) +issub(ex) = iscall(ex, :-) || isdotcall(ex, :-) """ ismul(ex, dot::Bool) @@ -111,13 +131,7 @@ issub(ex) = @capture(ex, -(_, _) | .-(_, _) | (-).(_, _)) Determine whether expression `ex` is a multiplication that is dotted if `dot` is `true` and not dotted otherwise. """ -function ismul(ex, dot::Bool) - if dot - return @capture(ex, .*(__) | (*).(__)) - else - return @capture(ex, *(__)) - end -end +ismul(ex, dot::Bool) = dot ? isdotcall(ex, :*) : iscall(ex, :*) """ newmuladd(x, y, z, dot::Bool) @@ -139,10 +153,16 @@ end Return arguments of function call in `ex`. """ -function args(ex) - @capture(ex, (_)(x__) | (_).(x__)) || error("expression is not a function call") +function args(ex::Expr) + if ex.head == :call && length(ex.args) > 1 + return ex.args[2:end] + end - return x + if ex.head == :. && length(ex.args) == 2 && Meta.isexpr(ex.args[2], :tuple) && !isempty(ex.args[2].args) + return ex.args[2].args + end + + error("expression is not a function call with arguments") end """ @@ -152,7 +172,15 @@ Split arguments of function call `ex` before last argument and combine first arguments to one expression if possible. """ function splitargs(ex) - @capture(ex, (_)(x__, y_) | (_).(x__, y_)) || error("cannot split arguments") + if ex.head == :call && length(ex.args) > 2 + x = ex.args[2:end-1] + y = ex.args[end] + elseif ex.head == :. && length(ex.args) == 2 && Meta.isexpr(ex.args[2], :tuple) && length(ex.args[2].args) > 1 + x = ex.args[2].args[1:end-1] + y = ex.args[2].args[end] + else + error("cannot split arguments") + end return newargs(ex, x...), y end @@ -165,15 +193,16 @@ Create new expression of function call `ex` with arguments `args`. Unary function calls are not considered, i.e. if only one function argument is provided it is returned. """ -function newargs(ex, args...) +function newargs(ex::Expr, args...) # Return single argument length(args) == 1 && return args[1] # Create function calls with new arguments - if @capture(ex, (f_)(__)) - return :($f($(args...))) - elseif @capture(ex, (f_).(__)) - return :(($f).($(args...))) + if ex.head == :call && !isempty(ex.args) + return Expr(:call, ex.args[1], args...) + end + if ex.head == :. && length(ex.args) == 2 && Meta.isexpr(ex.args[2], :tuple) + return Expr(:., ex.args[1], Expr(:tuple, args...)) end error("expression is not a function call") diff --git a/test/runtests.jl b/test/runtests.jl index d1a4544..d7551b8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -40,12 +40,16 @@ end :($(Base.muladd)(e, f, $(Base.muladd)(c, d, a*b))) @test @macroexpand(@muladd a*(b*c+d)+e) == :($(Base.muladd)(a, $(Base.muladd)(b, c, d), e)) + + @test @macroexpand(@muladd +a) == :(+a) end @testset "Subtraction" begin @test @macroexpand(@muladd a*b-c*d) == :($(Base.muladd)(-c, d, a*b)) @test @macroexpand(@muladd a*(b*c-d)-e) == :($(Base.muladd)(a, $(Base.muladd)(b, c, -d), -e)) + + @test @macroexpand(@muladd -a) == :(-a) end end @@ -64,6 +68,8 @@ end @test @macroexpand(@muladd f.(a)*b+c) == :($(Base.muladd)(f.(a), b, c)) @test @macroexpand(@muladd a*f.(b)+c) == :($(Base.muladd)(a, f.(b), c)) @test @macroexpand(@muladd a*b+f.(c)) == :($(Base.muladd)(a, b, f.(c))) + + @test @macroexpand(@muladd .+a) == :(.+a) end @testset "Subtraction" begin @@ -76,6 +82,8 @@ end @test @macroexpand(@muladd @. a-b*c) == :($(Base.muladd).((-).(b), c, a)) @test @macroexpand(@muladd a-b.*c) == :(a-b.*c) @test @macroexpand(@muladd a.-b*c) == :(a.-b*c) + + @test @macroexpand(@muladd .-a) == :(.-a) end end