Skip to content

Commit

Permalink
simplify contract pass
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Jun 21, 2023
1 parent 2e06010 commit d03c179
Showing 1 changed file with 21 additions and 30 deletions.
51 changes: 21 additions & 30 deletions src/vectorizationbase_compat/contract_pass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,42 +11,32 @@ function mulexprcost(@nospecialize(x::ProdArg))::Int
return base + length(ex.args)
end
end
function mul_fast_expr(args::SubArray{Any, 1, Vector{Any}, Tuple{UnitRange{Int64}}, true})::Expr
function mul_fast_expr(
args::SubArray{Any,1,Vector{Any},Tuple{UnitRange{Int64}},true}
)::Expr
b = Expr(:call, :mul_fast)
for i 2:length(args)
push!(b.args, args[i])
end
b
end
function mulexpr(mulexargs::SubArray{Any, 1, Vector{Any}, Tuple{UnitRange{Int64}}, true})::Tuple{ProdArg,ProdArg}
function mulexpr(
mulexargs::SubArray{Any,1,Vector{Any},Tuple{UnitRange{Int64}},true}
)::Tuple{ProdArg,ProdArg}
a = (mulexargs[1])::ProdArg
if length(mulexargs) == 2
return (a, mulexargs[2]::ProdArg)
elseif length(mulexargs) == 3
# We'll calc the product between the guesstimated cheaper two args first, for better out of order execution
b = (mulexargs[2])::ProdArg
c = (mulexargs[3])::ProdArg
ac = mulexprcost(a)
bc = mulexprcost(b)
cc = mulexprcost(c)
maxc = max(ac, bc, cc)
if ac == maxc
return (a, Expr(:call, :mul_fast, b, c))
elseif bc == maxc
return (b, Expr(:call, :mul_fast, a, c))
else
return (c, Expr(:call, :mul_fast, a, b))
end
else
return (a, mul_fast_expr(mulexargs))
end
a = (mulexargs[1])::Union{Symbol,Expr,Number}
b = if length(mulexargs) == 2 # two arg mul
(mulexargs[2])::Union{Symbol,Expr,Number}
else
mul_fast_expr(mulexargs)
end
a, b
Nexpr = length(mulexargs)
Nexpr == 2 && (a, mulexargs[2]::ProdArg)
Nexpr != 3 && (a, mul_fast_expr(mulexargs))
# We'll calc the product between the guesstimated cheaper two args first, for better out of order execution
b = (mulexargs[2])::ProdArg
c = (mulexargs[3])::ProdArg
ac = mulexprcost(a)
bc = mulexprcost(b)
cc = mulexprcost(c)
maxc = max(ac, bc, cc)
ac == maxc && return (a, Expr(:call, :mul_fast, b, c))
bc == maxc && return (b, Expr(:call, :mul_fast, c, a))
return (c, Expr(:call, :mul_fast, a, b))
end
function append_args_skip!(call, args, i, mod)
for j eachindex(args)
Expand Down Expand Up @@ -228,7 +218,8 @@ function capture_a_muladd(ex::Expr, mod)
end
true, call
end
capture_muladd(ex::Expr, mod) = while true
capture_muladd(ex::Expr, mod) =
while true
ex.head === :ref && return ex
if Meta.isexpr(ex, :call, 2)
if (ex.args[1] === :(-))
Expand Down

0 comments on commit d03c179

Please sign in to comment.