Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make BitSet store any Int #25029

Merged
merged 9 commits into from
Dec 19, 2017
14 changes: 6 additions & 8 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1302,25 +1302,23 @@ function empty!(a::Vector)
return a
end

_memcmp(a, b, len) = ccall(:memcmp, Int32, (Ptr{Void}, Ptr{Void}, Csize_t), a, b, len) % Int

# use memcmp for lexcmp on byte arrays
function lexcmp(a::Array{UInt8,1}, b::Array{UInt8,1})
c = ccall(:memcmp, Int32, (Ptr{UInt8}, Ptr{UInt8}, UInt),
a, b, min(length(a),length(b)))
c = _memcmp(a, b, min(length(a),length(b)))
return c < 0 ? -1 : c > 0 ? +1 : cmp(length(a),length(b))
end

const BitIntegerArray{N} = Union{map(T->Array{T,N}, BitInteger_types)...} where N
# use memcmp for == on bit integer types
function ==(a::Arr, b::Arr) where Arr <: BitIntegerArray
size(a) == size(b) && 0 == ccall(
:memcmp, Int32, (Ptr{Void}, Ptr{Void}, UInt), a, b, sizeof(eltype(Arr)) * length(a))
end
==(a::Arr, b::Arr) where {Arr <: BitIntegerArray} =
size(a) == size(b) && 0 == _memcmp(a, b, sizeof(eltype(Arr)) * length(a))

# this is ~20% faster than the generic implementation above for very small arrays
function ==(a::Arr, b::Arr) where Arr <: BitIntegerArray{1}
len = length(a)
len == length(b) && 0 == ccall(
:memcmp, Int32, (Ptr{Void}, Ptr{Void}, UInt), a, b, sizeof(eltype(Arr)) * len)
len == length(b) && 0 == _memcmp(a, b, sizeof(eltype(Arr)) * len)
end

"""
Expand Down
40 changes: 24 additions & 16 deletions base/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ IndexStyle(::Type{<:BitArray}) = IndexLinear()
## aux functions ##

const _msk64 = ~UInt64(0)
@inline _div64(l) = l >>> 6
@inline _div64(l) = l >> 6
@inline _mod64(l) = l & 63
@inline _msk_end(l::Integer) = _msk64 >>> _mod64(-l)
@inline _msk_end(B::BitArray) = _msk_end(length(B))
Expand Down Expand Up @@ -636,6 +636,10 @@ end

@inline function unsafe_bitsetindex!(Bc::Array{UInt64}, x::Bool, i::Int)
i1, i2 = get_chunks_id(i)
_unsafe_bitsetindex!(Bc, x, i1, i2)
end

@inline function _unsafe_bitsetindex!(Bc::Array{UInt64}, x::Bool, i1::Int, i2::Int)
u = UInt64(1) << i2
@inbounds begin
c = Bc[i1]
Expand Down Expand Up @@ -1438,22 +1442,17 @@ circshift!(B::BitVector, i::Integer) = circshift!(B, B, i)

## count & find ##

function count(B::BitArray)
function bitcount(Bc::Vector{UInt64})
n = 0
Bc = B.chunks
@inbounds for i = 1:length(Bc)
n += count_ones(Bc[i])
end
return n
end

# returns the index of the next non-zero element, or 0 if all zeros
function findnext(B::BitArray, start::Integer)
start > 0 || throw(BoundsError(B, start))
start > length(B) && return 0

Bc = B.chunks
count(B::BitArray) = bitcount(B.chunks)

function unsafe_bitfindnext(Bc::Vector{UInt64}, start::Integer)
chunk_start = _div64(start-1)+1
within_chunk_start = _mod64(start-1)
mask = _msk64 << within_chunk_start
Expand All @@ -1471,6 +1470,14 @@ function findnext(B::BitArray, start::Integer)
end
return 0
end

# returns the index of the next non-zero element, or 0 if all zeros
function findnext(B::BitArray, start::Integer)
start > 0 || throw(BoundsError(B, start))
start > length(B) && return 0
unsafe_bitfindnext(B.chunks, start)
end

#findfirst(B::BitArray) = findnext(B, 1) ## defined in array.jl

# aux function: same as findnext(~B, start), but performed without temporaries
Expand Down Expand Up @@ -1527,13 +1534,7 @@ function findnext(testf::Function, B::BitArray, start::Integer)
end
#findfirst(testf::Function, B::BitArray) = findnext(testf, B, 1) ## defined in array.jl

# returns the index of the previous non-zero element, or 0 if all zeros
function findprev(B::BitArray, start::Integer)
start > 0 || return 0
start > length(B) && throw(BoundsError(B, start))

Bc = B.chunks

function unsafe_bitfindprev(Bc::Vector{UInt64}, start::Integer)
chunk_start = _div64(start-1)+1
mask = _msk_end(start)

Expand All @@ -1551,6 +1552,13 @@ function findprev(B::BitArray, start::Integer)
return 0
end

# returns the index of the previous non-zero element, or 0 if all zeros
function findprev(B::BitArray, start::Integer)
start > 0 || return 0
start > length(B) && throw(BoundsError(B, start))
unsafe_bitfindprev(B.chunks, start)
end

function findprevnot(B::BitArray, start::Integer)
start > 0 || return 0
start > length(B) && throw(BoundsError(B, start))
Expand Down
Loading