Skip to content

Commit

Permalink
Add rot! to BLAS in stdlib/LinearAlgebra, add generic rotate! and ref…
Browse files Browse the repository at this point in the history
…lect! (#35124)
  • Loading branch information
amontoison authored Mar 31, 2020
1 parent bf8aae8 commit ddf79a8
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 0 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ Standard library changes
* `normalize` now supports multidimensional arrays ([#34239])
* `lq` factorizations can now be used to compute the minimum-norm solution to under-determined systems ([#34350]).
* The BLAS submodule now supports the level-2 BLAS subroutine `spmv!` ([#34320]).
* The BLAS submodule now supports the level-1 BLAS subroutine `rot!` ([#35124]).
* New generic `rotate!(x, y, c, s)` and `reflect!(x, y, c, s)` functions ([#35124]).

#### Markdown

Expand Down
2 changes: 2 additions & 0 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ export
opnorm,
rank,
rdiv!,
reflect!,
rotate!,
schur,
schur!,
svd,
Expand Down
32 changes: 32 additions & 0 deletions stdlib/LinearAlgebra/src/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export
blascopy!,
dotc,
dotu,
rot!,
scal!,
scal,
nrm2,
Expand Down Expand Up @@ -198,6 +199,37 @@ for (fname, elty) in ((:dcopy_,:Float64),
end
end


## rot

"""
rot!(n, X, incx, Y, incy, c, s)
Overwrite `X` with `c*X + s*Y` and `Y` with `-conj(s)*X + c*Y` for the first `n` elements of array `X` with stride `incx` and
first `n` elements of array `Y` with stride `incy`. Returns `X` and `Y`.
!!! compat "Julia 1.5"
`rot!` requires at least Julia 1.5.
"""
function rot! end

for (fname, elty, cty, sty, lib) in ((:drot_, :Float64, :Float64, :Float64, libblas),
(:srot_, :Float32, :Float32, :Float32, libblas),
(:zdrot_, :ComplexF64, :Float64, :Float64, libblas),
(:csrot_, :ComplexF32, :Float32, :Float32, libblas),
(:zrot_, :ComplexF64, :Float64, :ComplexF64, liblapack),
(:crot_, :ComplexF32, :Float32, :ComplexF32, liblapack))
@eval begin
# SUBROUTINE DROT(N,DX,INCX,DY,INCY,C,S)
function rot!(n::Integer, DX::Union{Ptr{$elty},AbstractArray{$elty}}, incx::Integer, DY::Union{Ptr{$elty},AbstractArray{$elty}}, incy::Integer, C::$cty, S::$sty)
ccall((@blasfunc($fname), $lib), Cvoid,
(Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ref{$cty}, Ref{$sty}),
n, DX, incx, DY, incy, C, S)
DX, DY
end
end
end

## scal

"""
Expand Down
45 changes: 45 additions & 0 deletions stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1416,6 +1416,51 @@ function axpby!(α, x::AbstractArray, β, y::AbstractArray)
y
end

"""
rotate!(x, y, c, s)
Overwrite `x` with `c*x + s*y` and `y` with `-conj(s)*x + c*y`.
Returns `x` and `y`.
!!! compat "Julia 1.5"
`rotate!` requires at least Julia 1.5.
"""
function rotate!(x::AbstractVector, y::AbstractVector, c, s)
require_one_based_indexing(x, y)
n = length(x)
if n != length(y)
throw(DimensionMismatch("x has length $(length(x)), but y has length $(length(y))"))
end
@inbounds for i = 1:n
xi, yi = x[i], y[i]
x[i] = c *xi + s*yi
y[i] = -conj(s)*xi + c*yi
end
return x, y
end

"""
reflect!(x, y, c, s)
Overwrite `x` with `c*x + s*y` and `y` with `conj(s)*x - c*y`.
Returns `x` and `y`.
!!! compat "Julia 1.5"
`reflect!` requires at least Julia 1.5.
"""
function reflect!(x::AbstractVector, y::AbstractVector, c, s)
require_one_based_indexing(x, y)
n = length(x)
if n != length(y)
throw(DimensionMismatch("x has length $(length(x)), but y has length $(length(y))"))
end
@inbounds for i = 1:n
xi, yi = x[i], y[i]
x[i] = c *xi + s*yi
y[i] = conj(s)*xi - c*yi
end
return x, y
end

# Elementary reflection similar to LAPACK. The reflector is not Hermitian but
# ensures that tridiagonalization of Hermitian matrices become real. See lawn72
Expand Down
26 changes: 26 additions & 0 deletions stdlib/LinearAlgebra/test/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,32 @@ Random.seed!(100)
@test BLAS.iamax(z) == argmax(map(x -> abs(real(x)) + abs(imag(x)), z))
end
end
@testset "rot!" begin
if elty <: Real
x = convert(Vector{elty}, randn(n))
y = convert(Vector{elty}, randn(n))
c = rand(elty)
s = rand(elty)
x2 = copy(x)
y2 = copy(y)
BLAS.rot!(n, x, 1, y, 1, c, s)
@test x c*x2 + s*y2
@test y -s*x2 + c*y2
else
x = convert(Vector{elty}, complex.(randn(n),rand(n)))
y = convert(Vector{elty}, complex.(randn(n),rand(n)))
cty = (elty == ComplexF32) ? Float32 : Float64
c = rand(cty)
for sty in [cty, elty]
s = rand(sty)
x2 = copy(x)
y2 = copy(y)
BLAS.rot!(n, x, 1, y, 1, c, s)
@test x c*x2 + s*y2
@test y -conj(s)*x2 + c*y2
end
end
end
@testset "axp(b)y" begin
if elty <: Real
x1 = convert(Vector{elty}, randn(n))
Expand Down
19 changes: 19 additions & 0 deletions stdlib/LinearAlgebra/test/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,25 @@ end
@test norm(x, 3) cbrt(5^3 +sqrt(5)^3)
end

@testset "rotate! and reflect!" begin
x = rand(ComplexF64, 10)
y = rand(ComplexF64, 10)
c = rand(Float64)
s = rand(ComplexF64)

x2 = copy(x)
y2 = copy(y)
rotate!(x, y, c, s)
@test x c*x2 + s*y2
@test y -conj(s)*x2 + c*y2

x3 = copy(x)
y3 = copy(y)
reflect!(x, y, c, s)
@test x c*x3 + s*y3
@test y conj(s)*x3 - c*y3
end

@testset "LinearAlgebra.axp(b)y! for element type without commutative multiplication" begin
α = [1 2; 3 4]
β = [5 6; 7 8]
Expand Down

0 comments on commit ddf79a8

Please sign in to comment.