-
-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Handle subtractions and use MacroTools #7
Merged
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
julia 0.7-beta | ||
MacroTools |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,138 +1,184 @@ | ||
module MuladdMacro | ||
|
||
using MacroTools: postwalk, @capture | ||
|
||
""" | ||
@muladd ex | ||
@muladd | ||
|
||
Convert every combination of addition/subtraction and multiplication to a call of `muladd`. | ||
|
||
Convert every combined multiplication and addition in `ex` into a call of `muladd`. If both | ||
of the involved operators are dotted, `muladd` is applied as a "dot call". | ||
If both of the involved operators are dotted, `muladd` is applied as a dot call. | ||
The order of summation might be changed. | ||
""" | ||
macro muladd(ex) | ||
esc(to_muladd(ex)) | ||
esc(to_muladd(ex)) | ||
end | ||
|
||
function to_muladd(ex::Expr) | ||
if !isaddition(ex) | ||
if ex.head == :macrocall && length(ex.args)>=2 && ex.args[1] == Symbol("@__dot__") | ||
# expand @. macros first (enables use of @. inside of @muladd expression) | ||
return to_muladd(Base.Broadcast.__dot__(last(ex.args))) | ||
else | ||
# if expression is no sum apply the reduction to its arguments | ||
return Expr(ex.head, to_muladd.(ex.args)...) | ||
end | ||
""" | ||
to_muladd(ex) | ||
|
||
Convert every combination of addition/subtraction and multiplication in expression `ex` to a call of `muladd`. | ||
|
||
If both of the involved operators are dotted, `muladd` is applied as a dot call | ||
The order of summation might be changed. | ||
""" | ||
function to_muladd(ex) | ||
postwalk(ex) do x | ||
# Modify summations | ||
issum(x) && return sum_to_muladd(x) | ||
|
||
# Modify subtractions | ||
issub(x) && return sub_to_muladd(x) | ||
|
||
return x | ||
end | ||
end | ||
|
||
""" | ||
sum_to_muladd(ex) | ||
|
||
Replace sum `ex` by sequence of `muladd` if possible. Hereby the order of summation might be changed. | ||
""" | ||
function sum_to_muladd(ex) | ||
# Retrieve summands | ||
summands = args(ex) | ||
|
||
# retrieve summands of addition and split them into two groups, one with expressions | ||
# Skip further calculations if no summand is a multiplication | ||
dotcall = isdotcall(ex) | ||
any(x -> ismul(x, dotcall), summands) || return ex | ||
|
||
# Split summands into two groups, one with expressions | ||
# of multiplications and one with other expressions | ||
# if addition is a dot call multiplications must be dot calls as well; if the addition | ||
# is a regular operation only regular multiplications are filtered | ||
all_operands = to_muladd.(operands(ex)) | ||
if isdotcall(ex) | ||
mul_operands = filter(x->isdotcall(x, :*), all_operands) | ||
odd_operands = filter(x->!isdotcall(x, :*), all_operands) | ||
else | ||
mul_operands = filter(x->isoperation(x, :*), all_operands) | ||
odd_operands = filter(x->!isoperation(x, :*), all_operands) | ||
end | ||
mulsummands = filter(x -> ismul(x, dotcall), summands) | ||
oddsummands = filter(x -> !ismul(x, dotcall), summands) | ||
|
||
# define summands that are reduced with muladd and the initial element of the reduction | ||
if isempty(odd_operands) | ||
# if all summands are multiplications one of these summands is | ||
# the initial element of the reduction and evaluated first | ||
first_operation = mul_operands[1] | ||
to_be_muladded = mul_operands[2:end] | ||
else | ||
to_be_muladded = mul_operands | ||
|
||
# expressions that are no multiplications are summed up in a separate expression | ||
# that is the initial element of the reduction and evaluated first | ||
# if the original addition was a dot call this expression also is a dot call | ||
if length(odd_operands) == 1 | ||
first_operation = odd_operands[1] | ||
elseif isdotcall(ex) | ||
# make sure returned expression has same style as original expression | ||
if ex.head == :. | ||
first_operation = Expr(:., :+, Expr(:tuple, odd_operands...)) | ||
else | ||
first_operation = Expr(:call, :.+, odd_operands...) | ||
end | ||
else | ||
first_operation = Expr(:call, :+, odd_operands...) | ||
end | ||
# If all summands are multiplications the first one is not reduced | ||
isempty(oddsummands) && push!(oddsummands, popfirst!(mulsummands)) | ||
|
||
# Reduce sum to a composition of muladd | ||
foldl(mulsummands; init = newargs(ex, oddsummands...)) do s₁, s₂ | ||
newmuladd(splitargs(s₂)..., s₁, dotcall) | ||
end | ||
end | ||
|
||
# reduce sum to a composition of muladd | ||
foldl(to_be_muladded; init=first_operation) do last_expr, next_expr | ||
# retrieve factors of multiplication that will be reduced next | ||
next_operands = operands(next_expr) | ||
|
||
# second factor is always last operand | ||
next_factor2 = next_operands[end] | ||
|
||
# first factor is an expression of a multiplication if there are more than | ||
# two operands | ||
# if the original multiplication was a dot call this expression also is a dot call | ||
if length(next_operands) == 2 | ||
next_factor1 = next_operands[1] | ||
elseif isdotcall(next_expr) | ||
next_factor1 = Expr(:., :*, Expr(:tuple, next_operands[1:end-1]...)) | ||
else | ||
next_factor1 = Expr(:call, :*, next_operands[1:end-1]...) | ||
end | ||
|
||
# create a dot call if both involved operators are dot calls | ||
if isdotcall(ex) | ||
Expr(:., Base.muladd, Expr(:tuple, next_factor1, next_factor2, last_expr)) | ||
else | ||
Expr(:call, Base.muladd, next_factor1, next_factor2, last_expr) | ||
end | ||
""" | ||
sub_to_muladd(ex) | ||
|
||
Replace subtraction `ex` by `muladd` if possible. | ||
""" | ||
function sub_to_muladd(ex) | ||
# Retrieve operands | ||
x, y = args(ex) | ||
|
||
# Modify subtraction if possible | ||
dotcall = isdotcall(ex) | ||
if ismul(y, dotcall) | ||
y₁, y₂ = splitargs(y) | ||
return newmuladd(:(-$y₁), y₂, x, dotcall) | ||
elseif ismul(x, dotcall) | ||
return newmuladd(splitargs(x)..., :(-$y), dotcall) | ||
end | ||
|
||
return ex | ||
end | ||
to_muladd(ex) = ex | ||
|
||
""" | ||
isoperation(ex, op::Symbol) | ||
isdotcall(ex) | ||
|
||
Determine whether `ex` is a call of operation `op`. | ||
Determine whether `ex` is a dot call. | ||
""" | ||
isoperation(ex::Expr, op::Symbol) = | ||
ex.head == :call && !isempty(ex.args) && ex.args[1] == op | ||
isoperation(ex, op::Symbol) = false | ||
isdotcall(ex) = | ||
(@capture(ex, (f_)(__)) && startswith(string(f), '.')) || | ||
@capture(ex, (_).(__)) | ||
|
||
""" | ||
isdotcall(ex[, op]) | ||
issum(ex) | ||
|
||
Determine whether `ex` is a dot call and, in case `op` is specified, whether it calls | ||
operator `op`. | ||
Determine whether `ex` is a sum. | ||
""" | ||
isdotcall(ex::Expr) = !isempty(ex.args) && | ||
(ex.head == :. || | ||
(ex.head == :call && !isempty(ex.args) && first(string(ex.args[1])) == '.')) | ||
isdotcall(ex) = false | ||
issum(ex) = @capture(ex, +(__) | .+(__) | (+).(__)) | ||
|
||
isdotcall(ex::Expr, op::Symbol) = isdotcall(ex) && | ||
(ex.args[1] == op || ex.args[1] == Symbol('.', op)) | ||
isdotcall(ex, op::Symbol) = false | ||
""" | ||
issub(ex) | ||
|
||
Determine whether `ex` is a subtraction. | ||
""" | ||
isaddition(ex) | ||
issub(ex) = @capture(ex, -(_, _) | .-(_, _) | (-).(_, _)) | ||
|
||
Determine whether `ex` is an expression of an addition. | ||
""" | ||
isaddition(ex) = isoperation(ex, :+) || isdotcall(ex, :+) | ||
ismul(ex, dot::Bool) | ||
|
||
Determine whether expression `ex` is a multiplication that is dotted if | ||
`dot` is `true` and not dotted otherwise. | ||
""" | ||
operands(ex) | ||
function ismul(ex, dot::Bool) | ||
if dot | ||
return @capture(ex, .*(__) | (*).(__)) | ||
else | ||
return @capture(ex, *(__)) | ||
end | ||
end | ||
|
||
Return arguments of function call in `ex`. | ||
""" | ||
function operands(ex::Expr) | ||
if ex.head == :. && length(ex.args) == 2 && typeof(ex.args[2]) <: Expr | ||
ex.args[2].args | ||
newmuladd(x, y, z, dot::Bool) | ||
|
||
Return expression `(muladd).(x, y, z)` if `dot` is `true` and | ||
`muladd(x, y, z)` otherwise. | ||
""" | ||
function newmuladd(x, y, z, dot::Bool) | ||
# Quoting seems to be required for the @. macro to work | ||
if dot | ||
:(($(Meta.quot(Base.muladd))).($x, $y, $z)) | ||
else | ||
ex.args[2:end] | ||
:($(Meta.quot(Base.muladd))($x, $y, $z)) | ||
end | ||
end | ||
|
||
""" | ||
args(ex) | ||
|
||
Return arguments of function call in `ex`. | ||
""" | ||
function args(ex) | ||
@capture(ex, (_)(x__) | (_).(x__)) || error("expression is not a function call") | ||
|
||
return x | ||
end | ||
|
||
""" | ||
splitargs(ex) | ||
|
||
Split arguments of function call `ex` before last argument and combine first | ||
arguments to one expression if possible. | ||
""" | ||
function splitargs(ex) | ||
@capture(ex, (_)(x__, y_) | (_).(x__, y_)) || error("cannot split arguments") | ||
|
||
return newargs(ex, x...), y | ||
end | ||
|
||
""" | ||
newargs(ex, args) | ||
|
||
Create new expression of function call `ex` with arguments `args`. | ||
|
||
Unary function calls are not considered, i.e. if only one function | ||
argument is provided it is returned. | ||
""" | ||
function newargs(ex, args...) | ||
# Return single argument | ||
length(args) == 1 && return args[1] | ||
|
||
# Create function calls with new arguments | ||
if @capture(ex, (f_)(__)) | ||
return :($f($(args...))) | ||
elseif @capture(ex, (f_).(__)) | ||
return :(($f).($(args...))) | ||
end | ||
|
||
error("expression is not a function call") | ||
end | ||
|
||
export @muladd | ||
|
||
end # module |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure why quoting is needed here and whether that's the correct way to handle the application of
@.
macros. But it was the only way I could get it to work besides expanding@.
macros immediately which in my opinion should not be required (or at least feels a bit suboptimal).