Skip to content

Commit

Permalink
Merge pull request #4 from devmotion/eval_order
Browse files Browse the repository at this point in the history
Correct order of evaluation
  • Loading branch information
devmotion authored Aug 2, 2017
2 parents 748243b + cc0402b commit bd28aba
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 30 deletions.
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,15 @@ This package provides the `@muladd` macro. It automatically converts expressions
with multiplications and additions to calls with `muladd` which then fuse via
FMA when it would increase the performance of the code. The `@muladd` macro
can be placed on code blocks and it will automatically find the appropriate
expressions and nest muladd expressions when necessary.
expressions and nest muladd expressions when necessary. In mixed expressions summands without multiplication will be grouped together and evaluated first but otherwise the order of evaluation of multiplications and additions is not changed.

## Examples

```julia
julia> macroexpand(:(@muladd k3 = f(t + c3*dt, @. uprev+dt*(a031*k1+a032*k2))))
:(k3 = f((muladd)(c3, dt, t), (muladd).(dt, (muladd).(a031, k1, *.(a032, k2)), uprev)))

:(k3 = f((muladd)(c3, dt, t), (muladd).(dt, (muladd).(a032, k2, *.(a031, k1)), uprev)))
julia> macroexpand(:(@muladd integrator.EEst = integrator.opts.internalnorm((update - dt*(bhat1*k1 + bhat4*k4 + bhat5*k5 + bhat6*k6 + bhat7*k7 + bhat10*k10))./ @. (integrator.opts.abstol+max(abs(uprev),abs(u))*integrator.opts.reltol))))
:(integrator.EEst = integrator.opts.internalnorm((update - dt * (muladd)(bhat1, k1, (muladd)(bhat4, k4, (muladd)(bhat5, k5, (muladd)(bhat6, k6, (muladd)(bhat7, k7, bhat10 * k10)))))) ./ (muladd).(max.(abs.(uprev), abs.(u)), integrator.opts.reltol, integrator.opts.abstol)))
:(integrator.EEst = integrator.opts.internalnorm((update - dt * (muladd)(bhat10, k10, (muladd)(bhat7, k7, (muladd)(bhat6, k6, (muladd)(bhat5, k5, (muladd)(bhat4, k4, bhat1 * k1)))))) ./ (muladd).(max.(abs.(uprev), abs.(u)), integrator.opts.reltol, integrator.opts.abstol)))
```

## Broadcasting
Expand Down
40 changes: 20 additions & 20 deletions src/MuladdMacro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,53 +37,53 @@ function to_muladd(ex::Expr)
# define summands that are reduced with muladd and the initial element of the reduction
if isempty(odd_operands)
# if all summands are multiplications one of these summands is
# the initial element of the reduction
to_be_muladded = mul_operands[1:end-1]
last_operation = mul_operands[end]
# the initial element of the reduction and evaluated first
first_operation = mul_operands[1]
to_be_muladded = mul_operands[2:end]
else
to_be_muladded = mul_operands

# expressions that are no multiplications are summed up in a separate expression
# that is the initial element of the reduction
# that is the initial element of the reduction and evaluated first
# if the original addition was a dot call this expression also is a dot call
if length(odd_operands) == 1
last_operation = odd_operands[1]
first_operation = odd_operands[1]
elseif isdotcall(ex)
# make sure returned expression has same style as original expression
if ex.head == :.
last_operation = Expr(:., :+, Expr(:tuple, odd_operands...))
first_operation = Expr(:., :+, Expr(:tuple, odd_operands...))
else
last_operation = Expr(:call, :.+, odd_operands...)
first_operation = Expr(:call, :.+, odd_operands...)
end
else
last_operation = Expr(:call, :+, odd_operands...)
first_operation = Expr(:call, :+, odd_operands...)
end
end

# reduce sum to a composition of muladd
foldr(last_operation, to_be_muladded) do xs, r
foldl(first_operation, to_be_muladded) do last_expr, next_expr
# retrieve factors of multiplication that will be reduced next
xs_operands = operands(xs)
next_operands = operands(next_expr)

# first factor is always first operand
xs_factor1 = xs_operands[1]
# second factor is always last operand
next_factor2 = next_operands[end]

# second factor is an expression of a multiplication if there are more than
# first factor is an expression of a multiplication if there are more than
# two operands
# if the original multiplication was a dot call this expression also is a dot call
if length(xs_operands) == 2
xs_factor2 = xs_operands[2]
elseif isdotcall(xs)
xs_factor2 = Expr(:., :*, Expr(:tuple, xs_operands[2:end]...))
if length(next_operands) == 2
next_factor1 = next_operands[1]
elseif isdotcall(next_expr)
next_factor1 = Expr(:., :*, Expr(:tuple, next_operands[1:end-1]...))
else
xs_factor2 = Expr(:call, :*, xs_operands[2:end]...)
next_factor1 = Expr(:call, :*, next_operands[1:end-1]...)
end

# create a dot call if both involved operators are dot calls
if isdotcall(ex)
Expr(:., Base.muladd, Expr(:tuple, xs_factor1, xs_factor2, r))
Expr(:., Base.muladd, Expr(:tuple, next_factor1, next_factor2, last_expr))
else
Expr(:call, Base.muladd, xs_factor1, xs_factor2, r)
Expr(:call, Base.muladd, next_factor1, next_factor2, last_expr)
end
end
end
Expand Down
12 changes: 6 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ end

# Additional factors
@testset "Additional factors" begin
@test @macroexpand(@muladd a*b*c+d) == :($(Base.muladd)(a, b*c, d))
@test @macroexpand(@muladd a*b*c*d+e) == :($(Base.muladd)(a, b*c*d, e))
@test @macroexpand(@muladd a*b*c+d) == :($(Base.muladd)(a*b, c, d))
@test @macroexpand(@muladd a*b*c*d+e) == :($(Base.muladd)(a*b*c, d, e))
end

# Multiple multiplications
@testset "Multiple multiplications" begin
@test @macroexpand(@muladd a*b+c*d) == :($(Base.muladd)(a, b, c*d))
@test @macroexpand(@muladd a*b+c*d+e*f) == :($(Base.muladd)(a, b,
$(Base.muladd)(c, d, e*f)))
@test @macroexpand(@muladd a*b+c*d) == :($(Base.muladd)(c, d, a*b))
@test @macroexpand(@muladd a*b+c*d+e*f) == :($(Base.muladd)(e, f,
$(Base.muladd)(c, d, a*b)))
@test @macroexpand(@muladd a*(b*c+d)+e) == :($(Base.muladd)(a,
$(Base.muladd)(b, c, d), e))
end
Expand All @@ -32,7 +32,7 @@ end
@test @macroexpand(@muladd a*b.+c) == :(a*b.+c)
@test @macroexpand(@muladd .+(a.*b, c, d)) == :($(Base.muladd).(a, b, c.+d))
@test @macroexpand(@muladd @. a*b+c+d) == :($(Base.muladd).(a, b, (+).(c, d)))
@test @macroexpand(@muladd @. a*b*c+d) == :($(Base.muladd).(a, (*).(b, c), d))
@test @macroexpand(@muladd @. a*b*c+d) == :($(Base.muladd).((*).(a, b), c, d))
@test @macroexpand(@muladd f.(a)*b+c) == :($(Base.muladd)(f.(a), b, c))
@test @macroexpand(@muladd a*f.(b)+c) == :($(Base.muladd)(a, f.(b), c))
@test @macroexpand(@muladd a*b+f.(c)) == :($(Base.muladd)(a, b, f.(c)))
Expand Down

0 comments on commit bd28aba

Please sign in to comment.