Skip to content

Commit

Permalink
use adjoint not transpose, + doc tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Abbott committed Oct 24, 2020
1 parent c8a1fee commit 676d166
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions src/batched/batchedmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ and similarly `batched_adjoint`. Other permutations are also handled by BLAS,
provided that the batch index `k` is not the first dimension of the underlying array.
Thus `PermutedDimsArray(::Array, (1,3,2))` and `PermutedDimsArray(::Array, (3,1,2))` are fine.
However `PermutedDimsArray(::Array, (3,2,1))` is not acceptable to BLAS,
and will thus be copied as this is faster than the fallback method `batched_mul_generic!`.
However `A = PermutedDimsArray(::Array, (3,2,1))` is not acceptable to BLAS,
since `stride(A,3) == 1`. This be copied, as doing so is faster than `batched_mul_generic!`.
Both this `copy` and `batched_mul_generic!` produce `@debug` messages,
and setting for instance `ENV["JULIA_DEBUG"] = NNlib` will display them.
Expand Down Expand Up @@ -74,8 +74,9 @@ for `X ∈ [A,B,C]`, either `strides(X,1)==1` or `strides(X,2)==1`, the latter m
be caused by `batched_transpose` or by for instance `PermutedDimsArray(::Array, (3,1,2))`.
Unlike `batched_mul` this will never make a copy.
For complex arrays, the wrapper made by `batched_adjoint` must be outermost to be seen,
and in this case `stride(A::BatchedAdjoint,2) == 1` is not optional.
For complex arrays, the wrapper made by `batched_adjoint` must be outermost to be seen.
In this case the strided accepted by BLAS are more restricted, if `stride(C,1)==1` then
only `stride(AorB::BatchedAdjoint,2) == 1` is accepted.
"""
function batched_mul!(C::AbstractArray{T,3}, A::AbstractArray{<:Any,3}, B::AbstractArray{<:Any,3},
α::Number=one(T), β::Number=zero(T)) where {T}
Expand All @@ -97,8 +98,8 @@ function _batched_try_gemm!(::Type{DT}, C, A, B, α::Number, β::Number) where {

if Base.stride(C,1) == 1
elseif Base.stride(C,2) == 1
@debug "transposing C = A * B into Cᵀ = Bᵀ * Aᵀ" size(C) strides(C)
return batched_mul!(batched_transpose(C), batched_transpose(B), batched_transpose(A), α, β)
@debug "transforming C = A * B into C' = B' * A'" size(C) strides(C)
return batched_mul!(batched_adjoint(C), batched_adjoint(B), batched_adjoint(A), α, β)
else
return batched_mul_generic!(C, A, B, α, β)
end
Expand Down

0 comments on commit 676d166

Please sign in to comment.