diff --git a/src/batched/batchedmul.jl b/src/batched/batchedmul.jl index 641b57421..c6c1bf2de 100644 --- a/src/batched/batchedmul.jl +++ b/src/batched/batchedmul.jl @@ -19,8 +19,8 @@ and similarly `batched_adjoint`. Other permutations are also handled by BLAS, provided that the batch index `k` is not the first dimension of the underlying array. Thus `PermutedDimsArray(::Array, (1,3,2))` and `PermutedDimsArray(::Array, (3,1,2))` are fine. -However `PermutedDimsArray(::Array, (3,2,1))` is not acceptable to BLAS, -and will thus be copied as this is faster than the fallback method `batched_mul_generic!`. +However `A = PermutedDimsArray(::Array, (3,2,1))` is not acceptable to BLAS, +since `stride(A,3) == 1`. This be copied, as doing so is faster than `batched_mul_generic!`. Both this `copy` and `batched_mul_generic!` produce `@debug` messages, and setting for instance `ENV["JULIA_DEBUG"] = NNlib` will display them. @@ -74,8 +74,9 @@ for `X ∈ [A,B,C]`, either `strides(X,1)==1` or `strides(X,2)==1`, the latter m be caused by `batched_transpose` or by for instance `PermutedDimsArray(::Array, (3,1,2))`. Unlike `batched_mul` this will never make a copy. -For complex arrays, the wrapper made by `batched_adjoint` must be outermost to be seen, -and in this case `stride(A::BatchedAdjoint,2) == 1` is not optional. +For complex arrays, the wrapper made by `batched_adjoint` must be outermost to be seen. +In this case the strided accepted by BLAS are more restricted, if `stride(C,1)==1` then +only `stride(AorB::BatchedAdjoint,2) == 1` is accepted. """ function batched_mul!(C::AbstractArray{T,3}, A::AbstractArray{<:Any,3}, B::AbstractArray{<:Any,3}, α::Number=one(T), β::Number=zero(T)) where {T} @@ -97,8 +98,8 @@ function _batched_try_gemm!(::Type{DT}, C, A, B, α::Number, β::Number) where { if Base.stride(C,1) == 1 elseif Base.stride(C,2) == 1 - @debug "transposing C = A * B into Cᵀ = Bᵀ * Aᵀ" size(C) strides(C) - return batched_mul!(batched_transpose(C), batched_transpose(B), batched_transpose(A), α, β) + @debug "transforming C = A * B into C' = B' * A'" size(C) strides(C) + return batched_mul!(batched_adjoint(C), batched_adjoint(B), batched_adjoint(A), α, β) else return batched_mul_generic!(C, A, B, α, β) end