@@ -364,28 +364,32 @@ lmul!(A, B)
364364#  aggressive constant propagation makes mul!(C, A, B) invoke gemm_wrapper! directly
365365Base. @constprop  :aggressive  function  generic_matmatmul! (C:: StridedMatrix{T} , tA, tB, A:: StridedVecOrMat{T} , B:: StridedVecOrMat{T} ,
366366                                    _add:: MulAddMul = MulAddMul ()) where  {T<: BlasFloat }
367-     if  all (in ((' N'  , ' T'  , ' C'  )), (tA, tB))
368-         if  tA ==  ' T'   &&  tB ==  ' N'   &&  A ===  B
367+     #  We convert the chars to uppercase to potentially unwrap a WrapperChar,
368+     #  and extract the char corresponding to the wrapper type
369+     tA_uc, tB_uc =  uppercase (tA), uppercase (tB)
370+     #  the map in all ensures constprop by acting on tA and tB individually, instead of looping over them.
371+     if  all (map (in ((' N'  , ' T'  , ' C'  )), (tA_uc, tB_uc)))
372+         if  tA_uc ==  ' T'   &&  tB_uc ==  ' N'   &&  A ===  B
369373            return  syrk_wrapper! (C, ' T'  , A, _add)
370-         elseif  tA  ==  ' N'   &&  tB  ==  ' T'   &&  A ===  B
374+         elseif  tA_uc  ==  ' N'   &&  tB_uc  ==  ' T'   &&  A ===  B
371375            return  syrk_wrapper! (C, ' N'  , A, _add)
372-         elseif  tA  ==  ' C'   &&  tB  ==  ' N'   &&  A ===  B
376+         elseif  tA_uc  ==  ' C'   &&  tB_uc  ==  ' N'   &&  A ===  B
373377            return  herk_wrapper! (C, ' C'  , A, _add)
374-         elseif  tA  ==  ' N'   &&  tB  ==  ' C'   &&  A ===  B
378+         elseif  tA_uc  ==  ' N'   &&  tB_uc  ==  ' C'   &&  A ===  B
375379            return  herk_wrapper! (C, ' N'  , A, _add)
376380        else 
377381            return  gemm_wrapper! (C, tA, tB, A, B, _add)
378382        end 
379383    end 
380384    alpha, beta =  promote (_add. alpha, _add. beta, zero (T))
381385    if  alpha isa  Union{Bool,T} &&  beta isa  Union{Bool,T}
382-         if  (tA  ==  ' S'   ||  tA  ==   ' s ' )  &&  tB  ==  ' N' 
386+         if  tA_uc  ==  ' S'   &&  tB_uc  ==  ' N' 
383387            return  BLAS. symm! (' L'  , tA ==  ' S'   ?  ' U'   :  ' L'  , alpha, A, B, beta, C)
384-         elseif  (tB  ==  ' S '   ||  tB  ==   ' s ' )  &&  tA  ==  ' N ' 
388+         elseif  tA_uc  ==  ' N '   &&  tB_uc  ==  ' S ' 
385389            return  BLAS. symm! (' R'  , tB ==  ' S'   ?  ' U'   :  ' L'  , alpha, B, A, beta, C)
386-         elseif  (tA  ==  ' H'   ||  tA  ==   ' h ' )  &&  tB  ==  ' N' 
390+         elseif  tA_uc  ==  ' H'   &&  tB_uc  ==  ' N' 
387391            return  BLAS. hemm! (' L'  , tA ==  ' H'   ?  ' U'   :  ' L'  , alpha, A, B, beta, C)
388-         elseif  (tB  ==  ' H '   ||  tB  ==   ' h ' )  &&  tA  ==  ' N ' 
392+         elseif  tA_uc  ==  ' N '   &&  tB_uc  ==  ' H ' 
389393            return  BLAS. hemm! (' R'  , tB ==  ' H'   ?  ' U'   :  ' L'  , alpha, B, A, beta, C)
390394        end 
391395    end 
395399#  Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency.
396400Base. @constprop  :aggressive  function  generic_matmatmul! (C:: StridedVecOrMat{Complex{T}} , tA, tB, A:: StridedVecOrMat{Complex{T}} , B:: StridedVecOrMat{T} ,
397401                    _add:: MulAddMul = MulAddMul ()) where  {T<: BlasReal }
398-     if  all (in ((' N'  , ' T'  , ' C'  )), (tA, tB))
402+     #  We convert the chars to uppercase to potentially unwrap a WrapperChar,
403+     #  and extract the char corresponding to the wrapper type
404+     tA_uc, tB_uc =  uppercase (tA), uppercase (tB)
405+     #  the map in all ensures constprop by acting on tA and tB individually, instead of looping over them.
406+     if  all (map (in ((' N'  , ' T'  , ' C'  )), (tA_uc, tB_uc)))
399407        gemm_wrapper! (C, tA, tB, A, B, _add)
400408    else 
401409        _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), _add)
@@ -434,18 +442,19 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{T}, tA::AbstractChar
434442    mA ==  0  &&  return  y
435443    nA ==  0  &&  return  _rmul_or_fill! (y, β)
436444    alpha, beta =  promote (α, β, zero (T))
445+     tA_uc =  uppercase (tA) #  potentially convert a WrapperChar to a Char
437446    if  alpha isa  Union{Bool,T} &&  beta isa  Union{Bool,T} && 
438447        stride (A, 1 ) ==  1  &&  abs (stride (A, 2 )) >=  size (A, 1 ) && 
439448        ! iszero (stride (x, 1 )) &&  #  We only check input's stride here.
440-         if  tA  in  (' N'  , ' T'  , ' C'  )
449+         if  tA_uc  in  (' N'  , ' T'  , ' C'  )
441450            return  BLAS. gemv! (tA, alpha, A, x, beta, y)
442-         elseif  tA  in  ( ' S' ,  ' s ' ) 
451+         elseif  tA_uc  ==   ' S' 
443452            return  BLAS. symv! (tA ==  ' S'   ?  ' U'   :  ' L'  , alpha, A, x, beta, y)
444-         elseif  tA  in  ( ' H' ,  ' h ' ) 
453+         elseif  tA_uc  ==   ' H' 
445454            return  BLAS. hemv! (tA ==  ' H'   ?  ' U'   :  ' L'  , alpha, A, x, beta, y)
446455        end 
447456    end 
448-     if  tA  in  (' S'  , ' s ' ,  ' H ' ,  ' h '  )
457+     if  tA_uc  in  (' S'  , ' H '  )
449458        #  re-wrap again and use plain ('N') matvec mul algorithm,
450459        #  because _generic_matvecmul! can't handle the HermOrSym cases specifically
451460        return  _generic_matvecmul! (y, ' N'  , wrap (A, tA), x, MulAddMul (α, β))
@@ -464,14 +473,15 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs
464473    mA ==  0  &&  return  y
465474    nA ==  0  &&  return  _rmul_or_fill! (y, β)
466475    alpha, beta =  promote (α, β, zero (T))
476+     tA_uc =  uppercase (tA) #  potentially convert a WrapperChar to a Char
467477    if  alpha isa  Union{Bool,T} &&  beta isa  Union{Bool,T} && 
468478        stride (A, 1 ) ==  1  &&  abs (stride (A, 2 )) >=  size (A, 1 ) && 
469-         stride (y, 1 ) ==  1  &&  tA  ==  ' N'   &&  #  reinterpret-based optimization is valid only for contiguous `y`
479+         stride (y, 1 ) ==  1  &&  tA_uc  ==  ' N'   &&  #  reinterpret-based optimization is valid only for contiguous `y`
470480        ! iszero (stride (x, 1 ))
471481        BLAS. gemv! (tA, alpha, reinterpret (T, A), x, beta, reinterpret (T, y))
472482        return  y
473483    else 
474-         Anew, ta =  tA  in  (' S'  , ' s ' ,  ' H ' ,  ' h '  ) ?  (wrap (A, tA), ' N'  ) :  (A, tA)
484+         Anew, ta =  tA_uc  in  (' S'  , ' H '  ) ?  (wrap (A, tA), oftype (tA,  ' N' ) ) :  (A, tA)
475485        return  _generic_matvecmul! (y, ta, Anew, x, MulAddMul (α, β))
476486    end 
477487end 
@@ -487,15 +497,16 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs
487497    mA ==  0  &&  return  y
488498    nA ==  0  &&  return  _rmul_or_fill! (y, β)
489499    alpha, beta =  promote (α, β, zero (T))
500+     tA_uc =  uppercase (tA) #  potentially convert a WrapperChar to a Char
490501    @views  if  alpha isa  Union{Bool,T} &&  beta isa  Union{Bool,T} && 
491502        stride (A, 1 ) ==  1  &&  abs (stride (A, 2 )) >=  size (A, 1 ) && 
492-         ! iszero (stride (x, 1 )) &&  tA  in  (' N'  , ' T'  , ' C'  )
503+         ! iszero (stride (x, 1 )) &&  tA_uc  in  (' N'  , ' T'  , ' C'  )
493504        xfl =  reinterpret (reshape, T, x) #  Use reshape here.
494505        yfl =  reinterpret (reshape, T, y)
495506        BLAS. gemv! (tA, alpha, A, xfl[1 , :], beta, yfl[1 , :])
496507        BLAS. gemv! (tA, alpha, A, xfl[2 , :], beta, yfl[2 , :])
497508        return  y
498-     elseif  tA  in  (' S'  , ' s ' ,  ' H ' ,  ' h '  )
509+     elseif  tA_uc  in  (' S'  , ' H '  )
499510        #  re-wrap again and use plain ('N') matvec mul algorithm,
500511        #  because _generic_matvecmul! can't handle the HermOrSym cases specifically
501512        return  _generic_matvecmul! (y, ' N'  , wrap (A, tA), x, MulAddMul (α, β))
@@ -504,10 +515,13 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs
504515    end 
505516end 
506517
507- function  syrk_wrapper! (C:: StridedMatrix{T} , tA:: AbstractChar , A:: StridedVecOrMat{T} ,
518+ #  the aggressive constprop pushes tA and tB into gemm_wrapper!, which is needed for wrap calls within it
519+ #  to be concretely inferred
520+ Base. @constprop  :aggressive  function  syrk_wrapper! (C:: StridedMatrix{T} , tA:: AbstractChar , A:: StridedVecOrMat{T} ,
508521        _add =  MulAddMul ()) where  {T<: BlasFloat }
509522    nC =  checksquare (C)
510-     if  tA ==  ' T' 
523+     tA_uc =  uppercase (tA) #  potentially convert a WrapperChar to a Char
524+     if  tA_uc ==  ' T' 
511525        (nA, mA) =  size (A,1 ), size (A,2 )
512526        tAt =  ' N' 
513527    else 
@@ -542,10 +556,13 @@ function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat
542556    return  gemm_wrapper! (C, tA, tAt, A, A, _add)
543557end 
544558
545- function  herk_wrapper! (C:: Union{StridedMatrix{T}, StridedMatrix{Complex{T}}} , tA:: AbstractChar , A:: Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}} ,
559+ #  the aggressive constprop pushes tA and tB into gemm_wrapper!, which is needed for wrap calls within it
560+ #  to be concretely inferred
561+ Base. @constprop  :aggressive  function  herk_wrapper! (C:: Union{StridedMatrix{T}, StridedMatrix{Complex{T}}} , tA:: AbstractChar , A:: Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}} ,
546562        _add =  MulAddMul ()) where  {T<: BlasReal }
547563    nC =  checksquare (C)
548-     if  tA ==  ' C' 
564+     tA_uc =  uppercase (tA) #  potentially convert a WrapperChar to a Char
565+     if  tA_uc ==  ' C' 
549566        (nA, mA) =  size (A,1 ), size (A,2 )
550567        tAt =  ' N' 
551568    else 
@@ -581,20 +598,28 @@ function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA
581598    return  gemm_wrapper! (C, tA, tAt, A, A, _add)
582599end 
583600
584- function  gemm_wrapper (tA:: AbstractChar , tB:: AbstractChar ,
601+ #  Aggressive constprop helps propagate the values of tA and tB into wrap, which
602+ #  makes the calls concretely inferred
603+ Base. @constprop  :aggressive  function  gemm_wrapper (tA:: AbstractChar , tB:: AbstractChar ,
585604                      A:: StridedVecOrMat{T} ,
586605                      B:: StridedVecOrMat{T} ) where  {T<: BlasFloat }
587606    mA, nA =  lapack_size (tA, A)
588607    mB, nB =  lapack_size (tB, B)
589608    C =  similar (B, T, mA, nB)
590-     if  all (in ((' N'  , ' T'  , ' C'  )), (tA, tB))
609+     #  We convert the chars to uppercase to potentially unwrap a WrapperChar,
610+     #  and extract the char corresponding to the wrapper type
611+     tA_uc, tB_uc =  uppercase (tA), uppercase (tB)
612+     #  the map in all ensures constprop by acting on tA and tB individually, instead of looping over them.
613+     if  all (map (in ((' N'  , ' T'  , ' C'  )), (tA_uc, tB_uc)))
591614        gemm_wrapper! (C, tA, tB, A, B)
592615    else 
593616        _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), _add)
594617    end 
595618end 
596619
597- function  gemm_wrapper! (C:: StridedVecOrMat{T} , tA:: AbstractChar , tB:: AbstractChar ,
620+ #  Aggressive constprop helps propagate the values of tA and tB into wrap, which
621+ #  makes the calls concretely inferred
622+ Base. @constprop  :aggressive  function  gemm_wrapper! (C:: StridedVecOrMat{T} , tA:: AbstractChar , tB:: AbstractChar ,
598623                       A:: StridedVecOrMat{T} , B:: StridedVecOrMat{T} ,
599624                       _add =  MulAddMul ()) where  {T<: BlasFloat }
600625    mA, nA =  lapack_size (tA, A)
@@ -634,7 +659,9 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar
634659    _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), _add)
635660end 
636661
637- function  gemm_wrapper! (C:: StridedVecOrMat{Complex{T}} , tA:: AbstractChar , tB:: AbstractChar ,
662+ #  Aggressive constprop helps propagate the values of tA and tB into wrap, which
663+ #  makes the calls concretely inferred
664+ Base. @constprop  :aggressive  function  gemm_wrapper! (C:: StridedVecOrMat{Complex{T}} , tA:: AbstractChar , tB:: AbstractChar ,
638665                       A:: StridedVecOrMat{Complex{T}} , B:: StridedVecOrMat{T} ,
639666                       _add =  MulAddMul ()) where  {T<: BlasReal }
640667    mA, nA =  lapack_size (tA, A)
@@ -664,13 +691,15 @@ function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::Abs
664691
665692    alpha, beta =  promote (_add. alpha, _add. beta, zero (T))
666693
694+     tA_uc =  uppercase (tA) #  potentially convert a WrapperChar to a Char
695+ 
667696    #  Make-sure reinterpret-based optimization is BLAS-compatible.
668697    if  (alpha isa  Union{Bool,T} && 
669698        beta isa  Union{Bool,T} && 
670699        stride (A, 1 ) ==  stride (B, 1 ) ==  stride (C, 1 ) ==  1  && 
671700        stride (A, 2 ) >=  size (A, 1 ) && 
672701        stride (B, 2 ) >=  size (B, 1 ) && 
673-         stride (C, 2 ) >=  size (C, 1 ) &&  tA  ==  ' N'  )
702+         stride (C, 2 ) >=  size (C, 1 ) &&  tA_uc  ==  ' N'  )
674703        BLAS. gemm! (tA, tB, alpha, reinterpret (T, A), B, beta, reinterpret (T, C))
675704        return  C
676705    end 
@@ -703,9 +732,10 @@ parameters must satisfy `length(ir_dest) == length(ir_src)` and
703732See also [`copy_transpose!`](@ref) and [`copy_adjoint!`](@ref). 
704733""" 
705734function  copyto! (B:: AbstractVecOrMat , ir_dest:: AbstractUnitRange{Int} , jr_dest:: AbstractUnitRange{Int} , tM:: AbstractChar , M:: AbstractVecOrMat , ir_src:: AbstractUnitRange{Int} , jr_src:: AbstractUnitRange{Int} )
706-     if  tM ==  ' N' 
735+     tM_uc =  uppercase (tM) #  potentially convert a WrapperChar to a Char
736+     if  tM_uc ==  ' N' 
707737        copyto! (B, ir_dest, jr_dest, M, ir_src, jr_src)
708-     elseif  tM  ==  ' T' 
738+     elseif  tM_uc  ==  ' T' 
709739        copy_transpose! (B, ir_dest, jr_dest, M, jr_src, ir_src)
710740    else 
711741        copy_adjoint! (B, ir_dest, jr_dest, M, jr_src, ir_src)
@@ -734,11 +764,12 @@ range parameters must satisfy `length(ir_dest) == length(jr_src)` and
734764See also [`copyto!`](@ref) and [`copy_adjoint!`](@ref). 
735765""" 
736766function  copy_transpose! (B:: AbstractMatrix , ir_dest:: AbstractUnitRange{Int} , jr_dest:: AbstractUnitRange{Int} , tM:: AbstractChar , M:: AbstractVecOrMat , ir_src:: AbstractUnitRange{Int} , jr_src:: AbstractUnitRange{Int} )
737-     if  tM ==  ' N' 
767+     tM_uc =  uppercase (tM) #  potentially convert a WrapperChar to a Char
768+     if  tM_uc ==  ' N' 
738769        copy_transpose! (B, ir_dest, jr_dest, M, ir_src, jr_src)
739770    else 
740771        copyto! (B, ir_dest, jr_dest, M, jr_src, ir_src)
741-         tM  ==  ' C'   &&  conj! (@view  B[ir_dest, jr_dest])
772+         tM_uc  ==  ' C'   &&  conj! (@view  B[ir_dest, jr_dest])
742773    end 
743774    B
744775end 
751782
752783@inline  function  generic_matvecmul! (C:: AbstractVector , tA, A:: AbstractVecOrMat , B:: AbstractVector ,
753784                                    _add:: MulAddMul  =  MulAddMul ())
754-     Anew, ta =  tA in  (' S'  , ' s'  , ' H'  , ' h'  ) ?  (wrap (A, tA), ' N'  ) :  (A, tA)
785+     tA_uc =  uppercase (tA) #  potentially convert a WrapperChar to a Char
786+     Anew, ta =  tA_uc in  (' S'  , ' H'  ) ?  (wrap (A, tA), oftype (tA, ' N'  )) :  (A, tA)
755787    return  _generic_matvecmul! (C, ta, Anew, B, _add)
756788end 
757789
0 commit comments