Skip to content

Commit

Permalink
Fix several broadcast issues
Browse files Browse the repository at this point in the history
  • Loading branch information
wsshin committed Aug 8, 2017
1 parent da1a371 commit 3a49b3a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 15 deletions.
6 changes: 4 additions & 2 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import Base.Broadcast:
# This isn't the precise output type, just a placeholder to return from
# promote_containertype, which will control dispatch to our broadcast_c.
_containertype(::Type{<:StaticArray}) = StaticArray
_containertype(::Type{<:RowVector{<:Any,<:SVector}}) = StaticArray

# With the above, the default promote_containertype gives reasonable defaults:
# StaticArray, StaticArray -> StaticArray
Expand All @@ -32,6 +33,7 @@ broadcast_indices(::Type{StaticArray}, A) = indices(A)
_broadcast(f, broadcast_sizes(as...), as...)
end

@inline broadcast_sizes(a::RowVector{<:Any,<:SVector}, as...) = (Size(a), broadcast_sizes(as...)...)
@inline broadcast_sizes(a::StaticArray, as...) = (Size(a), broadcast_sizes(as...)...)
@inline broadcast_sizes(a, as...) = (Size(), broadcast_sizes(as...)...)
@inline broadcast_sizes() = ()
Expand Down Expand Up @@ -66,9 +68,9 @@ end
for i = 1:length(sizes)
s = sizes[i]
for j = 1:length(s)
if newsize[j] == 1 || newsize[j] == s[j]
if newsize[j] == 1
newsize[j] = s[j]
else
elseif newsize[j] s[j] && s[j] 1
throw(DimensionMismatch("Tried to broadcast on inputs sized $sizes"))
end
end
Expand Down
26 changes: 13 additions & 13 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,32 +45,32 @@ end
@testset "2x2 StaticMatrix with 1x2 StaticMatrix" begin
m1 = @SMatrix [1 2; 3 4]
m2 = @SMatrix [1 4]
@test_broken @inferred(broadcast(+, m1, m2)) === @SMatrix [2 6; 4 8] #197
@test_broken @inferred(m1 .+ m2) === @SMatrix [2 6; 4 8] #197
@test @inferred(broadcast(+, m1, m2)) === @SMatrix [2 6; 4 8] #197
@test @inferred(m1 .+ m2) === @SMatrix [2 6; 4 8] #197
@test @inferred(m2 .+ m1) === @SMatrix [2 6; 4 8]
@test_broken @inferred(m1 .* m2) === @SMatrix [1 8; 3 16] #197
@test @inferred(m1 .* m2) === @SMatrix [1 8; 3 16] #197
@test @inferred(m2 .* m1) === @SMatrix [1 8; 3 16]
@test_broken @inferred(m1 ./ m2) === @SMatrix [1 1/2; 3 1] #197
@test @inferred(m1 ./ m2) === @SMatrix [1 1/2; 3 1] #197
@test @inferred(m2 ./ m1) === @SMatrix [1 2; 1/3 1]
@test_broken @inferred(m1 .- m2) === @SMatrix [0 -2; 2 0] #197
@test @inferred(m1 .- m2) === @SMatrix [0 -2; 2 0] #197
@test @inferred(m2 .- m1) === @SMatrix [0 2; -2 0]
@test_broken @inferred(m1 .^ m2) === @SMatrix [1 16; 1 256] #197
@test @inferred(m1 .^ m2) === @SMatrix [1 16; 3 256] #197
end

@testset "1x2 StaticMatrix with StaticVector" begin
m = @SMatrix [1 2]
v = SVector(1, 4)
@test @inferred(broadcast(+, m, v)) === @SMatrix [2 3; 5 6]
@test @inferred(m .+ v) === @SMatrix [2 3; 5 6]
@test_broken @inferred(v .+ m) === @SMatrix [2 3; 5 6] #197
@test @inferred(v .+ m) === @SMatrix [2 3; 5 6] #197
@test @inferred(m .* v) === @SMatrix [1 2; 4 8]
@test_broken @inferred(v .* m) === @SMatrix [1 2; 4 8] #197
@test @inferred(v .* m) === @SMatrix [1 2; 4 8] #197
@test @inferred(m ./ v) === @SMatrix [1 2; 1/4 1/2]
@test_broken @inferred(v ./ m) === @SMatrix [1 1/2; 4 2] #197
@test @inferred(v ./ m) === @SMatrix [1 1/2; 4 2] #197
@test @inferred(m .- v) === @SMatrix [0 1; -3 -2]
@test_broken @inferred(v .- m) === @SMatrix [0 -1; 3 2] #197
@test @inferred(v .- m) === @SMatrix [0 -1; 3 2] #197
@test @inferred(m .^ v) === @SMatrix [1 2; 1 16]
@test_broken @inferred(v .^ m) === @SMatrix [1 1; 4 16] #197
@test @inferred(v .^ m) === @SMatrix [1 1; 4 16] #197
end

@testset "StaticVector with StaticVector" begin
Expand All @@ -89,9 +89,9 @@ end
@test @inferred(v2 .^ v1) === SVector(1, 16)
# test case issue #199
@test @inferred(SVector(1) .+ SVector()) === SVector()
@test_broken @inferred(SVector() .+ SVector(1)) === SVector()
@test @inferred(SVector() .+ SVector(1)) === SVector()
# test case issue #200
@test_broken @inferred(v1 .+ v2') === @SMatrix [2 5; 3 5]
@test @inferred(v1 .+ v2') === @SMatrix [2 5; 3 6]
end

@testset "StaticVector with Scalar" begin
Expand Down

0 comments on commit 3a49b3a

Please sign in to comment.