Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add partition algorithms #82

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions src/SortingAlgorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -932,4 +932,131 @@ function sort!(v::AbstractVector, lo::Integer, hi::Integer, ::PagedMergeSortAlg,
pagedmergesort!(v, lo, hi, o, scratch, pageLocations)
return v
end

# Partition

partition(by, v::AbstractVector; kws...) = partition!(by, similar(v), v; kws...)

function _checkbounds(v, lo, hi)
checkbounds(Bool, v, lo:hi) ||
error("$(typeof(v)) failed a boundscheck of the form `checkbounds(Bool, v, firstindex(v):lastindex(v))`")
end
function _unstable_partition!(by, v::AbstractVector)
lo, hi = firstindex(v), lastindex(v)
@boundscheck _checkbounds(v, lo, hi) # This should't happen for valid AbstractArrays
@inbounds while true
while lo < hi && !by(v[lo]) lo += 1 end
while lo < hi && by(v[hi]) hi -= 1 end
lo < hi || return lo
v[lo], v[hi] = v[hi], v[lo]
lo += 1
hi -= 1
end
end
function _branching_lo_stable_partition!(by, v::AbstractVector)
lo, hi = firstindex(v), lastindex(v)
@boundscheck _checkbounds(v, lo, hi) # This should't happen for valid AbstractArrays
while lo <= hi && !by(v[lo]) lo += 1 end
mi = lo + 1
@inbounds while mi <= hi
if !by(v[mi])
v[lo], v[mi] = v[mi], v[lo]
lo += 1
mi += 1
else
mi += 1
end
end
lo
end
function _branchless_lo_stable_partition!(by, v::AbstractVector)
i, hi = firstindex(v), lastindex(v)
@boundscheck _checkbounds(v, i, hi) # This should't happen for valid AbstractArrays
delta = 0
@inbounds while i <= hi
res = !(by(v[i])::Bool)
id = i - delta * res
v[i], v[id] = v[id], v[i]
delta += res
i += 1
end
delta + 1
end
function _branching_stable_reversestable_partition_to!(by, dest::AbstractVector, src::AbstractVector)
src_i = firstindex(src)
src_end = lastindex(src)
dest_lo = firstindex(dest)
dest_hi = lastindex(dest)
@boundscheck begin
_checkbounds(src, src_i, src_end)
_checkbounds(dest, dest_lo, dest_hi)
dest_hi - dest_lo == src_end - src_i ||
throw(DimensionMismatch("length mismatch: $(dest_hi - dest_lo) != $(src_end - src_i)"))
end
@inbounds while src_i <= src_end
if by(src[src_i])
dest[dest_hi] = src[src_i]
dest_hi -= 1
else
dest[dest_lo] = src[src_i]
dest_lo += 1
end
src_i += 1
end
dest_lo
end
function _branchless_stable_reversestable_partition_to!(by, dest::AbstractVector, src::AbstractVector)
i = firstindex(src)
src_end = lastindex(src)
dest_lo = firstindex(dest)
dest_hi = lastindex(dest)
@boundscheck begin
_checkbounds(src, i, src_end)
_checkbounds(dest, dest_lo, dest_hi)
dest_hi - dest_lo == src_end - i ||
throw(DimensionMismatch("length mismatch: $(dest_hi - dest_lo + 1) != $(src_end - i + 1)"))
end
transfer = dest_lo - i
delta = dest_hi - dest_lo + i
@inbounds while i <= src_end
res = by(src[i])::Bool
target = i + transfer + (delta - i)*res
dest[target] = src[i]
transfer -= res
i += 1
end
i + transfer
end
function _base_stable_reversestable_partition_to!(by, dest::AbstractVector, src::AbstractVector)
lo = firstindex(src)
hi = lastindex(src)
dest_lo = firstindex(dest)
dest_hi = lastindex(dest)
@boundscheck begin
_checkbounds(src, lo, hi)
_checkbounds(dest, dest_lo, dest_hi)
dest_hi - dest_lo == hi - lo ||
throw(DimensionMismatch("length mismatch: $(dest_hi - dest_lo + 1) != $(hi - lo + 1)"))
end
offset = lo - dest_lo
@inbounds while lo <= hi
x = src[lo]
fx = by(x)
dest[(fx ? hi : lo) - offset] = x
offset += fx
lo += 1
end
lo - offset
end

end # module

# x = rand(Int, 1000);
# y = similar(x);
# z = similar(x);
# @btime SortingAlgorithms._unstable_partition!(isodd, copyto!($y, $x)) # 816.176 ns
# @btime SortingAlgorithms._branching_lo_stable_partition!(isodd, copyto!($y, $x)) # 752.101 ns
# @btime SortingAlgorithms._branchless_lo_stable_partition!(isodd, copyto!($y, $x)) # 679.054 ns
# @btime SortingAlgorithms._branching_stable_reversestable_partition_to!(isodd, $z, copyto!($y, $x)) # 734.336 ns
# @btime SortingAlgorithms._branchless_stable_reversestable_partition_to!(isodd, $z, copyto!($y, $x)) # 829.861 ns
# @btime SortingAlgorithms._base_stable_reversestable_partition_to!(isodd, $z, copyto!($y, $x)) # 687.908 ns
Loading