Skip to content

Commit

Permalink
Fix #9
Browse files Browse the repository at this point in the history
  • Loading branch information
YingboMa committed Jul 22, 2018
1 parent d877284 commit 2ebae3e
Showing 1 changed file with 50 additions and 20 deletions.
70 changes: 50 additions & 20 deletions src/MuladdMacro.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__precompile__(true)
module MuladdMacro

using MacroTools: postwalk, @capture
using MacroTools: postwalk

"""
@muladd
Expand Down Expand Up @@ -68,7 +69,9 @@ Replace subtraction `ex` by `muladd` if possible.
"""
function sub_to_muladd(ex)
# Retrieve operands
x, y = args(ex)
_x = args(ex)
length(_x) == 1 && return ex
x, y = _x

# Modify subtraction if possible
dotcall = isdotcall(ex)
Expand All @@ -87,23 +90,33 @@ end
Determine whether `ex` is a dot call.
"""
isdotcall(ex) =
(@capture(ex, (f_)(__)) && startswith(string(f), '.')) ||
@capture(ex, (_).(__))
isdotcall(ex) = typeof(ex) <: Expr ? ex.head == :. || string(ex.args[1])[1] == '.' : false

"""
iscall(ex, op)
Determine whether `ex` is a call to `op`.
"""
function iscall(ex, op)
!(typeof(ex) <: Expr) && return false
length(ex.args) == 0 && return false
_op = ex.args[1]
return _op == op || _op == Symbol(:., op)
end

"""
issum(ex)
Determine whether `ex` is a sum.
"""
issum(ex) = @capture(ex, +(__) | .+(__) | (+).(__))
issum(ex) = iscall(ex, :+)

"""
issub(ex)
Determine whether `ex` is a subtraction.
"""
issub(ex) = @capture(ex, -(_, _) | .-(_, _) | (-).(_, _))
issub(ex) = iscall(ex, :-)

"""
ismul(ex, dot::Bool)
Expand All @@ -112,10 +125,13 @@ Determine whether expression `ex` is a multiplication that is dotted if
`dot` is `true` and not dotted otherwise.
"""
function ismul(ex, dot::Bool)
!(typeof(ex) <: Expr) && return false
length(ex.args) == 0 && return false
if dot
return @capture(ex, .*(__) | (*).(__))
return typeof(ex) <: Expr &&
( ex.args[1] == :.* || (ex.head == :. && ex.args[1] == :*) )
else
return @capture(ex, *(__))
return typeof(ex) <: Expr && ex.args[1] == :*
end
end

Expand All @@ -140,9 +156,17 @@ end
Return arguments of function call in `ex`.
"""
function args(ex)
@capture(ex, (_)(x__) | (_).(x__)) || error("expression is not a function call")

return x
if typeof(ex) <: Expr
if ex.head == :call
x = ex.args[2:end]
return x
end
if ex.head == :.
x = ex.args[2].args[1:end]
return x
end
end
error("expression is not a function call")
end

"""
Expand All @@ -152,9 +176,17 @@ Split arguments of function call `ex` before last argument and combine first
arguments to one expression if possible.
"""
function splitargs(ex)
@capture(ex, (_)(x__, y_) | (_).(x__, y_)) || error("cannot split arguments")

return newargs(ex, x...), y
if ex.head == :call
x = ex.args[2:end-1]
y = ex.args[end]
return newargs(ex, x...), y
end
if ex.head == :.
x = ex.args[2].args[1:end-1]
y = ex.args[2].args[end]
return newargs(ex, x...), y
end
error("cannot split arguments")
end

"""
Expand All @@ -168,13 +200,11 @@ argument is provided it is returned.
function newargs(ex, args...)
# Return single argument
length(args) == 1 && return args[1]
f = ex.args[1]

# Create function calls with new arguments
if @capture(ex, (f_)(__))
return :($f($(args...)))
elseif @capture(ex, (f_).(__))
return :(($f).($(args...)))
end
ex.head == :. && return :(($f).($(args...)))
ex.head == :call && return :(($f)($(args...)))

error("expression is not a function call")
end
Expand Down

0 comments on commit 2ebae3e

Please sign in to comment.