Skip to content

Commit

Permalink
simplify thanks to Compat.jl, and julia 1.3+
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Abbott committed Oct 24, 2020
1 parent 18deacd commit b51cc49
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 26 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
Compat = "3.13"
Compat = "3.14"
Requires = "0.5, 1.0"
julia = "1.3"

Expand Down
31 changes: 6 additions & 25 deletions src/batched/batchedmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,28 +141,18 @@ for (TA, fA) in _BATCHED_LIST, (TB, fB) in _BATCHED_LIST

@eval function batched_mul_generic!(C::AbstractArray{T, 3}, A::$TA, B::$TB,
α::Number=one(T), β::Number=zero(T)) where {T}

size(A, 3) == size(C, 3) || size(A, 3) == 1 || throw(DimensionMismatch("batch size mismatch: A != C"))
size(B, 3) == size(C, 3) || size(B, 3) == 1 || throw(DimensionMismatch("batch size mismatch: B != C"))
@debug "calling fallback method for batched_mul!" typeof(A) typeof(B) typeof(C)

Abase, Bbase = _unbatch(A), _unbatch(B)
sA, oA = size(A,3) == 1 ? (0,1) : (1,0)
sB, oB = size(B,3) == 1 ? (0,1) : (1,0)

if VERSION >= v"1.3"
@inbounds for k in 1:size(C,3)
@views mul!(C[:,:,k], $fA(Abase[:,:,k*sA+oA]), $fB(Bbase[:,:,k*sB+oB]), α, β)
end
elseif α==1 && β==0
@inbounds for k in 1:size(C,3)
@views mul!(C[:,:,k], $fA(Abase[:,:,k*sA+oA]), $fB(Bbase[:,:,k*sB+oB]))
end
else
@debug "since there is no 5-arg mul!, calling C1 .= α .* (A1 * B1) .+ β .* C" α β
@inbounds for k in 1:size(C,3)
@views C[:,:,k] .= α .* $fA(Abase[:,:,k*sA+oA]) * $fB(Bbase[:,:,k*sB+oB]) .+ β .* C[:,:,k]
end
@inbounds for k in 1:size(C,3)
@views mul!(C[:,:,k], $fA(Abase[:,:,k*sA+oA]), $fB(Bbase[:,:,k*sB+oB]), α, β)
end

C
end

Expand Down Expand Up @@ -216,9 +206,6 @@ strided-ness, and hence also return `is_strided(parent(A))`.
This correctly handles things like `NamedDimsArray` wihch don't alter indexing.
However, it's a little pessimistic in that e.g. a `view` of such a container will return
`false`, even in cases where the same `view` of `parent(A)` would be a `StridedArray`.
`A::Transpose` doesn't currently define `strides`, until that's fixed this returns `false`.
The PR to fix that only defines `strides(::Adjoint{T})` for `T<:Real`, so this will follow.
"""
is_strided(A::StridedArray) = true
is_strided(A) = false
Expand All @@ -237,13 +224,7 @@ end
is_strided(A::BatchedAdjoint) = eltype(A) <: Real && is_strided(parent(A))
is_strided(A::BatchedTranspose) = is_strided(parent(A))

if hasmethod(Base.strides, Tuple{LinearAlgebra.Transpose})
# https://github.com/JuliaLang/julia/pull/29135
is_strided(A::LinearAlgebra.Transpose) = is_strided(parent(A))
is_strided(A::LinearAlgebra.Adjoint) = eltype(A) <: Real && is_strided(parent(A))
else
is_strided(A::LinearAlgebra.Transpose) = false
is_strided(A::LinearAlgebra.Adjoint) = false
end
is_strided(A::LinearAlgebra.Transpose) = is_strided(parent(A))
is_strided(A::LinearAlgebra.Adjoint) = eltype(A) <: Real && is_strided(parent(A))

are_strided(As...) = mapfoldl(is_strided, &, As; init=true)

0 comments on commit b51cc49

Please sign in to comment.