From ddf79a84ca565d38fb684e879446f1ce0c387721 Mon Sep 17 00:00:00 2001 From: Alexis <35051714+amontoison@users.noreply.github.com> Date: Tue, 31 Mar 2020 04:45:43 -0400 Subject: [PATCH] Add rot! to BLAS in stdlib/LinearAlgebra, add generic rotate! and reflect! (#35124) --- NEWS.md | 2 + stdlib/LinearAlgebra/src/LinearAlgebra.jl | 2 + stdlib/LinearAlgebra/src/blas.jl | 32 ++++++++++++++++ stdlib/LinearAlgebra/src/generic.jl | 45 +++++++++++++++++++++++ stdlib/LinearAlgebra/test/blas.jl | 26 +++++++++++++ stdlib/LinearAlgebra/test/generic.jl | 19 ++++++++++ 6 files changed, 126 insertions(+) diff --git a/NEWS.md b/NEWS.md index e2862e6ff0704..c10dca7ad04ce 100644 --- a/NEWS.md +++ b/NEWS.md @@ -115,6 +115,8 @@ Standard library changes * `normalize` now supports multidimensional arrays ([#34239]) * `lq` factorizations can now be used to compute the minimum-norm solution to under-determined systems ([#34350]). * The BLAS submodule now supports the level-2 BLAS subroutine `spmv!` ([#34320]). +* The BLAS submodule now supports the level-1 BLAS subroutine `rot!` ([#35124]). +* New generic `rotate!(x, y, c, s)` and `reflect!(x, y, c, s)` functions ([#35124]). #### Markdown diff --git a/stdlib/LinearAlgebra/src/LinearAlgebra.jl b/stdlib/LinearAlgebra/src/LinearAlgebra.jl index 30002837f8398..bb1dcb3c17ea7 100644 --- a/stdlib/LinearAlgebra/src/LinearAlgebra.jl +++ b/stdlib/LinearAlgebra/src/LinearAlgebra.jl @@ -124,6 +124,8 @@ export opnorm, rank, rdiv!, + reflect!, + rotate!, schur, schur!, svd, diff --git a/stdlib/LinearAlgebra/src/blas.jl b/stdlib/LinearAlgebra/src/blas.jl index b8f86a12c6ea5..eb55e7d5fd3ab 100644 --- a/stdlib/LinearAlgebra/src/blas.jl +++ b/stdlib/LinearAlgebra/src/blas.jl @@ -17,6 +17,7 @@ export blascopy!, dotc, dotu, + rot!, scal!, scal, nrm2, @@ -198,6 +199,37 @@ for (fname, elty) in ((:dcopy_,:Float64), end end + +## rot + +""" + rot!(n, X, incx, Y, incy, c, s) + +Overwrite `X` with `c*X + s*Y` and `Y` with `-conj(s)*X + c*Y` for the first `n` elements of array `X` with stride `incx` and +first `n` elements of array `Y` with stride `incy`. Returns `X` and `Y`. + +!!! compat "Julia 1.5" + `rot!` requires at least Julia 1.5. +""" +function rot! end + +for (fname, elty, cty, sty, lib) in ((:drot_, :Float64, :Float64, :Float64, libblas), + (:srot_, :Float32, :Float32, :Float32, libblas), + (:zdrot_, :ComplexF64, :Float64, :Float64, libblas), + (:csrot_, :ComplexF32, :Float32, :Float32, libblas), + (:zrot_, :ComplexF64, :Float64, :ComplexF64, liblapack), + (:crot_, :ComplexF32, :Float32, :ComplexF32, liblapack)) + @eval begin + # SUBROUTINE DROT(N,DX,INCX,DY,INCY,C,S) + function rot!(n::Integer, DX::Union{Ptr{$elty},AbstractArray{$elty}}, incx::Integer, DY::Union{Ptr{$elty},AbstractArray{$elty}}, incy::Integer, C::$cty, S::$sty) + ccall((@blasfunc($fname), $lib), Cvoid, + (Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ref{$cty}, Ref{$sty}), + n, DX, incx, DY, incy, C, S) + DX, DY + end + end +end + ## scal """ diff --git a/stdlib/LinearAlgebra/src/generic.jl b/stdlib/LinearAlgebra/src/generic.jl index b90928c967c81..e5555da909d10 100644 --- a/stdlib/LinearAlgebra/src/generic.jl +++ b/stdlib/LinearAlgebra/src/generic.jl @@ -1416,6 +1416,51 @@ function axpby!(α, x::AbstractArray, β, y::AbstractArray) y end +""" + rotate!(x, y, c, s) + +Overwrite `x` with `c*x + s*y` and `y` with `-conj(s)*x + c*y`. +Returns `x` and `y`. + +!!! compat "Julia 1.5" + `rotate!` requires at least Julia 1.5. +""" +function rotate!(x::AbstractVector, y::AbstractVector, c, s) + require_one_based_indexing(x, y) + n = length(x) + if n != length(y) + throw(DimensionMismatch("x has length $(length(x)), but y has length $(length(y))")) + end + @inbounds for i = 1:n + xi, yi = x[i], y[i] + x[i] = c *xi + s*yi + y[i] = -conj(s)*xi + c*yi + end + return x, y +end + +""" + reflect!(x, y, c, s) + +Overwrite `x` with `c*x + s*y` and `y` with `conj(s)*x - c*y`. +Returns `x` and `y`. + +!!! compat "Julia 1.5" + `reflect!` requires at least Julia 1.5. +""" +function reflect!(x::AbstractVector, y::AbstractVector, c, s) + require_one_based_indexing(x, y) + n = length(x) + if n != length(y) + throw(DimensionMismatch("x has length $(length(x)), but y has length $(length(y))")) + end + @inbounds for i = 1:n + xi, yi = x[i], y[i] + x[i] = c *xi + s*yi + y[i] = conj(s)*xi - c*yi + end + return x, y +end # Elementary reflection similar to LAPACK. The reflector is not Hermitian but # ensures that tridiagonalization of Hermitian matrices become real. See lawn72 diff --git a/stdlib/LinearAlgebra/test/blas.jl b/stdlib/LinearAlgebra/test/blas.jl index dc6bf1f054767..23c6d68cdc997 100644 --- a/stdlib/LinearAlgebra/test/blas.jl +++ b/stdlib/LinearAlgebra/test/blas.jl @@ -78,6 +78,32 @@ Random.seed!(100) @test BLAS.iamax(z) == argmax(map(x -> abs(real(x)) + abs(imag(x)), z)) end end + @testset "rot!" begin + if elty <: Real + x = convert(Vector{elty}, randn(n)) + y = convert(Vector{elty}, randn(n)) + c = rand(elty) + s = rand(elty) + x2 = copy(x) + y2 = copy(y) + BLAS.rot!(n, x, 1, y, 1, c, s) + @test x ≈ c*x2 + s*y2 + @test y ≈ -s*x2 + c*y2 + else + x = convert(Vector{elty}, complex.(randn(n),rand(n))) + y = convert(Vector{elty}, complex.(randn(n),rand(n))) + cty = (elty == ComplexF32) ? Float32 : Float64 + c = rand(cty) + for sty in [cty, elty] + s = rand(sty) + x2 = copy(x) + y2 = copy(y) + BLAS.rot!(n, x, 1, y, 1, c, s) + @test x ≈ c*x2 + s*y2 + @test y ≈ -conj(s)*x2 + c*y2 + end + end + end @testset "axp(b)y" begin if elty <: Real x1 = convert(Vector{elty}, randn(n)) diff --git a/stdlib/LinearAlgebra/test/generic.jl b/stdlib/LinearAlgebra/test/generic.jl index eb747b185afa3..9ca508d9a9908 100644 --- a/stdlib/LinearAlgebra/test/generic.jl +++ b/stdlib/LinearAlgebra/test/generic.jl @@ -217,6 +217,25 @@ end @test norm(x, 3) ≈ cbrt(5^3 +sqrt(5)^3) end +@testset "rotate! and reflect!" begin + x = rand(ComplexF64, 10) + y = rand(ComplexF64, 10) + c = rand(Float64) + s = rand(ComplexF64) + + x2 = copy(x) + y2 = copy(y) + rotate!(x, y, c, s) + @test x ≈ c*x2 + s*y2 + @test y ≈ -conj(s)*x2 + c*y2 + + x3 = copy(x) + y3 = copy(y) + reflect!(x, y, c, s) + @test x ≈ c*x3 + s*y3 + @test y ≈ conj(s)*x3 - c*y3 +end + @testset "LinearAlgebra.axp(b)y! for element type without commutative multiplication" begin α = [1 2; 3 4] β = [5 6; 7 8]