Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rot to BLAS in stdlib/LinearAlgebra #35124

Merged
merged 11 commits into from
Mar 31, 2020
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ Standard library changes
* `normalize` now supports multidimensional arrays ([#34239])
* `lq` factorizations can now be used to compute the minimum-norm solution to under-determined systems ([#34350]).
* The BLAS submodule now supports the level-2 BLAS subroutine `spmv!` ([#34320]).
* The BLAS submodule now supports the level-1 BLAS subroutine `rot!` ([#35124]).
StefanKarpinski marked this conversation as resolved.
Show resolved Hide resolved

dkarrasch marked this conversation as resolved.
Show resolved Hide resolved
#### Markdown

Expand Down
2 changes: 2 additions & 0 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ export
opnorm,
rank,
rdiv!,
ref!,
rot!,
schur,
schur!,
svd,
Expand Down
32 changes: 32 additions & 0 deletions stdlib/LinearAlgebra/src/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export
blascopy!,
dotc,
dotu,
rot!,
scal!,
scal,
nrm2,
Expand Down Expand Up @@ -198,6 +199,37 @@ for (fname, elty) in ((:dcopy_,:Float64),
end
end


## rot

"""
rot!(n, X, incx, Y, incy, c, s)

Overwrite `X` with `c*X + s*Y` and `Y` with `-conj(s)*X + c*Y` for the first `n` elements of array `X` with stride `incx` and
first `n` elements of array `Y` with stride `incy`. Returns `X` and `Y`.

!!! compat "Julia 1.5"
`rot!` requires at least Julia 1.5.
"""
function rot! end

for (fname, elty, cty, sty, lib) in ((:drot_, :Float64, :Float64, :Float64, libblas),
(:srot_, :Float32, :Float32, :Float32, libblas),
(:zdrot_, :ComplexF64, :Float64, :Float64, libblas),
(:csrot_, :ComplexF32, :Float32, :Float32, libblas),
(:zrot_, :ComplexF64, :Float64, :ComplexF64, liblapack),
(:crot_, :ComplexF32, :Float32, :ComplexF32, liblapack))
@eval begin
# SUBROUTINE DROT(N,DX,INCX,DY,INCY,C,S)
function rot!(n::Integer, DX::Union{Ptr{$elty},AbstractArray{$elty}}, incx::Integer, DY::Union{Ptr{$elty},AbstractArray{$elty}}, incy::Integer, C::$cty, S::$sty)
ccall((@blasfunc($fname), $lib), Cvoid,
(Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ref{$cty}, Ref{$sty}),
n, DX, incx, DY, incy, C, S)
DX, DY
end
end
end

## scal

"""
Expand Down
43 changes: 43 additions & 0 deletions stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1416,6 +1416,49 @@ function axpby!(α, x::AbstractArray, β, y::AbstractArray)
y
end

"""
rot!(x, y, c, s)

Overwrite `x` with `c*x + s*y` and `y` with `-conj(s)*x + c*y`.
Returns `x` and `y`.

!!! compat "Julia 1.5"
`rot!` requires at least Julia 1.5.
"""
function rot!(x::AbstractVector, y::AbstractVector, c, s)
n = length(x)
if n != length(y)
throw(DimensionMismatch("x has length $(length(x)), but y has length $(length(y))"))
end
@inbounds for i = 1:n
xi, yi = x[i], y[i]
x[i] = c *xi + s*yi
y[i] = -conj(s)*xi + c*yi
end
return x, y
end

"""
ref!(x, y, c, s)

Overwrite `x` with `c*x + s*y` and `y` with `conj(s)*x - c*y`.
Returns `x` and `y`.

!!! compat "Julia 1.5"
`rot!` requires at least Julia 1.5.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rot! -> ref!

"""
function ref!(x::AbstractVector, y::AbstractVector, c, s)
n = length(x)
if n != length(y)
throw(DimensionMismatch("x has length $(length(x)), but y has length $(length(y))"))
end
@inbounds for i = 1:n
dkarrasch marked this conversation as resolved.
Show resolved Hide resolved
xi, yi = x[i], y[i]
x[i] = c *xi + s*yi
y[i] = conj(s)*xi - c*yi
end
return x, y
end

# Elementary reflection similar to LAPACK. The reflector is not Hermitian but
# ensures that tridiagonalization of Hermitian matrices become real. See lawn72
Expand Down
26 changes: 26 additions & 0 deletions stdlib/LinearAlgebra/test/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,32 @@ Random.seed!(100)
@test BLAS.iamax(z) == argmax(map(x -> abs(real(x)) + abs(imag(x)), z))
end
end
@testset "rot!" begin
if elty <: Real
x = convert(Vector{elty}, randn(n))
y = convert(Vector{elty}, randn(n))
c = rand(elty)
s = rand(elty)
x2 = copy(x)
y2 = copy(y)
BLAS.rot!(n, x, 1, y, 1, c, s)
@test x ≈ c*x2 + s*y2
@test y ≈ -s*x2 + c*y2
else
x = convert(Vector{elty}, complex.(randn(n),rand(n)))
y = convert(Vector{elty}, complex.(randn(n),rand(n)))
cty = (elty == ComplexF32) ? Float32 : Float64
c = rand(cty)
for sty in [cty, elty]
s = rand(sty)
x2 = copy(x)
y2 = copy(y)
BLAS.rot!(n, x, 1, y, 1, c, s)
@test x ≈ c*x2 + s*y2
@test y ≈ -conj(s)*x2 + c*y2
end
end
end
@testset "axp(b)y" begin
if elty <: Real
x1 = convert(Vector{elty}, randn(n))
Expand Down
19 changes: 19 additions & 0 deletions stdlib/LinearAlgebra/test/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,25 @@ end
@test norm(x, 3) ≈ cbrt(5^3 +sqrt(5)^3)
end

@testset "rot! and ref!" begin
x = rand(1000)
y = rand(1000)
c = rand()
s = rand(ComplexF64)

x2 = copy(x)
y2 = copy(y)
rot!(n, x, y, c, s)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
rot!(n, x, y, c, s)
rot!(x, y, c, s)

@test x ≈ c*x2 + s*y2
@test y ≈ -conj(s)*x2 + c*y2

x3 = copy(x)
y3 = copy(y)
ref!(n, x, y, c, s)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ref!(n, x, y, c, s)
ref!(x, y, c, s)

@test x ≈ c*x3 + s*y3
@test y ≈ conj(s)*x3 - c*y3
end

@testset "LinearAlgebra.axp(b)y! for element type without commutative multiplication" begin
α = [1 2; 3 4]
β = [5 6; 7 8]
Expand Down