Skip to content

Commit

Permalink
combine broadcast statements (#457)
Browse files Browse the repository at this point in the history
* combine broadcast statements

* fix bc and add rrules
  • Loading branch information
chriselrod committed Jan 5, 2023
1 parent 8ba69ae commit a73a797
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 51 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LoopVectorization"
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
authors = ["Chris Elrod <[email protected]>"]
version = "0.12.145"
version = "0.12.146"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
79 changes: 29 additions & 50 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -525,15 +525,15 @@ function add_broadcast_loops!(ls::LoopSet, loopsyms::Vector{Symbol}, destsym::Sy
end
end

# size of dest determines loops
# function vmaterialize!(
@generated function vmaterialize!(
dest::AbstractArray{T,N},
bc::BC,
::Val{Mod},
::Val{UNROLL},
::Val{dontbc},
) where {T<:NativeTypes,N,BC<:Union{Broadcasted,Product},Mod,UNROLL,dontbc}
function vmaterialize_fun(
sizeofT::Int,
N,
@nospecialize(_::Type{BC}),
Mod,
UNROLL,
dontbc,
transpose::Bool,
) where {BC}
# 2 + 1
# we have an N dimensional loop.
# need to construct the LoopSet
Expand All @@ -542,17 +542,20 @@ end
set_hw!(ls, rs, rc, cls)
ls.isbroadcast = isbroadcast # maybe set `false` in a DiffEq-like `@..` macro
loopsyms = [gensym!(ls, "n") for _ 1:N]
add_broadcast_loops!(ls, loopsyms, :dest)
elementbytes = sizeof(T)
transpose && pushprepreamble!(ls, Expr(:(=), :dest, Expr(:call, :parent, :dest′)))
ret = transpose ? :dest′ : :dest
add_broadcast_loops!(ls, loopsyms, ret)
elementbytes = sizeofT
add_broadcast!(ls, :destination, :bc, loopsyms, BC, dontbc, elementbytes)
transpose && reverse!(loopsyms)
storeop =
add_simple_store!(ls, :destination, ArrayReference(:dest, loopsyms), elementbytes)
doaddref!(ls, storeop)
resize!(ls.loop_order, num_loops(ls)) # num_loops may be greater than N, eg Product
# return ls
sc = setup_call(
ls,
:(Base.Broadcast.materialize!(dest, bc)),
:(Base.Broadcast.materialize!($ret, bc)),
LineNumberNode(0),
inline,
false,
Expand All @@ -563,7 +566,19 @@ end
warncheckarg,
safe,
)
Expr(:block, Expr(:meta, :inline), sc, :dest)
Expr(:block, Expr(:meta, :inline), sc, ret)
end

# size of dest determines loops
# function vmaterialize!(
@generated function vmaterialize!(
dest::AbstractArray{T,N},
bc::BC,
::Val{Mod},
::Val{UNROLL},
::Val{dontbc},
) where {T<:NativeTypes,N,BC<:Union{Broadcasted,Product},Mod,UNROLL,dontbc}
vmaterialize_fun(sizeof(T), N, BC, Mod, UNROLL, dontbc, false)
end
@generated function vmaterialize!(
dest′::Union{Adjoint{T,A},Transpose{T,A}},
Expand All @@ -580,43 +595,7 @@ end
UNROLL,
dontbc,
}
# we have an N dimensional loop.
# need to construct the LoopSet
ls = LoopSet(Mod)
inline, u₁, u₂, v, isbroadcast, _, rs, rc, cls, threads, warncheckarg, safe = UNROLL
set_hw!(ls, rs, rc, cls)
ls.isbroadcast = isbroadcast # maybe set `false` in a DiffEq-like `@..` macro
loopsyms = [gensym!(ls, "n") for _ 1:N]
pushprepreamble!(ls, Expr(:(=), :dest, Expr(:call, :parent, :dest′)))
add_broadcast_loops!(ls, loopsyms, :dest′)
elementbytes = sizeof(T)
add_broadcast!(ls, :destination, :bc, loopsyms, BC, dontbc, elementbytes)
storeop = add_simple_store!(
ls,
:destination,
ArrayReference(:dest, reverse(loopsyms)),
elementbytes,
)
doaddref!(ls, storeop)
resize!(ls.loop_order, num_loops(ls)) # num_loops may be greater than N, eg Product
Expr(
:block,
Expr(:meta, :inline),
setup_call(
ls,
:(Base.Broadcast.materialize!(dest′, bc)),
LineNumberNode(0),
inline,
false,
u₁,
u₂,
v,
threads % Int,
warncheckarg,
safe,
),
:dest′,
)
vmaterialize_fun(sizeof(T), N, BC, Mod, UNROLL, dontbc, true)
end
# these are marked `@inline` so the `@turbo` itself can choose whether or not to inline.
@generated function vmaterialize!(
Expand Down
5 changes: 5 additions & 0 deletions src/simdfunctionals/vmap_grad_rrule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,8 @@ function ChainRulesCore.rrule(::typeof(vmap), f::F, args::Vararg{Any,K}) where {
∂vmap_singlethread!(f, jacs, out, args...)
out, SIMDMapBack(jacs)
end
for f in (:vmapt, :vmapnt, :vmapntt)
@eval function ChainRulesCore.rrule(::typeof($f), f::F, args::Vararg{Any,K}) where {F,K}
ChainRulesCore.rrule(typeof(vmap), f, args...)
end
end

2 comments on commit a73a797

@chriselrod
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/75214

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.12.146 -m "<description of version>" a73a7979263c6f2853a703211b81f482184a32b2
git push origin v0.12.146

Please sign in to comment.