Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle subtractions and use MacroTools #7

Merged
merged 2 commits into from
Jul 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,24 @@
[![codecov.io](http://codecov.io/github/JuliaDiffEq/MuladdMacro.jl/coverage.svg?branch=master)](http://codecov.io/github/JuliaDiffEq/MuladdMacro.jl?branch=master)

This package provides the `@muladd` macro. It automatically converts expressions
with multiplications and additions to calls with `muladd` which then fuse via
with multiplications and additions or subtractions to calls with `muladd` which then fuse via
FMA when it would increase the performance of the code. The `@muladd` macro
can be placed on code blocks and it will automatically find the appropriate
expressions and nest muladd expressions when necessary. In mixed expressions summands without multiplication will be grouped together and evaluated first but otherwise the order of evaluation of multiplications and additions is not changed.

## Examples

```julia
julia> macroexpand(:(@muladd k3 = f(t + c3*dt, @. uprev+dt*(a031*k1+a032*k2))))
:(k3 = f((muladd)(c3, dt, t), (muladd).(dt, (muladd).(a032, k2, *.(a031, k1)), uprev)))
julia> macroexpand(:(@muladd integrator.EEst = integrator.opts.internalnorm((update - dt*(bhat1*k1 + bhat4*k4 + bhat5*k5 + bhat6*k6 + bhat7*k7 + bhat10*k10))./ @. (integrator.opts.abstol+max(abs(uprev),abs(u))*integrator.opts.reltol))))
:(integrator.EEst = integrator.opts.internalnorm((update - dt * (muladd)(bhat10, k10, (muladd)(bhat7, k7, (muladd)(bhat6, k6, (muladd)(bhat5, k5, (muladd)(bhat4, k4, bhat1 * k1)))))) ./ (muladd).(max.(abs.(uprev), abs.(u)), integrator.opts.reltol, integrator.opts.abstol)))
julia> @macroexpand(@muladd k3 = f(t + c3*dt, @. uprev+dt*(a031*k1+a032*k2)))
:(k3 = f((muladd)(c3, dt, t), (muladd).(dt, (muladd).(a032, k2, (*).(a031, k1)), uprev)))

julia> @macroexpand(@muladd integrator.EEst = integrator.opts.internalnorm((update - dt*(bhat1*k1 + bhat4*k4 + bhat5*k5 + bhat6*k6 + bhat7*k7 + bhat10*k10))./ @. (integrator.opts.abstol+max(abs(uprev),abs(u))*integrator.opts.reltol)))
:(integrator.EEst = (integrator.opts).internalnorm((muladd)(-dt, (muladd)(bhat10, k10, (muladd)(bhat7, k7, (muladd)(bhat6, k6, (muladd)(bhat5, k5, (muladd)(bhat4, k4, bhat1 * k1))))), update) ./ (muladd).(max.(abs.(uprev), abs.(u)), (integrator.opts).reltol, (integrator.opts).abstol)))
```

## Broadcasting

A `muladd` call will be broadcasted if both the `*` and the `+` are broadcasted.
A `muladd` call will be broadcasted if both the `*` and the `+` or `-` are broadcasted.
If either one is not broadcasted, then the expression will be converted to a
non-dotted `muladd`.

Expand Down
1 change: 1 addition & 0 deletions REQUIRE
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
julia 0.7-beta
MacroTools
242 changes: 144 additions & 98 deletions src/MuladdMacro.jl
Original file line number Diff line number Diff line change
@@ -1,138 +1,184 @@
module MuladdMacro

using MacroTools: postwalk, @capture

"""
@muladd ex
@muladd

Convert every combination of addition/subtraction and multiplication to a call of `muladd`.

Convert every combined multiplication and addition in `ex` into a call of `muladd`. If both
of the involved operators are dotted, `muladd` is applied as a "dot call".
If both of the involved operators are dotted, `muladd` is applied as a dot call.
The order of summation might be changed.
"""
macro muladd(ex)
esc(to_muladd(ex))
esc(to_muladd(ex))
end

function to_muladd(ex::Expr)
if !isaddition(ex)
if ex.head == :macrocall && length(ex.args)>=2 && ex.args[1] == Symbol("@__dot__")
# expand @. macros first (enables use of @. inside of @muladd expression)
return to_muladd(Base.Broadcast.__dot__(last(ex.args)))
else
# if expression is no sum apply the reduction to its arguments
return Expr(ex.head, to_muladd.(ex.args)...)
end
"""
to_muladd(ex)

Convert every combination of addition/subtraction and multiplication in expression `ex` to a call of `muladd`.

If both of the involved operators are dotted, `muladd` is applied as a dot call
The order of summation might be changed.
"""
function to_muladd(ex)
postwalk(ex) do x
# Modify summations
issum(x) && return sum_to_muladd(x)

# Modify subtractions
issub(x) && return sub_to_muladd(x)

return x
end
end

"""
sum_to_muladd(ex)

Replace sum `ex` by sequence of `muladd` if possible. Hereby the order of summation might be changed.
"""
function sum_to_muladd(ex)
# Retrieve summands
summands = args(ex)

# retrieve summands of addition and split them into two groups, one with expressions
# Skip further calculations if no summand is a multiplication
dotcall = isdotcall(ex)
any(x -> ismul(x, dotcall), summands) || return ex

# Split summands into two groups, one with expressions
# of multiplications and one with other expressions
# if addition is a dot call multiplications must be dot calls as well; if the addition
# is a regular operation only regular multiplications are filtered
all_operands = to_muladd.(operands(ex))
if isdotcall(ex)
mul_operands = filter(x->isdotcall(x, :*), all_operands)
odd_operands = filter(x->!isdotcall(x, :*), all_operands)
else
mul_operands = filter(x->isoperation(x, :*), all_operands)
odd_operands = filter(x->!isoperation(x, :*), all_operands)
end
mulsummands = filter(x -> ismul(x, dotcall), summands)
oddsummands = filter(x -> !ismul(x, dotcall), summands)

# define summands that are reduced with muladd and the initial element of the reduction
if isempty(odd_operands)
# if all summands are multiplications one of these summands is
# the initial element of the reduction and evaluated first
first_operation = mul_operands[1]
to_be_muladded = mul_operands[2:end]
else
to_be_muladded = mul_operands

# expressions that are no multiplications are summed up in a separate expression
# that is the initial element of the reduction and evaluated first
# if the original addition was a dot call this expression also is a dot call
if length(odd_operands) == 1
first_operation = odd_operands[1]
elseif isdotcall(ex)
# make sure returned expression has same style as original expression
if ex.head == :.
first_operation = Expr(:., :+, Expr(:tuple, odd_operands...))
else
first_operation = Expr(:call, :.+, odd_operands...)
end
else
first_operation = Expr(:call, :+, odd_operands...)
end
# If all summands are multiplications the first one is not reduced
isempty(oddsummands) && push!(oddsummands, popfirst!(mulsummands))

# Reduce sum to a composition of muladd
foldl(mulsummands; init = newargs(ex, oddsummands...)) do s₁, s₂
newmuladd(splitargs(s₂)..., s₁, dotcall)
end
end

# reduce sum to a composition of muladd
foldl(to_be_muladded; init=first_operation) do last_expr, next_expr
# retrieve factors of multiplication that will be reduced next
next_operands = operands(next_expr)

# second factor is always last operand
next_factor2 = next_operands[end]

# first factor is an expression of a multiplication if there are more than
# two operands
# if the original multiplication was a dot call this expression also is a dot call
if length(next_operands) == 2
next_factor1 = next_operands[1]
elseif isdotcall(next_expr)
next_factor1 = Expr(:., :*, Expr(:tuple, next_operands[1:end-1]...))
else
next_factor1 = Expr(:call, :*, next_operands[1:end-1]...)
end

# create a dot call if both involved operators are dot calls
if isdotcall(ex)
Expr(:., Base.muladd, Expr(:tuple, next_factor1, next_factor2, last_expr))
else
Expr(:call, Base.muladd, next_factor1, next_factor2, last_expr)
end
"""
sub_to_muladd(ex)

Replace subtraction `ex` by `muladd` if possible.
"""
function sub_to_muladd(ex)
# Retrieve operands
x, y = args(ex)

# Modify subtraction if possible
dotcall = isdotcall(ex)
if ismul(y, dotcall)
y₁, y₂ = splitargs(y)
return newmuladd(:(-$y₁), y₂, x, dotcall)
elseif ismul(x, dotcall)
return newmuladd(splitargs(x)..., :(-$y), dotcall)
end

return ex
end
to_muladd(ex) = ex

"""
isoperation(ex, op::Symbol)
isdotcall(ex)

Determine whether `ex` is a call of operation `op`.
Determine whether `ex` is a dot call.
"""
isoperation(ex::Expr, op::Symbol) =
ex.head == :call && !isempty(ex.args) && ex.args[1] == op
isoperation(ex, op::Symbol) = false
isdotcall(ex) =
(@capture(ex, (f_)(__)) && startswith(string(f), '.')) ||
@capture(ex, (_).(__))

"""
isdotcall(ex[, op])
issum(ex)

Determine whether `ex` is a dot call and, in case `op` is specified, whether it calls
operator `op`.
Determine whether `ex` is a sum.
"""
isdotcall(ex::Expr) = !isempty(ex.args) &&
(ex.head == :. ||
(ex.head == :call && !isempty(ex.args) && first(string(ex.args[1])) == '.'))
isdotcall(ex) = false
issum(ex) = @capture(ex, +(__) | .+(__) | (+).(__))

isdotcall(ex::Expr, op::Symbol) = isdotcall(ex) &&
(ex.args[1] == op || ex.args[1] == Symbol('.', op))
isdotcall(ex, op::Symbol) = false
"""
issub(ex)

Determine whether `ex` is a subtraction.
"""
isaddition(ex)
issub(ex) = @capture(ex, -(_, _) | .-(_, _) | (-).(_, _))

Determine whether `ex` is an expression of an addition.
"""
isaddition(ex) = isoperation(ex, :+) || isdotcall(ex, :+)
ismul(ex, dot::Bool)

Determine whether expression `ex` is a multiplication that is dotted if
`dot` is `true` and not dotted otherwise.
"""
operands(ex)
function ismul(ex, dot::Bool)
if dot
return @capture(ex, .*(__) | (*).(__))
else
return @capture(ex, *(__))
end
end

Return arguments of function call in `ex`.
"""
function operands(ex::Expr)
if ex.head == :. && length(ex.args) == 2 && typeof(ex.args[2]) <: Expr
ex.args[2].args
newmuladd(x, y, z, dot::Bool)

Return expression `(muladd).(x, y, z)` if `dot` is `true` and
`muladd(x, y, z)` otherwise.
"""
function newmuladd(x, y, z, dot::Bool)
# Quoting seems to be required for the @. macro to work
if dot
:(($(Meta.quot(Base.muladd))).($x, $y, $z))
else
ex.args[2:end]
:($(Meta.quot(Base.muladd))($x, $y, $z))
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why quoting is needed here and whether that's the correct way to handle the application of @. macros. But it was the only way I could get it to work besides expanding @. macros immediately which in my opinion should not be required (or at least feels a bit suboptimal).

end

"""
args(ex)

Return arguments of function call in `ex`.
"""
function args(ex)
@capture(ex, (_)(x__) | (_).(x__)) || error("expression is not a function call")

return x
end

"""
splitargs(ex)

Split arguments of function call `ex` before last argument and combine first
arguments to one expression if possible.
"""
function splitargs(ex)
@capture(ex, (_)(x__, y_) | (_).(x__, y_)) || error("cannot split arguments")

return newargs(ex, x...), y
end

"""
newargs(ex, args)

Create new expression of function call `ex` with arguments `args`.

Unary function calls are not considered, i.e. if only one function
argument is provided it is returned.
"""
function newargs(ex, args...)
# Return single argument
length(args) == 1 && return args[1]

# Create function calls with new arguments
if @capture(ex, (f_)(__))
return :($f($(args...)))
elseif @capture(ex, (f_).(__))
return :(($f).($(args...)))
end

error("expression is not a function call")
end

export @muladd

end # module
Loading