Skip to content

Commit

Permalink
Fix corner cases of broadcast binary arithmetic operations between sp…
Browse files Browse the repository at this point in the history
…arse vectors and scalars (#21515). (#22715)
  • Loading branch information
Sacha0 authored and andreasnoack committed Sep 22, 2017
1 parent 5f68e10 commit 5ad2246
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
4 changes: 0 additions & 4 deletions base/sparse/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1443,13 +1443,9 @@ scale!(x::AbstractSparseVector, a::Complex) = (scale!(nonzeros(x), a); x)
scale!(a::Real, x::AbstractSparseVector) = (scale!(nonzeros(x), a); x)
scale!(a::Complex, x::AbstractSparseVector) = (scale!(nonzeros(x), a); x)


(*)(x::AbstractSparseVector, a::Number) = SparseVector(length(x), copy(nonzeroinds(x)), nonzeros(x) * a)
(*)(a::Number, x::AbstractSparseVector) = SparseVector(length(x), copy(nonzeroinds(x)), a * nonzeros(x))
(/)(x::AbstractSparseVector, a::Number) = SparseVector(length(x), copy(nonzeroinds(x)), nonzeros(x) / a)
broadcast(::typeof(*), x::AbstractSparseVector, a::Number) = x * a
broadcast(::typeof(*), a::Number, x::AbstractSparseVector) = a * x
broadcast(::typeof(/), x::AbstractSparseVector, a::Number) = x / a

# dot
function dot(x::StridedVector{Tx}, y::SparseVectorUnion{Ty}) where {Tx<:Number,Ty<:Number}
Expand Down
10 changes: 10 additions & 0 deletions test/sparse/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1164,3 +1164,13 @@ end
@testset "spzeros with index type" begin
@test typeof(spzeros(Float32, Int16, 3)) == SparseVector{Float32,Int16}
end

@testset "corner cases of broadcast arithmetic operations with scalars (#21515)" begin
# test both scalar literals and variables
areequal(a, b, c) = isequal(a, b) && isequal(b, c)
inf, zeroh, zv, spzv = Inf, 0.0, zeros(1), spzeros(1)
@test areequal(spzv .* Inf, spzv .* inf, sparsevec(zv .* Inf))
@test areequal(Inf .* spzv, inf .* spzv, sparsevec(Inf .* zv))
@test areequal(spzv ./ 0.0, spzv ./ zeroh, sparsevec(zv ./ 0.0))
@test areequal(0.0 .\ spzv, zeroh .\ spzv, sparsevec(0.0 .\ zv))
end

0 comments on commit 5ad2246

Please sign in to comment.