diff --git a/src/OffsetArrays.jl b/src/OffsetArrays.jl index cd8f6ac1..98097d83 100644 --- a/src/OffsetArrays.jl +++ b/src/OffsetArrays.jl @@ -1,7 +1,6 @@ module OffsetArrays using Base: Indices, tail, @propagate_inbounds -import Base: (*), convert, promote_rule @static if !isdefined(Base, :IdentityUnitRange) const IdentityUnitRange = Base.Slice @@ -406,28 +405,23 @@ end no_offset_view(A::OffsetArray) = no_offset_view(parent(A)) - # Quick hack for matrix multiplication. # Ideally, one would instead improve LinearAlgebra's support of custom indexing. -function (*)(A::OffsetMatrix, B::OffsetMatrix) +function Base.:(*)(A::OffsetMatrix, B::OffsetMatrix) matmult_check_axes(A, B) - C = OffsetArray(parent(A) * parent(B), (axes(A,1), axes(B,2))) + C = parent(A) * parent(B) + OffsetArray{eltype(C), 2, typeof(C)}(C, (A.offsets[1], B.offsets[2])) end -function (*)(A::OffsetMatrix, B::OffsetVector) +function Base.:(*)(A::OffsetMatrix, B::OffsetVector) matmult_check_axes(A, B) - C = OffsetArray(parent(A) * parent(B), axes(A,1)) + C = parent(A) * parent(B) + OffsetArray{eltype(C), 1, typeof(C)}(C, (A.offsets[1], )) +end +function matmult_check_axes(A, B) + axes(A, 2) === axes(B, 1) || axes(A, 2) == axes(B, 1) || + error("axes(A,2) = $(UnitRange(axes(A,2))) does not equal axes(B,1) = $(UnitRange(axes(B,1)))") end -matmult_check_axes(A, B) = axes(A, 2) == axes(B, 1) || error("axes(A,2) must equal axes(B,1)") - -(*)(A::OffsetMatrix, B::AbstractMatrix) = A * OffsetArray(B) -(*)(A::OffsetMatrix, B::AbstractVector) = A * OffsetArray(B) -(*)(A::AbstractMatrix, B::OffsetArray) = OffsetArray(A) * B -(*)(A::AbstractVector, B::OffsetArray) = OffsetArray(A) * B - -# An alternative to the above four methods would be to use promote_rule, but it doesn't get invoked -# promote_rule(::Type{A1}, ::Type{A2}) where A1<:AbstractArray{<:Any,N} where A2<:OffsetArray{<:Any,N,A3} where {N,A3} = OffsetArray{eltype(promote_type(A1, A3)), N, promote_type(A1, A3)} - #### # work around for segfault in searchsorted* diff --git a/test/runtests.jl b/test/runtests.jl index 1721621f..b2aca0b5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -946,3 +946,34 @@ end @test searchsorted(o, 5) == 2:2 @test searchsorted(o, 6) == 3:2 end + +@testset "Matrix multiplication" begin + a = [1 2; 3 4] + v = [5, 6] + oa = OffsetArray(a, (2, 2)) + ov = OffsetVector(v, (2,)) + + @test parent(oa * oa) == a * a + @test axes(oa * oa) == axes(oa) + + @test parent(oa * ov) == a * v + @test axes(oa * ov) == (axes(oa, 1),) + + @test parent(ones(2, 2:3) * ones(2:3, 3:5)) == ones(2, 2) * ones(2, 3) + @test axes(ones(2, 2:3) * ones(2:3, 3:5)) == (1:2, 3:5) + + @test parent(ones(2, 2:3) * ones(2:3)) == ones(2, 2) * ones(2) + @test axes(ones(2, 2:3) * ones(2:3)) == (1:2,) + + # One-based arrays + oa2 = OffsetArray(a, axes(a)) + @test oa2 * a == a * a + @test a * oa2 == a * a + + @test oa2 * v == a * v + @test v' * oa2 == v' * a + + @test_throws Exception zeros(2, 2:3) * zeros(2:4, 2) + @test_throws Exception zeros(2, 2:3) * zeros(3:4, 2) + @test_throws Exception zeros(2, 2:3) * zeros(2:4) +end \ No newline at end of file