Skip to content

Commit

Permalink
Merge pull request #11 from JuliaDiffEq/if
Browse files Browse the repository at this point in the history
Better fix for #9
  • Loading branch information
YingboMa authored Jul 24, 2018
2 parents d877284 + 17a7b3e commit 1edd7f4
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 22 deletions.
73 changes: 51 additions & 22 deletions src/MuladdMacro.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module MuladdMacro

using MacroTools: postwalk, @capture
using MacroTools: postwalk

"""
@muladd
Expand Down Expand Up @@ -82,42 +82,56 @@ function sub_to_muladd(ex)
return ex
end

"""
iscall(ex, op)
Determine whether `ex` is a call of operation `op` with at least two arguments.
"""
iscall(ex::Expr, op) =
ex.head == :call && length(ex.args) > 2 && ex.args[1] == op
iscall(ex, op) = false

"""
isdotcall(ex)
Determine whether `ex` is a dot call.
"""
isdotcall(ex) =
(@capture(ex, (f_)(__)) && startswith(string(f), '.')) ||
@capture(ex, (_).(__))
isdotcall(ex::Expr) =
(ex.head == :. && length(ex.args) == 2 && Meta.isexpr(ex.args[2], :tuple)) ||
(ex.head == :call && !isempty(ex.args) && startswith(string(ex.args[1]), '.'))
isdotcall(ex) = false

"""
isdotcall(ex, op)
Determine whether `ex` is a dot call of operation `op` with at least two arguments.
"""
isdotcall(ex::Expr, op) =
(ex.head == :. && length(ex.args) == 2 && ex.args[1] == op && Meta.isexpr(ex.args[2], :tuple) && length(ex.args[2].args) > 1) ||
(ex.head == :call && length(ex.args) > 2 && ex.args[1] == Symbol('.', op))
isdotcall(ex, op) = false

"""
issum(ex)
Determine whether `ex` is a sum.
"""
issum(ex) = @capture(ex, +(__) | .+(__) | (+).(__))
issum(ex) = iscall(ex, :+) || isdotcall(ex, :+)

"""
issub(ex)
Determine whether `ex` is a subtraction.
"""
issub(ex) = @capture(ex, -(_, _) | .-(_, _) | (-).(_, _))
issub(ex) = iscall(ex, :-) || isdotcall(ex, :-)

"""
ismul(ex, dot::Bool)
Determine whether expression `ex` is a multiplication that is dotted if
`dot` is `true` and not dotted otherwise.
"""
function ismul(ex, dot::Bool)
if dot
return @capture(ex, .*(__) | (*).(__))
else
return @capture(ex, *(__))
end
end
ismul(ex, dot::Bool) = dot ? isdotcall(ex, :*) : iscall(ex, :*)

"""
newmuladd(x, y, z, dot::Bool)
Expand All @@ -139,10 +153,16 @@ end
Return arguments of function call in `ex`.
"""
function args(ex)
@capture(ex, (_)(x__) | (_).(x__)) || error("expression is not a function call")
function args(ex::Expr)
if ex.head == :call && length(ex.args) > 1
return ex.args[2:end]
end

return x
if ex.head == :. && length(ex.args) == 2 && Meta.isexpr(ex.args[2], :tuple) && !isempty(ex.args[2].args)
return ex.args[2].args
end

error("expression is not a function call with arguments")
end

"""
Expand All @@ -152,7 +172,15 @@ Split arguments of function call `ex` before last argument and combine first
arguments to one expression if possible.
"""
function splitargs(ex)
@capture(ex, (_)(x__, y_) | (_).(x__, y_)) || error("cannot split arguments")
if ex.head == :call && length(ex.args) > 2
x = ex.args[2:end-1]
y = ex.args[end]
elseif ex.head == :. && length(ex.args) == 2 && Meta.isexpr(ex.args[2], :tuple) && length(ex.args[2].args) > 1
x = ex.args[2].args[1:end-1]
y = ex.args[2].args[end]
else
error("cannot split arguments")
end

return newargs(ex, x...), y
end
Expand All @@ -165,15 +193,16 @@ Create new expression of function call `ex` with arguments `args`.
Unary function calls are not considered, i.e. if only one function
argument is provided it is returned.
"""
function newargs(ex, args...)
function newargs(ex::Expr, args...)
# Return single argument
length(args) == 1 && return args[1]

# Create function calls with new arguments
if @capture(ex, (f_)(__))
return :($f($(args...)))
elseif @capture(ex, (f_).(__))
return :(($f).($(args...)))
if ex.head == :call && !isempty(ex.args)
return Expr(:call, ex.args[1], args...)
end
if ex.head == :. && length(ex.args) == 2 && Meta.isexpr(ex.args[2], :tuple)
return Expr(:., ex.args[1], Expr(:tuple, args...))
end

error("expression is not a function call")
Expand Down
8 changes: 8 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,16 @@ end
:($(Base.muladd)(e, f, $(Base.muladd)(c, d, a*b)))
@test @macroexpand(@muladd a*(b*c+d)+e) ==
:($(Base.muladd)(a, $(Base.muladd)(b, c, d), e))

@test @macroexpand(@muladd +a) == :(+a)
end

@testset "Subtraction" begin
@test @macroexpand(@muladd a*b-c*d) == :($(Base.muladd)(-c, d, a*b))
@test @macroexpand(@muladd a*(b*c-d)-e) ==
:($(Base.muladd)(a, $(Base.muladd)(b, c, -d), -e))

@test @macroexpand(@muladd -a) == :(-a)
end
end

Expand All @@ -64,6 +68,8 @@ end
@test @macroexpand(@muladd f.(a)*b+c) == :($(Base.muladd)(f.(a), b, c))
@test @macroexpand(@muladd a*f.(b)+c) == :($(Base.muladd)(a, f.(b), c))
@test @macroexpand(@muladd a*b+f.(c)) == :($(Base.muladd)(a, b, f.(c)))

@test @macroexpand(@muladd .+a) == :(.+a)
end

@testset "Subtraction" begin
Expand All @@ -76,6 +82,8 @@ end
@test @macroexpand(@muladd @. a-b*c) == :($(Base.muladd).((-).(b), c, a))
@test @macroexpand(@muladd a-b.*c) == :(a-b.*c)
@test @macroexpand(@muladd a.-b*c) == :(a.-b*c)

@test @macroexpand(@muladd .-a) == :(.-a)
end
end

Expand Down

0 comments on commit 1edd7f4

Please sign in to comment.