Skip to content
Merged
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ CategoricalArraysStructTypesExt = "StructTypes"

[compat]
Arrow = "2"
Compat = "3.37, 4"
Compat = "3.47, 4.10"
DataAPI = "1.6"
JSON = "0.15, 0.16, 0.17, 0.18, 0.19, 0.20, 0.21"
JSON3 = "1.1.2"
Expand Down
4 changes: 3 additions & 1 deletion src/CategoricalArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ module CategoricalArrays
import DataAPI: unwrap
export unwrap

using Compat
@compat public default_formatter, numbered_formatter

using DataAPI
using Missings
using Printf
import Compat

# JuliaLang/julia#36810
if VERSION < v"1.5.2"
Expand Down
182 changes: 146 additions & 36 deletions src/extras.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,67 @@ function fill_refs!(refs::AbstractArray, X::AbstractArray,
end
end

if VERSION >= v"1.10"
const CUT_FMT = Printf.Format("%.*g")
end

"""
default_formatter(from, to, i; leftclosed, rightclosed)
CategoricalArrays.default_formatter(from, to, i::Integer;
leftclosed::Bool, rightclosed::Bool,
sigdigits::Integer)

Provide the default label format for the `cut(x, breaks)` method.
Provide the default label format for the `cut(x, breaks)` method,
which is `"[from, to)"` if `leftclosed` is `true` and `"[from, to)"` otherwise.

If they are floating points values, breaks are turned into to strings using
`@sprintf("%.*g", sigdigits, break)`
(or `to` using `@sprintf("%.*g", sigdigits, break)` for the last break).
"""
default_formatter(from, to, i; leftclosed, rightclosed) =
string(leftclosed ? "[" : "(", from, ", ", to, rightclosed ? "]" : ")")
function default_formatter(from, to, i::Integer;
leftclosed::Bool, rightclosed::Bool,
sigdigits::Integer)
@static if VERSION >= v"1.10"
from_str = from isa AbstractFloat ?
Printf.format(CUT_FMT, sigdigits, from) :
string(from)
to_str = to isa AbstractFloat ?
Printf.format(CUT_FMT, sigdigits, to) :
string(to)
else
from_str = from isa AbstractFloat ?
Printf.format(Printf.Format("%.$(sigdigits)g"), from) :
string(from)
to_str = to isa AbstractFloat ?
Printf.format(Printf.Format("%.$(sigdigits)g"), to) :
string(to)
end
string(leftclosed ? "[" : "(", from_str, ", ", to_str, rightclosed ? "]" : ")")
end

"""
CategoricalArrays.numbered_formatter(from, to, i::Integer;
leftclosed::Bool, rightclosed::Bool,
sigdigits::Integer)

Provide the default label format for the `cut(x, ngroups)` method
when `allowempty=true`, which is `"i: [from, to)"` if `leftclosed`
is `true` and `"i: [from, to)"` otherwise.

If they are floating points values, breaks are turned into to strings using
`@sprintf("%.*g", sigdigits, breaks)`
(or `to` using `@sprintf("%.*g", sigdigits, break)` for the last break).
"""
numbered_formatter(from, to, i::Integer;
leftclosed::Bool, rightclosed::Bool,
sigdigits::Integer) =
string(i, ": ",
default_formatter(from, to, i, leftclosed=leftclosed, rightclosed=rightclosed,
sigdigits=sigdigits))

@doc raw"""
cut(x::AbstractArray, breaks::AbstractVector;
labels::Union{AbstractVector,Function},
sigdigits::Integer=3,
extend::Union{Bool,Missing}=false, allowempty::Bool=false)

Cut a numeric array into intervals at values `breaks`
Expand All @@ -49,15 +99,25 @@ the last interval, which is closed on both ends, i.e. `[lower, upper]`.
If `x` accepts missing values (i.e. `eltype(x) >: Missing`) the returned array will
also accept them.

!!! note
For floating point data, breaks may be rounded to `sigdigits` significant digits
when generating interval labels, meaning that they may not reflect exactly the cutpoints
used.

# Keyword arguments
* `extend::Union{Bool, Missing}=false`: when `false`, an error is raised if some values
in `x` fall outside of the breaks; when `true`, breaks are automatically added to include
all values in `x`; when `missing`, values outside of the breaks generate `missing` entries.
* `labels::Union{AbstractVector, Function}`: a vector of strings, characters
or numbers giving the names to use for
the intervals; or a function `f(from, to, i; leftclosed, rightclosed)` that generates
or numbers giving the names to use for the intervals; or a function
`f(from, to, i::Integer; leftclosed::Bool, rightclosed::Bool, sigdigits::Integer)` that generates
the labels from the left and right interval boundaries and the group index. Defaults to
`"[from, to)"` (or `"[from, to]"` for the rightmost interval if `extend == true`).
[`CategoricalArrays.default_formatter`](@ref), giving `"[from, to)"` (or `"[from, to]"`
for the rightmost interval if `extend == true`).
* `sigdigits::Integer=3`: the minimum number of significant digits to use in labels.
This value is increased automatically if necessary so that rounded breaks are unique.
Only used for floating point types and when `labels` is a function, in which case it
is passed to it as a keyword argument.
* `allowempty::Bool=false`: when `false`, an error is raised if some breaks other than
the last one appear multiple times, generating empty intervals; when `true`,
duplicate breaks are allowed and the intervals they generate are kept as
Expand All @@ -69,19 +129,19 @@ julia> using CategoricalArrays

julia> cut(-1:0.5:1, [0, 1], extend=true)
5-element CategoricalArray{String,1,UInt32}:
"[-1.0, 0.0)"
"[-1.0, 0.0)"
"[0.0, 1.0]"
"[0.0, 1.0]"
"[0.0, 1.0]"
"[-1, 0)"
"[-1, 0)"
"[0, 1]"
"[0, 1]"
"[0, 1]"

julia> cut(-1:0.5:1, 2)
5-element CategoricalArray{String,1,UInt32}:
"Q1: [-1.0, 0.0)"
"Q1: [-1.0, 0.0)"
"Q2: [0.0, 1.0]"
"Q2: [0.0, 1.0]"
"Q2: [0.0, 1.0]"
"[-1, 0)"
"[-1, 0)"
"[0, 1]"
"[0, 1]"
"[0, 1]"

julia> cut(-1:0.5:1, 2, labels=["A", "B"])
5-element CategoricalArray{String,1,UInt32}:
Expand Down Expand Up @@ -114,15 +174,17 @@ julia> cut(-1:0.5:1, 3, labels=fmt)
@inline function cut(x::AbstractArray, breaks::AbstractVector;
extend::Union{Bool, Missing}=false,
labels::Union{AbstractVector{<:SupportedTypes},Function}=default_formatter,
sigdigits::Integer=3,
allowempty::Bool=false)
return _cut(x, breaks, extend, labels, allowempty)
return _cut(x, breaks, extend, labels, sigdigits, allowempty)
end

# Separate function for inferability (thanks to inlining of cut)
function _cut(x::AbstractArray{T, N}, breaks::AbstractVector,
extend::Union{Bool, Missing},
labels::Union{AbstractVector{<:SupportedTypes},Function},
allowempty::Bool=false) where {T, N}
sigdigits::Integer,
allowempty::Bool) where {T, N}
if !issorted(breaks)
breaks = sort(breaks)
end
Expand Down Expand Up @@ -179,21 +241,60 @@ function _cut(x::AbstractArray{T, N}, breaks::AbstractVector,
end
end

# Find minimal number of digits so that distinct breaks remain so
if eltype(breaks) <: AbstractFloat
while true
local i
for outer i in 2:lastindex(breaks)
b1 = breaks[i-1]
b2 = breaks[i]
isequal(b1, b2) && continue

@static if VERSION >= v"1.9"
b1_str = Printf.format(CUT_FMT, sigdigits, b1)
b2_str = Printf.format(CUT_FMT, sigdigits, b2)
else
b1_str = Printf.format(Printf.Format("%.$(sigdigits)g"), b1)
b2_str = Printf.format(Printf.Format("%.$(sigdigits)g"), b2)
end
if b1_str == b2_str
sigdigits += 1
break
end
end
i == lastindex(breaks) && break
end
end
n = length(breaks)
n >= 2 || throw(ArgumentError("at least two breaks must be provided when extend is not true"))
if labels isa Function
from = breaks[1:n-1]
to = breaks[2:n]
firstlevel = labels(from[1], to[1], 1,
leftclosed=!isequal(breaks[1], breaks[2]), rightclosed=false)
local firstlevel
try
firstlevel = labels(from[1], to[1], 1,
leftclosed=!isequal(breaks[1], breaks[2]), rightclosed=false,
sigdigits=sigdigits)
catch
# Support functions defined before v1.0, where sigdigits did not exist
Base.depwarn("`labels` function is now required to accept a `sigdigits` keyword argument",
:cut)
labels_orig = labels
labels = (from, to, i; leftclosed, rightclosed, sigdigits) ->
labels_orig(from, to, i; leftclosed, rightclosed)
firstlevel = labels_orig(from[1], to[1], 1,
leftclosed=!isequal(breaks[1], breaks[2]), rightclosed=false)
end
levs = Vector{typeof(firstlevel)}(undef, n-1)
levs[1] = firstlevel
for i in 2:n-2
levs[i] = labels(from[i], to[i], i,
leftclosed=!isequal(breaks[i], breaks[i+1]), rightclosed=false)
leftclosed=!isequal(breaks[i], breaks[i+1]), rightclosed=false,
sigdigits=sigdigits)
end
levs[end] = labels(from[end], to[end], n-1,
leftclosed=true, rightclosed=true)
leftclosed=true, rightclosed=true,
sigdigits=sigdigits)
else
length(labels) == n-1 ||
throw(ArgumentError("labels must be of length $(n-1), but got length $(length(labels))"))
Expand All @@ -213,14 +314,6 @@ function _cut(x::AbstractArray{T, N}, breaks::AbstractVector,
CategoricalArray{S, N}(refs, pool)
end

"""
quantile_formatter(from, to, i; leftclosed, rightclosed)

Provide the default label format for the `cut(x, ngroups)` method.
"""
quantile_formatter(from, to, i; leftclosed, rightclosed) =
string("Q", i, ": ", leftclosed ? "[" : "(", from, ", ", to, rightclosed ? "]" : ")")

"""
Find first value in (sorted) `v` which is greater than or equal to each quantile
in (sorted) `qs`.
Expand All @@ -247,6 +340,7 @@ end
"""
cut(x::AbstractArray, ngroups::Integer;
labels::Union{AbstractVector{<:AbstractString},Function},
sigdigits::Integer=3,
allowempty::Bool=false)

Cut a numeric array into `ngroups` quantiles.
Expand All @@ -257,19 +351,32 @@ but breaks are taken from actual data values instead of estimated quantiles.
If `x` contains `missing` values, they are automatically skipped when computing
quantiles.

!!! note
For floating point data, breaks may be rounded to `sigdigits` significant digits
when generating interval labels, meaning that they may not reflect exactly the cutpoints
used.

# Keyword arguments
* `labels::Union{AbstractVector, Function}`: a vector of strings, characters
or numbers giving the names to use for
the intervals; or a function `f(from, to, i; leftclosed, rightclosed)` that generates
or numbers giving the names to use for the intervals; or a function
`f(from, to, i::Integer; leftclosed::Bool, rightclosed::Bool, sigdigits::Integer)` that generates
the labels from the left and right interval boundaries and the group index. Defaults to
`"Qi: [from, to)"` (or `"Qi: [from, to]"` for the rightmost interval).
[`CategoricalArrays.default_formatter`](@ref), giving `"[from, to)"` (or `"[from, to]"`
for the rightmost interval if `extend == true`) if `allowempty=false`, otherwise to
[`CategoricalArrays.numbered_formatter`](@ref), which prefixes the label with the quantile
number to ensure uniqueness.
* `sigdigits::Integer=3`: the minimum number of significant digits to use when rounding
breaks for inclusion in generated labels. This value is increased automatically if necessary
so that rounded breaks are unique. Only used for floating point types and when `labels` is a
function, in which case it is passed to it as a keyword argument.
* `allowempty::Bool=false`: when `false`, an error is raised if some quantiles breakpoints
other than the last one are equal, generating empty intervals;
when `true`, duplicate breaks are allowed and the intervals they generate are kept as
unused levels (but duplicate labels are not allowed).
"""
function cut(x::AbstractArray, ngroups::Integer;
labels::Union{AbstractVector{<:SupportedTypes},Function}=quantile_formatter,
labels::Union{AbstractVector{<:SupportedTypes},Function,Nothing}=nothing,
sigdigits::Integer=3,
allowempty::Bool=false)
ngroups >= 1 || throw(ArgumentError("ngroups must be strictly positive (got $ngroups)"))
sorted_x = eltype(x) >: Missing ? sort!(collect(skipmissing(x))) : sort(x)
Expand All @@ -286,5 +393,8 @@ function cut(x::AbstractArray, ngroups::Integer;
"Pass `allowempty=true` to allow empty quantiles or " *
"choose a lower value for `ngroups`."))
end
cut(x, breaks; labels=labels, allowempty=allowempty)
if labels === nothing
labels = allowempty ? numbered_formatter : default_formatter
end
return cut(x, breaks; labels=labels, sigdigits=sigdigits, allowempty=allowempty)
end
Loading
Loading