Skip to content

Commit

Permalink
use offsets instead of axes
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Sep 15, 2020
1 parent f34598c commit 0529fe8
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 16 deletions.
26 changes: 10 additions & 16 deletions src/OffsetArrays.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module OffsetArrays

using Base: Indices, tail, @propagate_inbounds
import Base: (*), convert, promote_rule

@static if !isdefined(Base, :IdentityUnitRange)
const IdentityUnitRange = Base.Slice
Expand Down Expand Up @@ -406,28 +405,23 @@ end

no_offset_view(A::OffsetArray) = no_offset_view(parent(A))


# Quick hack for matrix multiplication.
# Ideally, one would instead improve LinearAlgebra's support of custom indexing.
function (*)(A::OffsetMatrix, B::OffsetMatrix)
function Base.:(*)(A::OffsetMatrix, B::OffsetMatrix)
matmult_check_axes(A, B)
C = OffsetArray(parent(A) * parent(B), (axes(A,1), axes(B,2)))
C = parent(A) * parent(B)
OffsetArray{eltype(C), 2, typeof(C)}(C, (A.offsets[1], B.offsets[2]))
end

function (*)(A::OffsetMatrix, B::OffsetVector)
function Base.:(*)(A::OffsetMatrix, B::OffsetVector)
matmult_check_axes(A, B)
C = OffsetArray(parent(A) * parent(B), axes(A,1))
C = parent(A) * parent(B)
OffsetArray{eltype(C), 1, typeof(C)}(C, (A.offsets[1], ))
end
function matmult_check_axes(A, B)
axes(A, 2) === axes(B, 1) || axes(A, 2) == axes(B, 1) ||
error("axes(A,2) = $(UnitRange(axes(A,2))) does not equal axes(B,1) = $(UnitRange(axes(B,1)))")
end
matmult_check_axes(A, B) = axes(A, 2) == axes(B, 1) || error("axes(A,2) must equal axes(B,1)")

(*)(A::OffsetMatrix, B::AbstractMatrix) = A * OffsetArray(B)
(*)(A::OffsetMatrix, B::AbstractVector) = A * OffsetArray(B)
(*)(A::AbstractMatrix, B::OffsetArray) = OffsetArray(A) * B
(*)(A::AbstractVector, B::OffsetArray) = OffsetArray(A) * B

# An alternative to the above four methods would be to use promote_rule, but it doesn't get invoked
# promote_rule(::Type{A1}, ::Type{A2}) where A1<:AbstractArray{<:Any,N} where A2<:OffsetArray{<:Any,N,A3} where {N,A3} = OffsetArray{eltype(promote_type(A1, A3)), N, promote_type(A1, A3)}


####
# work around for segfault in searchsorted*
Expand Down
31 changes: 31 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -946,3 +946,34 @@ end
@test searchsorted(o, 5) == 2:2
@test searchsorted(o, 6) == 3:2
end

@testset "Matrix multiplication" begin
a = [1 2; 3 4]
v = [5, 6]
oa = OffsetArray(a, (2, 2))
ov = OffsetVector(v, (2,))

@test parent(oa * oa) == a * a
@test axes(oa * oa) == axes(oa)

@test parent(oa * ov) == a * v
@test axes(oa * ov) == (axes(oa, 1),)

@test parent(ones(2, 2:3) * ones(2:3, 3:5)) == ones(2, 2) * ones(2, 3)
@test axes(ones(2, 2:3) * ones(2:3, 3:5)) == (1:2, 3:5)

@test parent(ones(2, 2:3) * ones(2:3)) == ones(2, 2) * ones(2)
@test axes(ones(2, 2:3) * ones(2:3)) == (1:2,)

# One-based arrays
oa2 = OffsetArray(a, axes(a))
@test oa2 * a == a * a
@test a * oa2 == a * a

@test oa2 * v == a * v
@test v' * oa2 == v' * a

@test_throws Exception zeros(2, 2:3) * zeros(2:4, 2)
@test_throws Exception zeros(2, 2:3) * zeros(3:4, 2)
@test_throws Exception zeros(2, 2:3) * zeros(2:4)
end

0 comments on commit 0529fe8

Please sign in to comment.