diff --git a/Cargo.lock b/Cargo.lock index e0837ea4adf4..142306052eb6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3050,6 +3050,7 @@ dependencies = [ "num-traits", "polars-arrow", "polars-compute", + "polars-dtype", "polars-error", "polars-row", "polars-schema", @@ -3081,6 +3082,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "polars-dtype" +version = "0.49.1" +dependencies = [ + "boxcar", + "hashbrown 0.15.4", + "polars-arrow", + "polars-error", + "polars-utils", + "schemars", + "serde", +] + [[package]] name = "polars-dylib" version = "0.49.1" @@ -3429,6 +3443,7 @@ dependencies = [ "bytemuck", "polars-arrow", "polars-compute", + "polars-dtype", "polars-error", "polars-utils", "proptest", @@ -4327,6 +4342,7 @@ dependencies = [ "schemars_derive", "serde", "serde_json", + "uuid", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 65c9344ff571..f08568fea747 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -102,6 +102,7 @@ zstd = "0.13" polars = { version = "0.49.1", path = "crates/polars", default-features = false } polars-compute = { version = "0.49.1", path = "crates/polars-compute", default-features = false } polars-core = { version = "0.49.1", path = "crates/polars-core", default-features = false } +polars-dtype = { version = "0.49.1", path = "crates/polars-dtype", default-features = false } polars-dylib = { version = "0.49.1", path = "crates/polars-dylib", default-features = false } polars-error = { version = "0.49.1", path = "crates/polars-error", default-features = false } polars-expr = { version = "0.49.1", path = "crates/polars-expr", default-features = false } diff --git a/crates/polars-arrow/src/bitmap/builder.rs b/crates/polars-arrow/src/bitmap/builder.rs index c32cecdca502..a3ac316299f6 100644 --- a/crates/polars-arrow/src/bitmap/builder.rs +++ b/crates/polars-arrow/src/bitmap/builder.rs @@ -75,6 +75,13 @@ impl BitmapBuilder { self.bit_cap = words_available * 64; } + pub fn clear(&mut self) { + self.buf = 0; + self.bit_len = 0; + self.set_bits_in_bytes = 0; + self.bytes.clear(); + } + #[inline(always)] pub fn push(&mut self, x: bool) { self.reserve(1); diff --git a/crates/polars-arrow/src/datatypes/field.rs b/crates/polars-arrow/src/datatypes/field.rs index eb4e60b0d0f7..87e9dc34d1a0 100644 --- a/crates/polars-arrow/src/datatypes/field.rs +++ b/crates/polars-arrow/src/datatypes/field.rs @@ -6,8 +6,15 @@ use serde::{Deserialize, Serialize}; use super::{ArrowDataType, Metadata}; -pub static DTYPE_ENUM_VALUES: &str = "_PL_ENUM_VALUES"; -pub static DTYPE_CATEGORICAL: &str = "_PL_CATEGORICAL"; +// These two have the same encoding, but because older versions of Polars +// were unable to read non-u32-key arrow dictionaries while _PL_ENUM_VALUES +// is set we switched to a new version. +pub static DTYPE_ENUM_VALUES_LEGACY: &str = "_PL_ENUM_VALUES"; +pub static DTYPE_ENUM_VALUES_NEW: &str = "_PL_ENUM_VALUES2"; + +// These have different encodings. +pub static DTYPE_CATEGORICAL_LEGACY: &str = "_PL_CATEGORICAL"; +pub static DTYPE_CATEGORICAL_NEW: &str = "_PL_CATEGORICAL2"; /// Represents Arrow's metadata of a "column". /// @@ -71,7 +78,7 @@ impl Field { pub fn is_enum(&self) -> bool { if let Some(md) = &self.metadata { - md.get(DTYPE_ENUM_VALUES).is_some() + md.get(DTYPE_ENUM_VALUES_LEGACY).is_some() || md.get(DTYPE_ENUM_VALUES_NEW).is_some() } else { false } @@ -79,7 +86,7 @@ impl Field { pub fn is_categorical(&self) -> bool { if let Some(md) = &self.metadata { - md.get(DTYPE_CATEGORICAL).is_some() + md.get(DTYPE_CATEGORICAL_LEGACY).is_some() || md.get(DTYPE_CATEGORICAL_NEW).is_some() } else { false } diff --git a/crates/polars-arrow/src/datatypes/mod.rs b/crates/polars-arrow/src/datatypes/mod.rs index 79fbf3ad8d4f..be6b8ed3c390 100644 --- a/crates/polars-arrow/src/datatypes/mod.rs +++ b/crates/polars-arrow/src/datatypes/mod.rs @@ -8,7 +8,10 @@ mod schema; use std::collections::BTreeMap; use std::sync::Arc; -pub use field::{DTYPE_CATEGORICAL, DTYPE_ENUM_VALUES, Field}; +pub use field::{ + DTYPE_CATEGORICAL_LEGACY, DTYPE_CATEGORICAL_NEW, DTYPE_ENUM_VALUES_LEGACY, + DTYPE_ENUM_VALUES_NEW, Field, +}; pub use physical_type::*; use polars_utils::pl_str::PlSmallStr; pub use schema::{ArrowSchema, ArrowSchemaRef}; diff --git a/crates/polars-core/Cargo.toml b/crates/polars-core/Cargo.toml index 4a1dcf606dba..2d6c5bd1c39a 100644 --- a/crates/polars-core/Cargo.toml +++ b/crates/polars-core/Cargo.toml @@ -10,6 +10,7 @@ description = "Core of the Polars DataFrame library" [dependencies] polars-compute = { workspace = true, features = ["gather"] } +polars-dtype = { workspace = true } polars-error = { workspace = true } polars-row = { workspace = true } polars-schema = { workspace = true } @@ -33,7 +34,7 @@ rand = { workspace = true, optional = true, features = ["small_rng", "std"] } rand_distr = { workspace = true, optional = true } rayon = { workspace = true } regex = { workspace = true, optional = true } -schemars = { workspace = true, optional = true } +schemars = { workspace = true, optional = true, features = ["uuid1"] } # activate if you want serde support for Series and DataFrames serde = { workspace = true, optional = true } serde_json = { workspace = true, optional = true } @@ -132,6 +133,7 @@ serde = [ "polars-schema/serde", "polars-utils/serde", "polars-compute/serde", + "polars-dtype/serde", "arrow/io_ipc", "arrow/io_ipc_compression", "serde_json", @@ -140,9 +142,10 @@ serde-lazy = ["serde", "arrow/serde", "indexmap/serde", "chrono/serde"] dsl-schema = [ "serde", "dep:schemars", + "polars-compute/dsl-schema", + "polars-dtype/dsl-schema", "polars-schema/dsl-schema", "polars-utils/dsl-schema", - "polars-compute/dsl-schema", ] docs-selection = [ diff --git a/crates/polars-core/src/chunked_array/builder/categorical.rs b/crates/polars-core/src/chunked_array/builder/categorical.rs new file mode 100644 index 000000000000..91a6e94f2701 --- /dev/null +++ b/crates/polars-core/src/chunked_array/builder/categorical.rs @@ -0,0 +1,82 @@ +use arrow::bitmap::BitmapBuilder; + +use crate::prelude::*; + +pub struct CategoricalChunkedBuilder { + name: PlSmallStr, + dtype: DataType, + mapping: Arc, + is_enum: bool, + cats: Vec, + validity: BitmapBuilder, +} + +impl CategoricalChunkedBuilder { + pub fn new(name: PlSmallStr, dtype: DataType) -> Self { + let (DataType::Categorical(_, mapping) | DataType::Enum(_, mapping)) = &dtype else { + panic!("non-Categorical/Enum dtype in CategoricalChunkedbuilder") + }; + Self { + name, + mapping: mapping.clone(), + is_enum: matches!(dtype, DataType::Enum(_, _)), + dtype, + cats: Vec::new(), + validity: BitmapBuilder::new(), + } + } + + pub fn dtype(&self) -> &DataType { + &self.dtype + } + + pub fn reserve(&mut self, len: usize) { + self.cats.reserve(len); + self.validity.reserve(len); + } + + pub fn append_cat( + &mut self, + cat: CatSize, + mapping: &Arc, + ) -> PolarsResult<()> { + if Arc::ptr_eq(&self.mapping, mapping) { + self.cats.push(T::Native::from_cat(cat)); + self.validity.push(true); + } else if let Some(s) = mapping.cat_to_str(cat) { + self.append_str(s)?; + } else { + self.append_null(); + } + Ok(()) + } + + pub fn append_str(&mut self, val: &str) -> PolarsResult<()> { + let cat = if self.is_enum { + self.mapping.get_cat(val).ok_or_else(|| { + polars_err!(ComputeError: "attempted to insert '{val}' into Enum which does not contain this string") + })? + } else { + self.mapping.insert_cat(val)? + }; + self.cats.push(T::Native::from_cat(cat)); + self.validity.push(true); + Ok(()) + } + + pub fn append_null(&mut self) { + self.cats.push(T::Native::default()); + self.validity.push(false); + } + + pub fn finish(self) -> CategoricalChunked { + unsafe { + let phys = ChunkedArray::from_vec_validity( + self.name, + self.cats, + self.validity.into_opt_validity(), + ); + CategoricalChunked::from_cats_and_dtype_unchecked(phys, self.dtype) + } + } +} diff --git a/crates/polars-core/src/chunked_array/builder/list/anonymous.rs b/crates/polars-core/src/chunked_array/builder/list/anonymous.rs index eab5df634be2..376b3e5efa73 100644 --- a/crates/polars-core/src/chunked_array/builder/list/anonymous.rs +++ b/crates/polars-core/src/chunked_array/builder/list/anonymous.rs @@ -4,7 +4,7 @@ pub struct AnonymousListBuilder<'a> { name: PlSmallStr, builder: AnonymousBuilder<'a>, fast_explode: bool, - inner_dtype: DtypeMerger, + inner_dtype: Option, } impl Default for AnonymousListBuilder<'_> { @@ -19,7 +19,7 @@ impl<'a> AnonymousListBuilder<'a> { name, builder: AnonymousBuilder::new(capacity), fast_explode: true, - inner_dtype: DtypeMerger::new(inner_dtype), + inner_dtype, } } @@ -59,13 +59,18 @@ impl<'a> AnonymousListBuilder<'a> { } pub fn append_series(&mut self, s: &'a Series) -> PolarsResult<()> { - match s.dtype() { - // Empty arrays tend to be null type and thus differ - // if we would push it the concat would fail. - DataType::Null if s.is_empty() => self.append_empty(), - dt => self.inner_dtype.update(dt)?, + match (s.dtype(), &self.inner_dtype) { + (DataType::Null, _) => {}, + (dt, None) => self.inner_dtype = Some(dt.clone()), + (dt, Some(set_dt)) => { + polars_bail!(ComputeError: "dtypes don't match, got {dt}, expected: {set_dt}") + }, + } + if s.is_empty() { + self.append_empty(); + } else { + self.builder.push_multiple(s.chunks()); } - self.builder.push_multiple(s.chunks()); Ok(()) } @@ -76,19 +81,18 @@ impl<'a> AnonymousListBuilder<'a> { ListChunked::full_null_with_dtype( slf.name.clone(), 0, - &slf.inner_dtype.materialize().unwrap_or(DataType::Null), + &slf.inner_dtype.unwrap_or(DataType::Null), ) } else { - let inner_dtype = slf.inner_dtype.materialize(); - - let inner_dtype_physical = inner_dtype + let inner_dtype_physical = self + .inner_dtype .as_ref() .map(|dt| dt.to_physical().to_arrow(CompatLevel::newest())); let arr = slf.builder.finish(inner_dtype_physical.as_ref()).unwrap(); - let list_dtype_logical = match inner_dtype { + let list_dtype_logical = match &self.inner_dtype { None => DataType::from_arrow_dtype(arr.dtype()), - Some(dt) => DataType::List(Box::new(dt)), + Some(dt) => DataType::List(Box::new(dt.clone())), }; let mut ca = ListChunked::with_chunk(PlSmallStr::EMPTY, arr); @@ -105,7 +109,7 @@ pub struct AnonymousOwnedListBuilder { name: PlSmallStr, builder: AnonymousBuilder<'static>, owned: Vec, - inner_dtype: DtypeMerger, + inner_dtype: Option, fast_explode: bool, } @@ -117,11 +121,17 @@ impl Default for AnonymousOwnedListBuilder { impl ListBuilderTrait for AnonymousOwnedListBuilder { fn append_series(&mut self, s: &Series) -> PolarsResult<()> { + match (s.dtype(), &self.inner_dtype) { + (DataType::Null, _) => {}, + (dt, None) => self.inner_dtype = Some(dt.clone()), + (dt, Some(set_dt)) => { + polars_ensure!(dt == set_dt, ComputeError: "dtypes don't match, got {dt}, expected: {set_dt}") + }, + } if s.is_empty() { self.append_empty(); } else { unsafe { - self.inner_dtype.update(s.dtype())?; self.builder .push_multiple(&*(s.chunks().as_ref() as *const [ArrayRef])); } @@ -138,7 +148,7 @@ impl ListBuilderTrait for AnonymousOwnedListBuilder { } fn finish(&mut self) -> ListChunked { - let inner_dtype = std::mem::take(&mut self.inner_dtype).materialize(); + let inner_dtype = std::mem::take(&mut self.inner_dtype); // Don't use self from here on out. let slf = std::mem::take(self); let inner_dtype_physical = inner_dtype @@ -166,7 +176,7 @@ impl AnonymousOwnedListBuilder { name, builder: AnonymousBuilder::new(capacity), owned: Vec::with_capacity(capacity), - inner_dtype: DtypeMerger::new(inner_dtype), + inner_dtype, fast_explode: true, } } diff --git a/crates/polars-core/src/chunked_array/builder/list/categorical.rs b/crates/polars-core/src/chunked_array/builder/list/categorical.rs deleted file mode 100644 index 1175b669bf06..000000000000 --- a/crates/polars-core/src/chunked_array/builder/list/categorical.rs +++ /dev/null @@ -1,235 +0,0 @@ -use std::hash::BuildHasher; - -use hashbrown::HashTable; -use hashbrown::hash_table::Entry; - -use super::*; - -pub fn create_categorical_chunked_listbuilder( - name: PlSmallStr, - ordering: CategoricalOrdering, - capacity: usize, - values_capacity: usize, - rev_map: Arc, -) -> Box { - match &*rev_map { - RevMapping::Local(_, h) => Box::new(ListLocalCategoricalChunkedBuilder::new( - name, - ordering, - capacity, - values_capacity, - *h, - )), - RevMapping::Global(_, _, _) => Box::new(ListGlobalCategoricalChunkedBuilder::new( - name, - ordering, - capacity, - values_capacity, - rev_map, - )), - } -} - -pub struct ListEnumCategoricalChunkedBuilder { - inner: ListPrimitiveChunkedBuilder, - ordering: CategoricalOrdering, - rev_map: RevMapping, -} - -impl ListEnumCategoricalChunkedBuilder { - pub(super) fn new( - name: PlSmallStr, - ordering: CategoricalOrdering, - capacity: usize, - values_capacity: usize, - rev_map: RevMapping, - ) -> Self { - Self { - inner: ListPrimitiveChunkedBuilder::new( - name, - capacity, - values_capacity, - DataType::UInt32, - ), - ordering, - rev_map, - } - } -} - -impl ListBuilderTrait for ListEnumCategoricalChunkedBuilder { - fn append_series(&mut self, s: &Series) -> PolarsResult<()> { - let DataType::Enum(Some(rev_map), _) = s.dtype() else { - polars_bail!(ComputeError: "expected enum type") - }; - polars_ensure!(rev_map.same_src(&self.rev_map),ComputeError: "incompatible enum types"); - self.inner.append_series(s) - } - - fn append_null(&mut self) { - self.inner.append_null() - } - - fn finish(&mut self) -> ListChunked { - let inner_dtype = DataType::Enum(Some(Arc::new(self.rev_map.clone())), self.ordering); - let mut ca = self.inner.finish(); - unsafe { ca.set_dtype(DataType::List(Box::new(inner_dtype))) } - ca - } -} - -struct ListLocalCategoricalChunkedBuilder { - inner: ListPrimitiveChunkedBuilder, - idx_lookup: HashTable, - ordering: CategoricalOrdering, - categories: MutablePlString, - categories_hash: u128, -} - -impl ListLocalCategoricalChunkedBuilder { - #[inline] - pub fn get_hash_builder() -> PlFixedStateQuality { - PlFixedStateQuality::with_seed(0) - } - - pub(super) fn new( - name: PlSmallStr, - ordering: CategoricalOrdering, - capacity: usize, - values_capacity: usize, - hash: u128, - ) -> Self { - Self { - inner: ListPrimitiveChunkedBuilder::new( - name, - capacity, - values_capacity, - DataType::UInt32, - ), - idx_lookup: HashTable::with_capacity(capacity), - ordering, - categories: MutablePlString::with_capacity(capacity), - categories_hash: hash, - } - } -} - -impl ListBuilderTrait for ListLocalCategoricalChunkedBuilder { - fn append_series(&mut self, s: &Series) -> PolarsResult<()> { - let DataType::Categorical(Some(rev_map), _) = s.dtype() else { - polars_bail!(ComputeError: "expected categorical type") - }; - let RevMapping::Local(cats_right, new_hash) = &**rev_map else { - polars_bail!(string_cache_mismatch) - }; - let ca = s.categorical().unwrap(); - - // Fast path rev_maps are compatible & lookup is initialized - if self.categories_hash == *new_hash && !self.idx_lookup.is_empty() { - return self.inner.append_series(s); - } - - let hash_builder = ListLocalCategoricalChunkedBuilder::get_hash_builder(); - - // Map the physical of the appended series to be compatible with the existing rev map - let mut idx_mapping = PlHashMap::with_capacity(ca.len()); - - for (idx, cat) in cats_right.values_iter().enumerate() { - let hash_cat = hash_builder.hash_one(cat); - let len = self.idx_lookup.len(); - - // Custom hashing / equality functions for comparing the &str to the idx - // SAFETY: index in hashmap are within bounds of categories - unsafe { - let r = self.idx_lookup.entry( - hash_cat, - |k| self.categories.value_unchecked(*k as usize) == cat, - |k| hash_builder.hash_one(self.categories.value_unchecked(*k as usize)), - ); - - match r { - Entry::Occupied(v) => { - // SAFETY: bucket is initialized. - idx_mapping.insert_unique_unchecked(idx as u32, *v.get()); - }, - Entry::Vacant(slot) => { - idx_mapping.insert_unique_unchecked(idx as u32, len as u32); - self.categories.push(Some(cat)); - slot.insert(len as u32); - }, - } - } - } - - let op = |opt_v: Option<&u32>| opt_v.map(|v| *idx_mapping.get(v).unwrap()); - // SAFETY: length is correct as we do one-one mapping over ca. - let iter = unsafe { - ca.physical() - .downcast_iter() - .flat_map(|arr| arr.iter().map(op)) - .trust_my_length(ca.len()) - }; - self.inner.append_iter(iter); - - Ok(()) - } - - fn append_null(&mut self) { - self.inner.append_null() - } - - fn finish(&mut self) -> ListChunked { - let categories: Utf8ViewArray = std::mem::take(&mut self.categories).into(); - let rev_map = RevMapping::build_local(categories); - let inner_dtype = DataType::Categorical(Some(Arc::new(rev_map)), self.ordering); - let mut ca = self.inner.finish(); - unsafe { ca.set_dtype(DataType::List(Box::new(inner_dtype))) } - ca - } -} - -struct ListGlobalCategoricalChunkedBuilder { - inner: ListPrimitiveChunkedBuilder, - ordering: CategoricalOrdering, - map_merger: GlobalRevMapMerger, -} - -impl ListGlobalCategoricalChunkedBuilder { - pub(super) fn new( - name: PlSmallStr, - ordering: CategoricalOrdering, - capacity: usize, - values_capacity: usize, - rev_map: Arc, - ) -> Self { - let inner = - ListPrimitiveChunkedBuilder::new(name, capacity, values_capacity, DataType::UInt32); - Self { - inner, - ordering, - map_merger: GlobalRevMapMerger::new(rev_map), - } - } -} - -impl ListBuilderTrait for ListGlobalCategoricalChunkedBuilder { - fn append_series(&mut self, s: &Series) -> PolarsResult<()> { - let DataType::Categorical(Some(rev_map), _) = s.dtype() else { - polars_bail!(ComputeError: "expected categorical type") - }; - self.map_merger.merge_map(rev_map)?; - self.inner.append_series(s) - } - - fn append_null(&mut self) { - self.inner.append_null() - } - - fn finish(&mut self) -> ListChunked { - let rev_map = std::mem::take(&mut self.map_merger).finish(); - let inner_dtype = DataType::Categorical(Some(rev_map), self.ordering); - let mut ca = self.inner.finish(); - unsafe { ca.set_dtype(DataType::List(Box::new(inner_dtype))) } - ca - } -} diff --git a/crates/polars-core/src/chunked_array/builder/list/dtypes.rs b/crates/polars-core/src/chunked_array/builder/list/dtypes.rs deleted file mode 100644 index 9808ae0841ca..000000000000 --- a/crates/polars-core/src/chunked_array/builder/list/dtypes.rs +++ /dev/null @@ -1,56 +0,0 @@ -use super::*; - -// Allow large enum as this shouldn't be moved much -#[allow(clippy::large_enum_variant)] -pub(super) enum DtypeMerger { - #[cfg(feature = "dtype-categorical")] - Categorical(GlobalRevMapMerger, CategoricalOrdering), - Other(Option), -} - -impl Default for DtypeMerger { - fn default() -> Self { - DtypeMerger::Other(None) - } -} - -impl DtypeMerger { - pub(super) fn new(dtype: Option) -> Self { - match dtype { - #[cfg(feature = "dtype-categorical")] - Some(DataType::Categorical(Some(rev_map), ordering)) if rev_map.is_global() => { - DtypeMerger::Categorical(GlobalRevMapMerger::new(rev_map), ordering) - }, - _ => DtypeMerger::Other(dtype), - } - } - - #[inline] - pub(super) fn update(&mut self, dtype: &DataType) -> PolarsResult<()> { - match self { - #[cfg(feature = "dtype-categorical")] - DtypeMerger::Categorical(merger, _) => { - let DataType::Categorical(Some(rev_map), _) = dtype else { - polars_bail!(ComputeError: "expected categorical rev-map") - }; - polars_ensure!(rev_map.is_global(), string_cache_mismatch); - return merger.merge_map(rev_map); - }, - DtypeMerger::Other(Some(set_dtype)) => { - polars_ensure!(set_dtype == dtype, ComputeError: "dtypes don't match, got {}, expected: {}", dtype, set_dtype) - }, - _ => {}, - } - Ok(()) - } - - pub(super) fn materialize(self) -> Option { - match self { - #[cfg(feature = "dtype-categorical")] - DtypeMerger::Categorical(merger, ordering) => { - Some(DataType::Categorical(Some(merger.finish()), ordering)) - }, - DtypeMerger::Other(dtype) => dtype, - } - } -} diff --git a/crates/polars-core/src/chunked_array/builder/list/mod.rs b/crates/polars-core/src/chunked_array/builder/list/mod.rs index c5eb91834a53..8bc8c62b93d0 100644 --- a/crates/polars-core/src/chunked_array/builder/list/mod.rs +++ b/crates/polars-core/src/chunked_array/builder/list/mod.rs @@ -1,9 +1,6 @@ mod anonymous; mod binary; mod boolean; -#[cfg(feature = "dtype-categorical")] -mod categorical; -mod dtypes; mod null; mod primitive; @@ -12,9 +9,6 @@ use arrow::legacy::array::list::AnonymousBuilder; use arrow::legacy::array::null::MutableNullArray; pub use binary::*; pub use boolean::*; -#[cfg(feature = "dtype-categorical")] -use categorical::*; -use dtypes::*; pub use null::*; pub use primitive::*; @@ -88,31 +82,6 @@ pub fn get_list_builder( list_capacity: usize, name: PlSmallStr, ) -> Box { - match inner_type_logical { - #[cfg(feature = "dtype-categorical")] - DataType::Categorical(Some(rev_map), ordering) => { - return create_categorical_chunked_listbuilder( - name, - *ordering, - list_capacity, - value_capacity, - rev_map.clone(), - ); - }, - #[cfg(feature = "dtype-categorical")] - DataType::Enum(Some(rev_map), ordering) => { - let list_builder = ListEnumCategoricalChunkedBuilder::new( - name, - *ordering, - list_capacity, - value_capacity, - (**rev_map).clone(), - ); - return Box::new(list_builder); - }, - _ => {}, - } - let physical_type = inner_type_logical.to_physical(); match &physical_type { diff --git a/crates/polars-core/src/chunked_array/builder/mod.rs b/crates/polars-core/src/chunked_array/builder/mod.rs index 36065b2ac5f8..98fb021d9ac1 100644 --- a/crates/polars-core/src/chunked_array/builder/mod.rs +++ b/crates/polars-core/src/chunked_array/builder/mod.rs @@ -1,4 +1,6 @@ mod boolean; +#[cfg(feature = "dtype-categorical")] +mod categorical; #[cfg(feature = "dtype-array")] pub mod fixed_size_list; pub mod list; @@ -11,6 +13,8 @@ use std::sync::Arc; use arrow::array::*; use arrow::bitmap::Bitmap; pub use boolean::*; +#[cfg(feature = "dtype-categorical")] +pub use categorical::*; #[cfg(feature = "dtype-array")] pub(crate) use fixed_size_list::*; pub use list::*; diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs index ac9bda1ec195..4976ff931693 100644 --- a/crates/polars-core/src/chunked_array/cast.rs +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -10,6 +10,7 @@ use super::flags::StatisticsFlags; #[cfg(feature = "dtype-datetime")] use crate::prelude::DataType::Datetime; use crate::prelude::*; +use crate::utils::handle_casting_failures; #[derive(Copy, Clone, Debug, Default, PartialEq, Hash, Eq)] #[cfg_attr(feature = "serde-lazy", derive(Serialize, Deserialize))] @@ -169,57 +170,33 @@ where return Ok(out); } match dtype { + // LEGACY + // TODO @ cat-rework: remove after exposing to/from physical functions. #[cfg(feature = "dtype-categorical")] - DataType::Categorical(_, ordering) => { - polars_ensure!( - self.dtype() == &DataType::UInt32, - ComputeError: "cannot cast numeric types to 'Categorical'" - ); - // SAFETY: - // we are guarded by the type system - let ca = unsafe { &*(self as *const ChunkedArray as *const UInt32Chunked) }; - - CategoricalChunked::from_global_indices(ca.clone(), *ordering) - .map(|ca| ca.into_series()) + DataType::Categorical(cats, _mapping) => { + let s = self.cast_with_options(&cats.physical().dtype(), options)?; + with_match_categorical_physical_type!(cats.physical(), |$C| { + // SAFETY: we are guarded by the type system. + type PhysCa = ChunkedArray<<$C as PolarsCategoricalType>::PolarsPhysical>; + let ca: &PhysCa = s.as_ref().as_ref(); + Ok(CategoricalChunked::<$C>::from_cats_and_dtype(ca.clone(), dtype.clone()) + .into_series()) + }) }, + + // LEGACY + // TODO @ cat-rework: remove after exposing to/from physical functions. #[cfg(feature = "dtype-categorical")] - DataType::Enum(rev_map, ordering) => { - let ca = match self.dtype() { - DataType::UInt32 => { - // SAFETY: we are guarded by the type system - unsafe { &*(self as *const ChunkedArray as *const UInt32Chunked) } - .clone() - }, - dt if dt.is_integer() => self - .cast_with_options(self.dtype(), options)? - .strict_cast(&DataType::UInt32)? - .u32()? - .clone(), - _ => { - polars_bail!(ComputeError: "cannot cast non integer types to 'Enum'") - }, - }; - let Some(rev_map) = rev_map else { - polars_bail!(ComputeError: "cannot cast to Enum without categories"); - }; - let categories = rev_map.get_categories(); - // Check if indices are in bounds - if let Some(m) = ChunkAgg::max(&ca) { - if m >= categories.len() as u32 { - polars_bail!(OutOfBounds: "index {} is bigger than the number of categories {}",m,categories.len()); - } - } - // SAFETY: indices are in bound - unsafe { - Ok(CategoricalChunked::from_cats_and_rev_map_unchecked( - ca.clone(), - rev_map.clone(), - true, - *ordering, - ) - .into_series()) - } + DataType::Enum(fcats, _mapping) => { + let s = self.cast_with_options(&fcats.physical().dtype(), options)?; + with_match_categorical_physical_type!(fcats.physical(), |$C| { + // SAFETY: we are guarded by the type system. + type PhysCa = ChunkedArray<<$C as PolarsCategoricalType>::PolarsPhysical>; + let ca: &PhysCa = s.as_ref().as_ref(); + Ok(CategoricalChunked::<$C>::from_cats_and_dtype(ca.clone(), dtype.clone()).into_series()) + }) }, + #[cfg(feature = "dtype-struct")] DataType::Struct(fields) => { cast_single_to_struct(self.name().clone(), &self.chunks, fields, options) @@ -258,26 +235,33 @@ where unsafe fn cast_unchecked(&self, dtype: &DataType) -> PolarsResult { match dtype { + // LEGACY + // TODO @ cat-rework: remove after exposing to/from physical functions. #[cfg(feature = "dtype-categorical")] - DataType::Categorical(Some(rev_map), ordering) - | DataType::Enum(Some(rev_map), ordering) => { - if self.dtype() == &DataType::UInt32 { - // SAFETY: - // we are guarded by the type system. - let ca = unsafe { &*(self as *const ChunkedArray as *const UInt32Chunked) }; - Ok(unsafe { - CategoricalChunked::from_cats_and_rev_map_unchecked( - ca.clone(), - rev_map.clone(), - matches!(dtype, DataType::Enum(_, _)), - *ordering, - ) - } - .into_series()) - } else { - polars_bail!(ComputeError: "cannot cast numeric types to 'Categorical'"); - } + DataType::Categorical(cats, _mapping) => { + polars_ensure!(self.dtype() == &cats.physical().dtype(), ComputeError: "cannot cast numeric types to 'Categorical'"); + with_match_categorical_physical_type!(cats.physical(), |$C| { + // SAFETY: we are guarded by the type system. + type PhysCa = ChunkedArray<<$C as PolarsCategoricalType>::PolarsPhysical>; + let ca = unsafe { &*(self as *const ChunkedArray as *const PhysCa) }; + Ok(CategoricalChunked::<$C>::from_cats_and_dtype_unchecked(ca.clone(), dtype.clone()) + .into_series()) + }) + }, + + // LEGACY + // TODO @ cat-rework: remove after exposing to/from physical functions. + #[cfg(feature = "dtype-categorical")] + DataType::Enum(fcats, _mapping) => { + polars_ensure!(self.dtype() == &fcats.physical().dtype(), ComputeError: "cannot cast numeric types to 'Enum'"); + with_match_categorical_physical_type!(fcats.physical(), |$C| { + // SAFETY: we are guarded by the type system. + type PhysCa = ChunkedArray<<$C as PolarsCategoricalType>::PolarsPhysical>; + let ca = unsafe { &*(self as *const ChunkedArray as *const PhysCa) }; + Ok(CategoricalChunked::<$C>::from_cats_and_dtype_unchecked(ca.clone(), dtype.clone()).into_series()) + }) }, + _ => self.cast_impl(dtype, CastOptions::Overflowing), } } @@ -287,31 +271,24 @@ impl ChunkCast for StringChunked { fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { match dtype { #[cfg(feature = "dtype-categorical")] - DataType::Categorical(rev_map, ordering) => match rev_map { - None => { - // SAFETY: length is correct - let iter = - unsafe { self.downcast_iter().flatten().trust_my_length(self.len()) }; - let builder = - CategoricalChunkedBuilder::new(self.name().clone(), self.len(), *ordering); - let ca = builder.drain_iter_and_finish(iter); - Ok(ca.into_series()) - }, - Some(_) => { - polars_bail!(InvalidOperation: "casting to a categorical with rev map is not allowed"); - }, + DataType::Categorical(cats, _mapping) => { + with_match_categorical_physical_type!(cats.physical(), |$C| { + Ok(CategoricalChunked::<$C>::from_str_iter(self.name().clone(), dtype.clone(), self.iter())? + .into_series()) + }) }, #[cfg(feature = "dtype-categorical")] - DataType::Enum(rev_map, ordering) => { - let Some(rev_map) = rev_map else { - polars_bail!(InvalidOperation: "cannot cast / initialize Enum without categories present") - }; - CategoricalChunked::from_string_to_enum(self, rev_map.get_categories(), *ordering) - .map(|ca| { - let mut s = ca.into_series(); - s.rename(self.name().clone()); - s - }) + DataType::Enum(fcats, _mapping) => { + let ret = with_match_categorical_physical_type!(fcats.physical(), |$C| { + CategoricalChunked::<$C>::from_str_iter(self.name().clone(), dtype.clone(), self.iter())? + .into_series() + }); + + if options.is_strict() && self.null_count() != ret.null_count() { + handle_casting_failures(&self.clone().into_series(), &ret)?; + } + + Ok(ret) }, #[cfg(feature = "dtype-struct")] DataType::Struct(fields) => { @@ -473,8 +450,9 @@ impl ChunkCast for ListChunked { List(child_type) => { match (ca.inner_dtype(), &**child_type) { (old, new) if old == new => Ok(ca.into_owned().into_series()), + // TODO @ cat-rework: can we implement this now? #[cfg(feature = "dtype-categorical")] - (dt, Categorical(None, _) | Enum(_, _)) + (dt, Categorical(_, _) | Enum(_, _)) if !matches!(dt, Categorical(_, _) | Enum(_, _) | String | Null) => { polars_bail!(InvalidOperation: "cannot cast List inner type: '{:?}' to Categorical", dt) @@ -498,6 +476,7 @@ impl ChunkCast for ListChunked { Array(child_type, width) => { let physical_type = dtype.to_physical(); + // TODO @ cat-rework: can we implement this now? // TODO!: properly implement this recursively. #[cfg(feature = "dtype-categorical")] polars_ensure!(!matches!(&**child_type, Categorical(_, _)), InvalidOperation: "array of categorical is not yet supported"); @@ -572,8 +551,9 @@ impl ChunkCast for ArrayChunked { match (ca.inner_dtype(), &**child_type) { (old, new) if old == new => Ok(ca.into_owned().into_series()), + // TODO @ cat-rework: can we implement this now? #[cfg(feature = "dtype-categorical")] - (dt, Categorical(None, _) | Enum(_, _)) if !matches!(dt, String) => { + (dt, Categorical(_, _) | Enum(_, _)) if !matches!(dt, String) => { polars_bail!(InvalidOperation: "cannot cast Array inner type: '{:?}' to dtype: {:?}", dt, child_type) }, _ => { @@ -748,15 +728,14 @@ mod test { fn test_cast_noop() { // check if we can cast categorical twice without panic let ca = StringChunked::new(PlSmallStr::from_static("foo"), &["bar", "ham"]); + let cats = Categories::global(); let out = ca .cast_with_options( - &DataType::Categorical(None, Default::default()), + &DataType::from_categories(cats.clone()), CastOptions::Strict, ) .unwrap(); - let out = out - .cast(&DataType::Categorical(None, Default::default())) - .unwrap(); + let out = out.cast(&DataType::from_categories(cats)).unwrap(); assert!(matches!(out.dtype(), &DataType::Categorical(_, _))) } } diff --git a/crates/polars-core/src/chunked_array/comparison/categorical.rs b/crates/polars-core/src/chunked_array/comparison/categorical.rs index 09573c5fbd32..b6d64e23d959 100644 --- a/crates/polars-core/src/chunked_array/comparison/categorical.rs +++ b/crates/polars-core/src/chunked_array/comparison/categorical.rs @@ -1,249 +1,339 @@ -use arrow::bitmap::Bitmap; -use arrow::legacy::utils::FromTrustedLenIterator; -use polars_compute::comparisons::TotalOrdKernel; - -use crate::chunked_array::cast::CastOptions; -use crate::prelude::nulls::replace_non_null; use crate::prelude::*; -#[cfg(feature = "dtype-categorical")] -fn cat_equality_helper<'a, Compare, Missing>( - lhs: &'a CategoricalChunked, - rhs: &'a CategoricalChunked, - missing_function: Missing, - compare_function: Compare, +fn str_to_cat_enum(map: &CategoricalMapping, s: &str) -> PolarsResult { + map.get_cat(s).ok_or_else(|| polars_err!(InvalidOperation: "conversion from `str` to `enum` failed for value \"{s}\"")) +} + +fn cat_equality_helper( + lhs: &CategoricalChunked, + rhs: &CategoricalChunked, + eq_phys: EqPhys, ) -> PolarsResult where - Compare: Fn(&'a UInt32Chunked, &'a UInt32Chunked) -> BooleanChunked, - Missing: Fn(&'a CategoricalChunked) -> BooleanChunked, + EqPhys: + Fn(&ChunkedArray, &ChunkedArray) -> BooleanChunked, { - let rev_map_l = lhs.get_rev_map(); - polars_ensure!(rev_map_l.same_src(rhs.get_rev_map()), string_cache_mismatch); - let rhs = rhs.physical(); - - // Fast path for globals - if rhs.len() == 1 && rhs.null_count() == 0 { - let rhs = rhs.get(0).unwrap(); - if rev_map_l.get_optional(rhs).is_none() { - return Ok(missing_function(lhs)); - } - } - Ok(compare_function(lhs.physical(), rhs)) + lhs.dtype().matches_schema_type(rhs.dtype())?; + Ok(eq_phys(lhs.physical(), rhs.physical())) } -fn cat_compare_helper<'a, Compare, CompareString>( - lhs: &'a CategoricalChunked, - rhs: &'a CategoricalChunked, - compare_function: Compare, - compare_str_function: CompareString, +fn cat_compare_helper( + lhs: &CategoricalChunked, + rhs: &CategoricalChunked, + cmp: Cmp, + cmp_phys: CmpPhys, ) -> PolarsResult where - Compare: Fn(&'a UInt32Chunked, &'a UInt32Chunked) -> BooleanChunked, - CompareString: Fn(&str, &str) -> bool, + Cmp: Fn(&str, &str) -> bool, + CmpPhys: + Fn(&ChunkedArray, &ChunkedArray) -> BooleanChunked, { - let rev_map_l = lhs.get_rev_map(); - let rev_map_r = rhs.get_rev_map(); - polars_ensure!(rev_map_l.same_src(rev_map_r), ComputeError: "can only compare categoricals of the same type with the same categories"); - - if lhs.is_enum() || !lhs.uses_lexical_ordering() { - Ok(compare_function(lhs.physical(), rhs.physical())) - } else { - match (lhs.len(), rhs.len()) { - (lhs_len, 1) => { - // SAFETY: physical is in range of revmap - let v = unsafe { - rhs.physical() - .get(0) - .map(|phys| rev_map_r.get_unchecked(phys)) - }; - let Some(v) = v else { - return Ok(BooleanChunked::full_null(lhs.name().clone(), lhs_len)); - }; - - Ok(lhs - .iter_str() - .map(|opt_s| opt_s.map(|s| compare_str_function(s, v))) - .collect_ca_trusted(lhs.name().clone())) - }, - (1, rhs_len) => { - // SAFETY: physical is in range of revmap - let v = unsafe { - lhs.physical() - .get(0) - .map(|phys| rev_map_l.get_unchecked(phys)) - }; - let Some(v) = v else { - return Ok(BooleanChunked::full_null(lhs.name().clone(), rhs_len)); - }; - Ok(rhs - .iter_str() - .map(|opt_s| opt_s.map(|s| compare_str_function(v, s))) - .collect_ca_trusted(lhs.name().clone())) - }, - (lhs_len, rhs_len) if lhs_len == rhs_len => Ok(lhs + lhs.dtype().matches_schema_type(rhs.dtype())?; + if lhs.is_enum() { + return Ok(cmp_phys(lhs.physical(), rhs.physical())); + } + let mapping = lhs.get_mapping(); + match (lhs.len(), rhs.len()) { + (lhs_len, 1) => { + let Some(cat) = rhs.physical().get(0) else { + return Ok(BooleanChunked::full_null(lhs.name().clone(), lhs_len)); + }; + + // SAFETY: physical is in range of the mapping. + let v = unsafe { mapping.cat_to_str_unchecked(cat.as_cat()) }; + Ok(lhs + .iter_str() + .map(|opt_s| opt_s.map(|s| cmp(s, v))) + .collect_ca_trusted(lhs.name().clone())) + }, + (1, rhs_len) => { + let Some(cat) = lhs.physical().get(0) else { + return Ok(BooleanChunked::full_null(lhs.name().clone(), rhs_len)); + }; + + // SAFETY: physical is in range of the mapping. + let v = unsafe { mapping.cat_to_str_unchecked(cat.as_cat()) }; + Ok(rhs + .iter_str() + .map(|opt_s| opt_s.map(|s| cmp(v, s))) + .collect_ca_trusted(lhs.name().clone())) + }, + (lhs_len, rhs_len) => { + assert!(lhs_len == rhs_len); + Ok(lhs .iter_str() .zip(rhs.iter_str()) .map(|(l, r)| match (l, r) { (None, _) => None, (_, None) => None, - (Some(l), Some(r)) => Some(compare_str_function(l, r)), + (Some(l), Some(r)) => Some(cmp(l, r)), }) - .collect_ca_trusted(lhs.name().clone())), - (lhs_len, rhs_len) => { - polars_bail!(ComputeError: "Columns are of unequal length: {} vs {}",lhs_len,rhs_len) - }, - } + .collect_ca_trusted(lhs.name().clone())) + }, } } -impl ChunkCompareEq<&CategoricalChunked> for CategoricalChunked { - type Item = PolarsResult; - - fn equal(&self, rhs: &CategoricalChunked) -> Self::Item { - cat_equality_helper( - self, - rhs, - |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, false), - UInt32Chunked::equal, - ) +fn cat_str_equality_helper( + lhs: &CategoricalChunked, + rhs: &StringChunked, + eq: Eq, + eq_phys_scalar: EqPhysScalar, + eq_str_scalar: EqStrScalar, +) -> BooleanChunked +where + Eq: Fn(Option<&str>, Option<&str>) -> Option, + EqPhysScalar: Fn(&ChunkedArray, T::Native) -> BooleanChunked, + EqStrScalar: Fn(&StringChunked, &str) -> BooleanChunked, +{ + let mapping = lhs.get_mapping(); + let null_eq = eq(None, None); + match (lhs.len(), rhs.len()) { + (lhs_len, 1) => { + let Some(s) = rhs.get(0) else { + return match null_eq { + Some(true) => lhs.physical().is_null(), + Some(false) => lhs.physical().is_not_null(), + None => BooleanChunked::full_null(lhs.name().clone(), lhs_len), + }; + }; + + cat_str_scalar_equality_helper(lhs, s, null_eq, &eq_phys_scalar) + }, + (1, rhs_len) => { + let Some(cat) = lhs.physical().get(0) else { + return match null_eq { + Some(true) => rhs.is_null().with_name(lhs.name().clone()), + Some(false) => rhs.is_not_null().with_name(lhs.name().clone()), + None => BooleanChunked::full_null(lhs.name().clone(), rhs_len), + }; + }; + + // SAFETY: physical is in range of the mapping. + let s = unsafe { mapping.cat_to_str_unchecked(cat.as_cat()) }; + eq_str_scalar(rhs, s).with_name(lhs.name().clone()) + }, + (lhs_len, rhs_len) => { + assert!(lhs_len == rhs_len); + lhs.iter_str() + .zip(rhs.iter()) + .map(|(l, r)| eq(l, r)) + .collect_ca_trusted(lhs.name().clone()) + }, } +} - fn equal_missing(&self, rhs: &CategoricalChunked) -> Self::Item { - cat_equality_helper( - self, - rhs, - |lhs| BooleanChunked::full(lhs.name().clone(), false, lhs.len()), - UInt32Chunked::equal_missing, - ) +fn cat_str_compare_helper( + lhs: &CategoricalChunked, + rhs: &StringChunked, + cmp: Cmp, + cmp_str_scalar: CmpStrScalar, +) -> BooleanChunked +where + Cmp: Fn(&str, &str) -> bool, + CmpStrScalar: Fn(&str, &StringChunked) -> BooleanChunked, +{ + let mapping = lhs.get_mapping(); + match (lhs.len(), rhs.len()) { + (lhs_len, 1) => { + let Some(s) = rhs.get(0) else { + return BooleanChunked::full_null(lhs.name().clone(), lhs_len); + }; + cat_str_scalar_compare_helper(lhs, s, cmp) + }, + (1, rhs_len) => { + let Some(cat) = lhs.physical().get(0) else { + return BooleanChunked::full_null(lhs.name().clone(), rhs_len); + }; + + // SAFETY: physical is in range of the mapping. + let s = unsafe { mapping.cat_to_str_unchecked(cat.as_cat()) }; + cmp_str_scalar(s, rhs).with_name(lhs.name().clone()) + }, + (lhs_len, rhs_len) => { + assert!(lhs_len == rhs_len); + lhs.iter_str() + .zip(rhs.iter()) + .map(|(l, r)| match (l, r) { + (None, _) => None, + (_, None) => None, + (Some(l), Some(r)) => Some(cmp(l, r)), + }) + .collect_ca_trusted(lhs.name().clone()) + }, } +} - fn not_equal(&self, rhs: &CategoricalChunked) -> Self::Item { - cat_equality_helper( - self, - rhs, - |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, true), - UInt32Chunked::not_equal, - ) +fn cat_str_phys_compare_helper( + lhs: &CategoricalChunked, + rhs: &StringChunked, + cmp: Cmp, +) -> PolarsResult +where + Cmp: Fn(T::Native, T::Native) -> bool, +{ + let mapping = lhs.get_mapping(); + match (lhs.len(), rhs.len()) { + (lhs_len, 1) => { + let Some(s) = rhs.get(0) else { + return Ok(BooleanChunked::full_null(lhs.name().clone(), lhs_len)); + }; + cat_str_scalar_phys_compare_helper(lhs, s, cmp) + }, + (1, rhs_len) => { + let Some(cat) = lhs.physical().get(0) else { + return Ok(BooleanChunked::full_null(lhs.name().clone(), rhs_len)); + }; + + rhs.iter() + .map(|opt_r| { + if let Some(r) = opt_r { + let r = T::Native::from_cat(str_to_cat_enum(mapping, r)?); + Ok(Some(cmp(cat, r))) + } else { + Ok(None) + } + }) + .try_collect_ca_trusted(lhs.name().clone()) + }, + (lhs_len, rhs_len) => { + assert!(lhs_len == rhs_len); + lhs.physical() + .iter() + .zip(rhs.iter()) + .map(|(l, r)| match (l, r) { + (None, _) => Ok(None), + (_, None) => Ok(None), + (Some(l), Some(r)) => { + let r = T::Native::from_cat(str_to_cat_enum(mapping, r)?); + Ok(Some(cmp(l, r))) + }, + }) + .try_collect_ca_trusted(lhs.name().clone()) + }, } +} - fn not_equal_missing(&self, rhs: &CategoricalChunked) -> Self::Item { - cat_equality_helper( - self, - rhs, - |lhs| BooleanChunked::full(lhs.name().clone(), true, lhs.len()), - UInt32Chunked::not_equal_missing, - ) - } +fn cat_str_scalar_equality_helper( + lhs: &CategoricalChunked, + rhs: &str, + null_eq: Option, + eq_phys_scalar: EqPhysScalar, +) -> BooleanChunked +where + EqPhysScalar: Fn(&ChunkedArray, T::Native) -> BooleanChunked, +{ + let mapping = lhs.get_mapping(); + let Some(cat) = mapping.get_cat(rhs) else { + return match null_eq { + Some(true) => lhs.physical().is_null(), + Some(false) => lhs.physical().is_not_null(), + None => BooleanChunked::full_null(lhs.name().clone(), lhs.len()), + }; + }; + + eq_phys_scalar(lhs.physical(), T::Native::from_cat(cat)) } -impl ChunkCompareIneq<&CategoricalChunked> for CategoricalChunked { +fn cat_str_scalar_compare_helper( + lhs: &CategoricalChunked, + rhs: &str, + cmp: Cmp, +) -> BooleanChunked +where + Cmp: Fn(&str, &str) -> bool, +{ + lhs.iter_str() + .map(|opt_l| opt_l.map(|l| cmp(l, rhs))) + .collect_ca_trusted(lhs.name().clone()) +} + +fn cat_str_scalar_phys_compare_helper( + lhs: &CategoricalChunked, + rhs: &str, + cmp: Cmp, +) -> PolarsResult +where + Cmp: Fn(T::Native, T::Native) -> bool, +{ + let r = T::Native::from_cat(str_to_cat_enum(lhs.get_mapping(), rhs)?); + Ok(lhs + .physical() + .iter() + .map(|opt_l| opt_l.map(|l| cmp(l, r))) + .collect_ca_trusted(lhs.name().clone())) +} + +impl ChunkCompareEq<&CategoricalChunked> for CategoricalChunked +where + ChunkedArray: + for<'a> ChunkCompareEq<&'a ChunkedArray, Item = BooleanChunked>, +{ type Item = PolarsResult; - fn gt(&self, rhs: &CategoricalChunked) -> Self::Item { - cat_compare_helper(self, rhs, UInt32Chunked::gt, |l, r| l > r) + fn equal(&self, rhs: &Self) -> Self::Item { + cat_equality_helper(self, rhs, |l, r| l.equal(r)) } - fn gt_eq(&self, rhs: &CategoricalChunked) -> Self::Item { - cat_compare_helper(self, rhs, UInt32Chunked::gt_eq, |l, r| l >= r) + fn equal_missing(&self, rhs: &Self) -> Self::Item { + cat_equality_helper(self, rhs, |l, r| l.equal_missing(r)) } - fn lt(&self, rhs: &CategoricalChunked) -> Self::Item { - cat_compare_helper(self, rhs, UInt32Chunked::lt, |l, r| l < r) + fn not_equal(&self, rhs: &Self) -> Self::Item { + cat_equality_helper(self, rhs, |l, r| l.not_equal(r)) } - fn lt_eq(&self, rhs: &CategoricalChunked) -> Self::Item { - cat_compare_helper(self, rhs, UInt32Chunked::lt_eq, |l, r| l <= r) + fn not_equal_missing(&self, rhs: &Self) -> Self::Item { + cat_equality_helper(self, rhs, |l, r| l.not_equal_missing(r)) } } -fn cat_str_equality_helper<'a, Missing, CompareNone, CompareCat, ComparePhys, CompareString>( - lhs: &'a CategoricalChunked, - rhs: &'a StringChunked, - missing_function: Missing, - compare_to_none: CompareNone, - cat_compare_function: CompareCat, - phys_compare_function: ComparePhys, - str_compare_function: CompareString, -) -> PolarsResult +impl ChunkCompareIneq<&CategoricalChunked> for CategoricalChunked where - Missing: Fn(&CategoricalChunked) -> BooleanChunked, - CompareNone: Fn(&CategoricalChunked) -> BooleanChunked, - ComparePhys: Fn(&UInt32Chunked, u32) -> BooleanChunked, - CompareCat: Fn(&CategoricalChunked, &CategoricalChunked) -> PolarsResult, - CompareString: Fn(&StringChunked, &'a StringChunked) -> BooleanChunked, + ChunkedArray: + for<'a> ChunkCompareIneq<&'a ChunkedArray, Item = BooleanChunked>, { - if lhs.is_enum() { - let rhs_cat = rhs.clone().into_series().strict_cast(lhs.dtype())?; - cat_compare_function(lhs, rhs_cat.categorical().unwrap()) - } else if rhs.len() == 1 { - match rhs.get(0) { - None => Ok(compare_to_none(lhs)), - Some(s) => { - cat_single_str_equality_helper(lhs, s, missing_function, phys_compare_function) - }, - } - } else { - let lhs_string = lhs.cast_with_options(&DataType::String, CastOptions::NonStrict)?; - Ok(str_compare_function(lhs_string.str().unwrap(), rhs)) + type Item = PolarsResult; + + fn gt(&self, rhs: &CategoricalChunked) -> Self::Item { + cat_compare_helper(self, rhs, |l, r| l > r, |l, r| l.gt(r)) } -} -fn cat_str_compare_helper<'a, CompareCat, ComparePhys, CompareStringSingle, CompareString>( - lhs: &'a CategoricalChunked, - rhs: &'a StringChunked, - cat_compare_function: CompareCat, - phys_compare_function: ComparePhys, - str_single_compare_function: CompareStringSingle, - str_compare_function: CompareString, -) -> PolarsResult -where - CompareStringSingle: Fn(&Utf8ViewArray, &str) -> Bitmap, - ComparePhys: Fn(&UInt32Chunked, u32) -> BooleanChunked, - CompareCat: Fn(&CategoricalChunked, &CategoricalChunked) -> PolarsResult, - CompareString: Fn(&StringChunked, &'a StringChunked) -> BooleanChunked, -{ - if lhs.is_enum() { - let rhs_cat = rhs.clone().into_series().strict_cast(lhs.dtype())?; - cat_compare_function(lhs, rhs_cat.categorical().unwrap()) - } else if rhs.len() == 1 { - match rhs.get(0) { - None => Ok(BooleanChunked::full_null(lhs.name().clone(), lhs.len())), - Some(s) => cat_single_str_compare_helper( - lhs, - s, - phys_compare_function, - str_single_compare_function, - ), - } - } else { - let lhs_string = lhs.cast_with_options(&DataType::String, CastOptions::NonStrict)?; - Ok(str_compare_function(lhs_string.str().unwrap(), rhs)) + fn gt_eq(&self, rhs: &CategoricalChunked) -> Self::Item { + cat_compare_helper(self, rhs, |l, r| l >= r, |l, r| l.gt_eq(r)) + } + + fn lt(&self, rhs: &CategoricalChunked) -> Self::Item { + cat_compare_helper(self, rhs, |l, r| l < r, |l, r| l.lt(r)) + } + + fn lt_eq(&self, rhs: &CategoricalChunked) -> Self::Item { + cat_compare_helper(self, rhs, |l, r| l <= r, |l, r| l.lt_eq(r)) } } -impl ChunkCompareEq<&StringChunked> for CategoricalChunked { - type Item = PolarsResult; +impl ChunkCompareEq<&StringChunked> for CategoricalChunked +where + ChunkedArray: for<'a> ChunkCompareEq, +{ + type Item = BooleanChunked; fn equal(&self, rhs: &StringChunked) -> Self::Item { cat_str_equality_helper( self, rhs, - |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, false), - |lhs| BooleanChunked::full_null(lhs.name().clone(), lhs.len()), - |s1, s2| CategoricalChunked::equal(s1, s2), - UInt32Chunked::equal, - StringChunked::equal, + |l, r| l.zip(r).map(|(l, r)| l == r), + |l, c| l.equal(c), + |r, c| r.equal(c), ) } + fn equal_missing(&self, rhs: &StringChunked) -> Self::Item { cat_str_equality_helper( self, rhs, - |lhs| BooleanChunked::full(lhs.name().clone(), false, lhs.len()), - |lhs| lhs.physical().is_null(), - |s1, s2| CategoricalChunked::equal_missing(s1, s2), - UInt32Chunked::equal_missing, - StringChunked::equal_missing, + |l, r| Some(l == r), + |l, c| l.equal_missing(c), + |r, c| r.equal_missing(c), ) } @@ -251,231 +341,118 @@ impl ChunkCompareEq<&StringChunked> for CategoricalChunked { cat_str_equality_helper( self, rhs, - |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, true), - |lhs| BooleanChunked::full_null(lhs.name().clone(), lhs.len()), - |s1, s2| CategoricalChunked::not_equal(s1, s2), - UInt32Chunked::not_equal, - StringChunked::not_equal, + |l, r| l.zip(r).map(|(l, r)| l != r), + |l, c| l.not_equal(c), + |r, c| r.not_equal(c), ) } + fn not_equal_missing(&self, rhs: &StringChunked) -> Self::Item { cat_str_equality_helper( self, rhs, - |lhs| BooleanChunked::full(lhs.name().clone(), true, lhs.len()), - |lhs| !lhs.physical().is_null(), - |s1, s2| CategoricalChunked::not_equal_missing(s1, s2), - UInt32Chunked::not_equal_missing, - StringChunked::not_equal_missing, + |l, r| Some(l != r), + |l, c| l.not_equal_missing(c), + |r, c| r.not_equal_missing(c), ) } } -impl ChunkCompareIneq<&StringChunked> for CategoricalChunked { +impl ChunkCompareIneq<&StringChunked> for CategoricalChunked { type Item = PolarsResult; fn gt(&self, rhs: &StringChunked) -> Self::Item { - cat_str_compare_helper( - self, - rhs, - |s1, s2| CategoricalChunked::gt(s1, s2), - UInt32Chunked::gt, - Utf8ViewArray::tot_gt_kernel_broadcast, - StringChunked::gt, - ) + if self.is_enum() { + cat_str_phys_compare_helper(self, rhs, |l, r| l > r) + } else { + Ok(cat_str_compare_helper( + self, + rhs, + |l, r| l > r, + |c, r| r.lt(c), + )) + } } fn gt_eq(&self, rhs: &StringChunked) -> Self::Item { - cat_str_compare_helper( - self, - rhs, - |s1, s2| CategoricalChunked::gt_eq(s1, s2), - UInt32Chunked::gt_eq, - Utf8ViewArray::tot_ge_kernel_broadcast, - StringChunked::gt_eq, - ) + if self.is_enum() { + cat_str_phys_compare_helper(self, rhs, |l, r| l >= r) + } else { + Ok(cat_str_compare_helper( + self, + rhs, + |l, r| l >= r, + |c, r| r.lt_eq(c), + )) + } } fn lt(&self, rhs: &StringChunked) -> Self::Item { - cat_str_compare_helper( - self, - rhs, - |s1, s2| CategoricalChunked::lt(s1, s2), - UInt32Chunked::lt, - Utf8ViewArray::tot_lt_kernel_broadcast, - StringChunked::lt, - ) + if self.is_enum() { + cat_str_phys_compare_helper(self, rhs, |l, r| l < r) + } else { + Ok(cat_str_compare_helper( + self, + rhs, + |l, r| l < r, + |c, r| r.gt(c), + )) + } } fn lt_eq(&self, rhs: &StringChunked) -> Self::Item { - cat_str_compare_helper( - self, - rhs, - |s1, s2| CategoricalChunked::lt_eq(s1, s2), - UInt32Chunked::lt_eq, - Utf8ViewArray::tot_le_kernel_broadcast, - StringChunked::lt_eq, - ) - } -} - -fn cat_single_str_equality_helper<'a, ComparePhys, Missing>( - lhs: &'a CategoricalChunked, - rhs: &'a str, - missing_function: Missing, - phys_compare_function: ComparePhys, -) -> PolarsResult -where - ComparePhys: Fn(&UInt32Chunked, u32) -> BooleanChunked, - Missing: Fn(&CategoricalChunked) -> BooleanChunked, -{ - let rev_map = lhs.get_rev_map(); - let idx = rev_map.find(rhs); - if lhs.is_enum() { - let Some(idx) = idx else { - polars_bail!( - not_in_enum, - value = rhs, - categories = rev_map.get_categories() - ) - }; - Ok(phys_compare_function(lhs.physical(), idx)) - } else { - match rev_map.find(rhs) { - None => Ok(missing_function(lhs)), - Some(idx) => Ok(phys_compare_function(lhs.physical(), idx)), + if self.is_enum() { + cat_str_phys_compare_helper(self, rhs, |l, r| l <= r) + } else { + Ok(cat_str_compare_helper( + self, + rhs, + |l, r| l <= r, + |c, r| r.gt_eq(c), + )) } } } -fn cat_single_str_compare_helper<'a, ComparePhys, CompareStringSingle>( - lhs: &'a CategoricalChunked, - rhs: &'a str, - phys_compare_function: ComparePhys, - str_single_compare_function: CompareStringSingle, -) -> PolarsResult +impl ChunkCompareEq<&str> for CategoricalChunked where - CompareStringSingle: Fn(&Utf8ViewArray, &str) -> Bitmap, - ComparePhys: Fn(&UInt32Chunked, u32) -> BooleanChunked, + ChunkedArray: for<'a> ChunkCompareEq, { - let rev_map = lhs.get_rev_map(); - if lhs.is_enum() { - match rev_map.find(rhs) { - None => { - polars_bail!( - not_in_enum, - value = rhs, - categories = rev_map.get_categories() - ) - }, - Some(idx) => Ok(phys_compare_function(lhs.physical(), idx)), - } - } else { - // Apply comparison on categories map and then do a lookup - let bitmap = str_single_compare_function(lhs.get_rev_map().get_categories(), rhs); - - let mask = match lhs.get_rev_map().as_ref() { - RevMapping::Local(_, _) => { - BooleanChunked::from_iter_trusted_length(lhs.physical().into_iter().map( - |opt_idx| { - // SAFETY: indexing into bitmap with same length as original array - opt_idx.map(|idx| unsafe { bitmap.get_bit_unchecked(idx as usize) }) - }, - )) - }, - RevMapping::Global(idx_map, _, _) => { - BooleanChunked::from_iter_trusted_length(lhs.physical().into_iter().map( - |opt_idx| { - // SAFETY: indexing into bitmap with same length as original array - opt_idx.map(|idx| unsafe { - let idx = *idx_map.get(&idx).unwrap(); - bitmap.get_bit_unchecked(idx as usize) - }) - }, - )) - }, - }; - - Ok(mask.with_name(lhs.name().clone())) - } -} - -impl ChunkCompareEq<&str> for CategoricalChunked { - type Item = PolarsResult; + type Item = BooleanChunked; fn equal(&self, rhs: &str) -> Self::Item { - cat_single_str_equality_helper( - self, - rhs, - |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, false), - UInt32Chunked::equal, - ) + cat_str_scalar_equality_helper(self, rhs, None, |l, c| l.equal(c)) } fn equal_missing(&self, rhs: &str) -> Self::Item { - cat_single_str_equality_helper( - self, - rhs, - |lhs| BooleanChunked::full(lhs.name().clone(), false, lhs.len()), - UInt32Chunked::equal_missing, - ) + cat_str_scalar_equality_helper(self, rhs, Some(true), |l, c| l.equal_missing(c)) } fn not_equal(&self, rhs: &str) -> Self::Item { - cat_single_str_equality_helper( - self, - rhs, - |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, true), - UInt32Chunked::not_equal, - ) + cat_str_scalar_equality_helper(self, rhs, None, |r, c| r.not_equal(c)) } fn not_equal_missing(&self, rhs: &str) -> Self::Item { - cat_single_str_equality_helper( - self, - rhs, - |lhs| BooleanChunked::full(lhs.name().clone(), true, lhs.len()), - UInt32Chunked::equal_missing, - ) + cat_str_scalar_equality_helper(self, rhs, Some(false), |l, c| l.not_equal_missing(c)) } } -impl ChunkCompareIneq<&str> for CategoricalChunked { - type Item = PolarsResult; +impl ChunkCompareIneq<&str> for CategoricalChunked { + type Item = BooleanChunked; fn gt(&self, rhs: &str) -> Self::Item { - cat_single_str_compare_helper( - self, - rhs, - UInt32Chunked::gt, - Utf8ViewArray::tot_gt_kernel_broadcast, - ) + cat_str_scalar_compare_helper(self, rhs, |l, r| l > r) } fn gt_eq(&self, rhs: &str) -> Self::Item { - cat_single_str_compare_helper( - self, - rhs, - UInt32Chunked::gt_eq, - Utf8ViewArray::tot_ge_kernel_broadcast, - ) + cat_str_scalar_compare_helper(self, rhs, |l, r| l >= r) } fn lt(&self, rhs: &str) -> Self::Item { - cat_single_str_compare_helper( - self, - rhs, - UInt32Chunked::lt, - Utf8ViewArray::tot_lt_kernel_broadcast, - ) + cat_str_scalar_compare_helper(self, rhs, |l, r| l < r) } fn lt_eq(&self, rhs: &str) -> Self::Item { - cat_single_str_compare_helper( - self, - rhs, - UInt32Chunked::lt_eq, - Utf8ViewArray::tot_le_kernel_broadcast, - ) + cat_str_scalar_compare_helper(self, rhs, |l, r| l <= r) } } diff --git a/crates/polars-core/src/chunked_array/from.rs b/crates/polars-core/src/chunked_array/from.rs index f152b0323f86..ab4d2b9b9ce3 100644 --- a/crates/polars-core/src/chunked_array/from.rs +++ b/crates/polars-core/src/chunked_array/from.rs @@ -1,85 +1,12 @@ -use arrow::compute::concatenate::concatenate_unchecked; - use super::*; #[allow(clippy::all)] fn from_chunks_list_dtype(chunks: &mut Vec, dtype: DataType) -> DataType { // ensure we don't get List - let dtype = if let Some(arr) = chunks.get(0) { + if let Some(arr) = chunks.get(0) { DataType::from_arrow_dtype(arr.dtype()) } else { dtype - }; - - match dtype { - #[cfg(feature = "dtype-categorical")] - // arrow dictionaries are not nested as dictionaries, but only by their keys, so we must - // change the list-value array to the keys and store the dictionary values in the datatype. - // if a global string cache is set, we also must modify the keys. - DataType::List(inner) - if matches!( - *inner, - DataType::Categorical(None, _) | DataType::Enum(None, _) - ) => - { - let array = concatenate_unchecked(chunks).unwrap(); - let list_arr = array.as_any().downcast_ref::>().unwrap(); - let values_arr = list_arr.values(); - let cat = unsafe { - Series::_try_from_arrow_unchecked( - PlSmallStr::EMPTY, - vec![values_arr.clone()], - values_arr.dtype(), - ) - .unwrap() - }; - - // we nest only the physical representation - // the mapping is still in our rev-map - let arrow_dtype = ListArray::::default_datatype(ArrowDataType::UInt32); - let new_array = ListArray::new( - arrow_dtype, - list_arr.offsets().clone(), - cat.array_ref(0).clone(), - list_arr.validity().cloned(), - ); - chunks.clear(); - chunks.push(Box::new(new_array)); - DataType::List(Box::new(cat.dtype().clone())) - }, - #[cfg(all(feature = "dtype-array", feature = "dtype-categorical"))] - DataType::Array(inner, width) - if matches!( - *inner, - DataType::Categorical(None, _) | DataType::Enum(None, _) - ) => - { - let array = concatenate_unchecked(chunks).unwrap(); - let list_arr = array.as_any().downcast_ref::().unwrap(); - let values_arr = list_arr.values(); - let cat = unsafe { - Series::_try_from_arrow_unchecked( - PlSmallStr::EMPTY, - vec![values_arr.clone()], - values_arr.dtype(), - ) - .unwrap() - }; - - // we nest only the physical representation - // the mapping is still in our rev-map - let arrow_dtype = FixedSizeListArray::default_datatype(ArrowDataType::UInt32, width); - let new_array = FixedSizeListArray::new( - arrow_dtype, - values_arr.len(), - cat.array_ref(0).clone(), - list_arr.validity().cloned(), - ); - chunks.clear(); - chunks.push(Box::new(new_array)); - DataType::Array(Box::new(cat.dtype().clone()), width) - }, - _ => dtype, } } diff --git a/crates/polars-core/src/chunked_array/logical/categorical.rs b/crates/polars-core/src/chunked_array/logical/categorical.rs new file mode 100644 index 000000000000..4feef7ec8436 --- /dev/null +++ b/crates/polars-core/src/chunked_array/logical/categorical.rs @@ -0,0 +1,286 @@ +use std::marker::PhantomData; + +use arrow::bitmap::BitmapBuilder; +use num_traits::Zero; + +use crate::chunked_array::cast::CastOptions; +use crate::chunked_array::flags::StatisticsFlags; +use crate::chunked_array::ops::ChunkFullNull; +use crate::prelude::*; +use crate::series::IsSorted; +use crate::utils::handle_casting_failures; + +pub type CategoricalChunked = Logical::PolarsPhysical>; +pub type Categorical8Chunked = CategoricalChunked; +pub type Categorical16Chunked = CategoricalChunked; +pub type Categorical32Chunked = CategoricalChunked; + +pub trait CategoricalPhysicalDtypeExt { + fn dtype(&self) -> DataType; +} + +impl CategoricalPhysicalDtypeExt for CategoricalPhysical { + fn dtype(&self) -> DataType { + match self { + Self::U8 => DataType::UInt8, + Self::U16 => DataType::UInt16, + Self::U32 => DataType::UInt32, + } + } +} + +impl CategoricalChunked { + pub fn is_enum(&self) -> bool { + matches!(self.dtype(), DataType::Enum(_, _)) + } + + pub(crate) fn get_flags(&self) -> StatisticsFlags { + self.phys.get_flags() + } + + /// Set flags for the ChunkedArray. + pub(crate) fn set_flags(&mut self, mut flags: StatisticsFlags) { + // We should not set the sorted flag if we are sorting in lexical order. + if self.uses_lexical_ordering() { + flags.set_sorted(IsSorted::Not) + } + self.physical_mut().set_flags(flags) + } + + /// Return whether or not the [`CategoricalChunked`] uses the lexical order + /// of the string values when sorting. + pub fn uses_lexical_ordering(&self) -> bool { + !self.is_enum() + } + + pub fn full_null_with_dtype(name: PlSmallStr, length: usize, dtype: DataType) -> Self { + let phys = + ChunkedArray::<::PolarsPhysical>::full_null(name, length); + unsafe { Self::from_cats_and_dtype_unchecked(phys, dtype) } + } + + /// Create a [`CategoricalChunked`] from a physical array and dtype. + /// + /// Checks that all the category ids are valid, mapping invalid ones to nulls. + pub fn from_cats_and_dtype( + mut cat_ids: ChunkedArray, + dtype: DataType, + ) -> Self { + let (DataType::Enum(_, mapping) | DataType::Categorical(_, mapping)) = &dtype else { + panic!("from_cats_and_dtype called on non-categorical type") + }; + assert!(dtype.cat_physical().ok() == Some(T::physical())); + + unsafe { + let mut validity = BitmapBuilder::new(); + for arr in cat_ids.downcast_iter_mut() { + validity.reserve(arr.len()); + if arr.has_nulls() { + for opt_cat_id in arr.iter() { + if let Some(cat_id) = opt_cat_id { + validity.push_unchecked(mapping.cat_to_str(cat_id.as_cat()).is_some()); + } else { + validity.push_unchecked(false); + } + } + } else { + for cat_id in arr.values_iter() { + validity.push_unchecked(mapping.cat_to_str(cat_id.as_cat()).is_some()); + } + } + + if arr.null_count() != validity.unset_bits() { + arr.set_validity(core::mem::take(&mut validity).into_opt_validity()); + } else { + validity.clear(); + } + } + } + + Self { + phys: cat_ids, + dtype, + _phantom: PhantomData, + } + } + + /// Create a [`CategoricalChunked`] from a physical array and dtype. + /// + /// # Safety + /// It's not checked that the indices are in-bounds or that the dtype is correct. + pub unsafe fn from_cats_and_dtype_unchecked( + cat_ids: ChunkedArray, + dtype: DataType, + ) -> Self { + debug_assert!(dtype.cat_physical().ok() == Some(T::physical())); + + Self { + phys: cat_ids, + dtype, + _phantom: PhantomData, + } + } + + /// Get a reference to the mapping of categorical types to the string values. + pub fn get_mapping(&self) -> &Arc { + let (DataType::Categorical(_, mapping) | DataType::Enum(_, mapping)) = self.dtype() else { + unreachable!() + }; + mapping + } + + /// Create an [`Iterator`] that iterates over the `&str` values of the [`CategoricalChunked`]. + pub fn iter_str(&self) -> impl PolarsIterator> { + let mapping = self.get_mapping(); + self.phys + .iter() + .map(|cat| unsafe { Some(mapping.cat_to_str_unchecked(cat?.as_cat())) }) + } + + /// Converts from strings to this CategoricalChunked. + /// + /// If this dtype is an Enum any non-existing strings get mapped to null. + pub fn from_str_iter<'a, I: IntoIterator>>( + name: PlSmallStr, + dtype: DataType, + strings: I, + ) -> PolarsResult { + let strings = strings.into_iter(); + + let hint = strings.size_hint().0; + let mut cat_ids = Vec::with_capacity(hint); + let mut validity = BitmapBuilder::with_capacity(hint); + + match &dtype { + DataType::Categorical(cats, mapping) => { + assert!(cats.physical() == T::physical()); + for opt_s in strings { + cat_ids.push(if let Some(s) = opt_s { + T::Native::from_cat(mapping.insert_cat(s)?) + } else { + T::Native::zero() + }); + validity.push(opt_s.is_some()); + } + }, + DataType::Enum(fcats, mapping) => { + assert!(fcats.physical() == T::physical()); + for opt_s in strings { + cat_ids.push(if let Some(cat) = opt_s.and_then(|s| mapping.get_cat(s)) { + validity.push(true); + T::Native::from_cat(cat) + } else { + validity.push(false); + T::Native::zero() + }); + } + }, + _ => panic!("from_strings_and_dtype_strict called on non-categorical type"), + } + + let arr = ::Array::from_vec(cat_ids) + .with_validity(validity.into_opt_validity()); + let phys = ChunkedArray::::with_chunk(name, arr); + Ok(unsafe { Self::from_cats_and_dtype_unchecked(phys, dtype) }) + } + + pub fn to_arrow(&self, compat_level: CompatLevel) -> DictionaryArray { + let keys = self.physical().rechunk(); + let keys = keys.downcast_as_array(); + let values = self + .get_mapping() + .to_arrow(compat_level != CompatLevel::oldest()); + let values_dtype = Box::new(values.dtype().clone()); + let dtype = + ArrowDataType::Dictionary(::KEY_TYPE, values_dtype, false); + unsafe { DictionaryArray::try_new_unchecked(dtype, keys.clone(), values).unwrap() } + } +} + +impl LogicalType for CategoricalChunked { + fn dtype(&self) -> &DataType { + &self.dtype + } + + fn get_any_value(&self, i: usize) -> PolarsResult> { + polars_ensure!(i < self.len(), oob = i, self.len()); + Ok(unsafe { self.get_any_value_unchecked(i) }) + } + + unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> { + match self.phys.get_unchecked(i) { + Some(i) => match &self.dtype { + DataType::Enum(_, mapping) => AnyValue::Enum(i.as_cat(), mapping), + DataType::Categorical(_, mapping) => AnyValue::Categorical(i.as_cat(), mapping), + _ => unreachable!(), + }, + None => AnyValue::Null, + } + } + + fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + if &self.dtype == dtype { + return Ok(self.clone().into_series()); + } + + match dtype { + DataType::String => { + let mapping = self.get_mapping(); + + // TODO @ cat-rework:, if len >= mapping.upper_bound(), cast categories to ViewArray, then construct array of Views. + + let mut builder = StringChunkedBuilder::new(self.phys.name().clone(), self.len()); + let to_str = |cat_id: CatSize| unsafe { mapping.cat_to_str_unchecked(cat_id) }; + if !self.phys.has_nulls() { + for cat_id in self.phys.into_no_null_iter() { + builder.append_value(to_str(cat_id.as_cat())); + } + } else { + for opt_cat_id in self.phys.into_iter() { + let opt_cat_id: Option<_> = opt_cat_id; + builder.append_option(opt_cat_id.map(|c| to_str(c.as_cat()))); + } + } + + let ca = builder.finish(); + Ok(ca.into_series()) + }, + + DataType::Enum(fcats, _mapping) => { + // TODO @ cat-rework: if len >= self.mapping().upper_bound(), remap categories then index into array. + let ret = with_match_categorical_physical_type!(fcats.physical(), |$C| { + CategoricalChunked::<$C>::from_str_iter( + self.name().clone(), + dtype.clone(), + self.iter_str() + )?.into_series() + }); + + if options.is_strict() && self.null_count() != ret.null_count() { + handle_casting_failures(&self.clone().into_series(), &ret)?; + } + + Ok(ret) + }, + + DataType::Categorical(cats, _mapping) => { + // TODO @ cat-rework: if len >= self.mapping().upper_bound(), remap categories then index into array. + Ok( + with_match_categorical_physical_type!(cats.physical(), |$C| { + CategoricalChunked::<$C>::from_str_iter( + self.name().clone(), + dtype.clone(), + self.iter_str() + )?.into_series() + }), + ) + }, + + // LEGACY + // TODO @ cat-rework: remove after exposing to/from physical functions. + dt if dt.is_integer() => self.phys.clone().cast_with_options(dtype, options), + + _ => polars_bail!(ComputeError: "cannot cast categorical types to {dtype:?}"), + } + } +} diff --git a/crates/polars-core/src/chunked_array/logical/categorical/builder.rs b/crates/polars-core/src/chunked_array/logical/categorical/builder.rs deleted file mode 100644 index e44e0cda8748..000000000000 --- a/crates/polars-core/src/chunked_array/logical/categorical/builder.rs +++ /dev/null @@ -1,507 +0,0 @@ -#![allow(unsafe_op_in_unsafe_fn)] -use std::hash::BuildHasher; - -use arrow::array::*; -use arrow::legacy::trusted_len::TrustedLenPush; -use hashbrown::hash_map::Entry; -use hashbrown::hash_table::{Entry as HTEntry, HashTable}; -use polars_utils::itertools::Itertools; - -use crate::hashing::_HASHMAP_INIT_SIZE; -use crate::prelude::*; -use crate::{POOL, StringCache, using_string_cache}; - -pub struct CategoricalChunkedBuilder { - cat_builder: UInt32Vec, - name: PlSmallStr, - ordering: CategoricalOrdering, - categories: MutablePlString, - local_mapping: HashTable, - local_hasher: PlFixedStateQuality, -} - -impl CategoricalChunkedBuilder { - pub fn new(name: PlSmallStr, capacity: usize, ordering: CategoricalOrdering) -> Self { - Self { - cat_builder: UInt32Vec::with_capacity(capacity), - name, - ordering, - categories: MutablePlString::with_capacity(_HASHMAP_INIT_SIZE), - local_mapping: HashTable::with_capacity(capacity / 10), - local_hasher: StringCache::get_hash_builder(), - } - } - - fn get_cat_idx(&mut self, s: &str, h: u64) -> (u32, bool) { - let len = self.local_mapping.len() as u32; - - // SAFETY: index in hashmap are within bounds of categories - unsafe { - let r = self.local_mapping.entry( - h, - |k| self.categories.value_unchecked(*k as usize) == s, - |k| { - self.local_hasher - .hash_one(self.categories.value_unchecked(*k as usize)) - }, - ); - - match r { - HTEntry::Occupied(v) => (*v.get(), false), - HTEntry::Vacant(slot) => { - self.categories.push(Some(s)); - slot.insert(len); - (len, true) - }, - } - } - } - - fn try_get_cat_idx(&mut self, s: &str, h: u64) -> Option { - // SAFETY: index in hashmap are within bounds of categories - unsafe { - let r = self.local_mapping.entry( - h, - |k| self.categories.value_unchecked(*k as usize) == s, - |k| { - self.local_hasher - .hash_one(self.categories.value_unchecked(*k as usize)) - }, - ); - - match r { - HTEntry::Occupied(v) => Some(*v.get()), - HTEntry::Vacant(_) => None, - } - } - } - - /// Append a new category, but fail if it didn't exist yet in the category state. - /// You can register categories up front with `register_value`, or via `append`. - #[inline] - pub fn try_append_value(&mut self, s: &str) -> PolarsResult<()> { - let h = self.local_hasher.hash_one(s); - let idx = self.try_get_cat_idx(s, h).ok_or_else( - || polars_err!(ComputeError: "category {} doesn't exist in Enum dtype", s), - )?; - self.cat_builder.push(Some(idx)); - Ok(()) - } - - /// Append a new category, but fail if it didn't exist yet in the category state. - /// You can register categories up front with `register_value`, or via `append`. - #[inline] - pub fn try_append(&mut self, opt_s: Option<&str>) -> PolarsResult<()> { - match opt_s { - None => self.append_null(), - Some(s) => self.try_append_value(s)?, - } - Ok(()) - } - - /// Registers a value to a categorical index without pushing it. - /// Returns the index and if the value was new. - #[inline] - pub fn register_value(&mut self, s: &str) -> (u32, bool) { - let h = self.local_hasher.hash_one(s); - self.get_cat_idx(s, h) - } - - #[inline] - pub fn append_value(&mut self, s: &str) { - let h = self.local_hasher.hash_one(s); - let idx = self.get_cat_idx(s, h).0; - self.cat_builder.push(Some(idx)); - } - - #[inline] - pub fn append_null(&mut self) { - self.cat_builder.push(None) - } - - #[inline] - pub fn append(&mut self, opt_s: Option<&str>) { - match opt_s { - None => self.append_null(), - Some(s) => self.append_value(s), - } - } - - fn drain_iter<'a, I>(&mut self, i: I) - where - I: IntoIterator>, - { - for opt_s in i.into_iter() { - self.append(opt_s); - } - } - - /// Fast path for global categorical which preserves hashes and saves an allocation by - /// altering the keys in place. - fn drain_iter_global_and_finish<'a, I>(&mut self, i: I) -> CategoricalChunked - where - I: IntoIterator>, - { - let iter = i.into_iter(); - // Save hashes for later when inserting into the global hashmap. - let mut hashes = Vec::with_capacity(_HASHMAP_INIT_SIZE); - for s in self.categories.values_iter() { - hashes.push(self.local_hasher.hash_one(s)); - } - - for opt_s in iter { - match opt_s { - None => self.append_null(), - Some(s) => { - let hash = self.local_hasher.hash_one(s); - let (cat_idx, new) = self.get_cat_idx(s, hash); - self.cat_builder.push(Some(cat_idx)); - if new { - // We appended a value to the map. - hashes.push(hash); - } - }, - } - } - - let categories = std::mem::take(&mut self.categories).freeze(); - - // We will create a mapping from our local categoricals to global categoricals - // and a mapping from global categoricals to our local categoricals. - let mut local_to_global: Vec = Vec::with_capacity(categories.len()); - let (id, local_to_global) = crate::STRING_CACHE.apply(|cache| { - for (s, h) in categories.values_iter().zip(hashes) { - // SAFETY: we allocated enough. - unsafe { local_to_global.push_unchecked(cache.insert_from_hash(h, s)) } - } - local_to_global - }); - - // Change local indices inplace to their global counterparts. - let update_cats = || { - if !local_to_global.is_empty() { - // when all categorical are null, `local_to_global` is empty and all cats physical values are 0. - self.cat_builder.apply_values(|cats| { - for cat in cats { - debug_assert!((*cat as usize) < local_to_global.len()); - *cat = *unsafe { local_to_global.get_unchecked(*cat as usize) }; - } - }) - } - }; - - let mut global_to_local = PlHashMap::with_capacity(local_to_global.len()); - POOL.join( - || fill_global_to_local(&local_to_global, &mut global_to_local), - update_cats, - ); - - let indices = std::mem::take(&mut self.cat_builder).into(); - let indices = UInt32Chunked::with_chunk(self.name.clone(), indices); - - // SAFETY: indices are in bounds of new rev_map - unsafe { - CategoricalChunked::from_cats_and_rev_map_unchecked( - indices, - Arc::new(RevMapping::Global(global_to_local, categories, id)), - false, - self.ordering, - ) - .with_fast_unique(true) - } - } - - pub fn drain_iter_and_finish<'a, I>(mut self, i: I) -> CategoricalChunked - where - I: IntoIterator>, - { - if using_string_cache() { - self.drain_iter_global_and_finish(i) - } else { - self.drain_iter(i); - self.finish() - } - } - - pub fn finish(self) -> CategoricalChunked { - // SAFETY: keys and values are in bounds - unsafe { - CategoricalChunked::from_keys_and_values( - self.name.clone(), - &self.cat_builder.into(), - &self.categories.into(), - self.ordering, - ) - .with_fast_unique(true) - } - } -} - -fn fill_global_to_local(local_to_global: &[u32], global_to_local: &mut PlHashMap) { - let mut local_idx = 0; - #[allow(clippy::explicit_counter_loop)] - for global_idx in local_to_global { - // we know the keys are unique so this is much faster - unsafe { - global_to_local.insert_unique_unchecked(*global_idx, local_idx); - } - local_idx += 1; - } -} - -impl CategoricalChunked { - /// Create a [`CategoricalChunked`] from a categorical indices. The indices will - /// probe the global string cache. - pub(crate) fn from_global_indices( - cats: UInt32Chunked, - ordering: CategoricalOrdering, - ) -> PolarsResult { - let len = crate::STRING_CACHE.read_map().len() as u32; - let oob = cats.into_iter().flatten().any(|cat| cat >= len); - polars_ensure!( - !oob, - ComputeError: - "cannot construct Categorical from these categories; at least one of them is out of bounds" - ); - Ok(unsafe { Self::from_global_indices_unchecked(cats, ordering) }) - } - - /// Create a [`CategoricalChunked`] from a categorical indices. The indices will - /// probe the global string cache. - /// - /// # Safety - /// This does not do any bound checks - pub unsafe fn from_global_indices_unchecked( - cats: UInt32Chunked, - ordering: CategoricalOrdering, - ) -> CategoricalChunked { - let cache = crate::STRING_CACHE.read_map(); - - let cap = std::cmp::min(std::cmp::min(cats.len(), cache.len()), _HASHMAP_INIT_SIZE); - let mut rev_map = PlHashMap::with_capacity(cap); - let mut str_values = MutablePlString::with_capacity(cap); - - for arr in cats.downcast_iter() { - for cat in arr.into_iter().flatten().copied() { - let offset = str_values.len() as u32; - - if let Entry::Vacant(entry) = rev_map.entry(cat) { - entry.insert(offset); - let str_val = cache.get_unchecked(cat); - str_values.push(Some(str_val)) - } - } - } - - let rev_map = RevMapping::Global(rev_map, str_values.into(), cache.uuid); - - CategoricalChunked::from_cats_and_rev_map_unchecked( - cats, - Arc::new(rev_map), - false, - ordering, - ) - } - - pub(crate) unsafe fn from_keys_and_values_global( - name: PlSmallStr, - keys: impl IntoIterator> + Send, - capacity: usize, - values: &Utf8ViewArray, - ordering: CategoricalOrdering, - ) -> Self { - // Vec where the index is local and the value is the global index - let mut local_to_global: Vec = Vec::with_capacity(values.len()); - let (id, local_to_global) = crate::STRING_CACHE.apply(|cache| { - // locally we don't need a hashmap because we all categories are 1 integer apart - // so the index is local, and the values is global - for s in values.values_iter() { - // SAFETY: we allocated enough - unsafe { local_to_global.push_unchecked(cache.insert(s)) } - } - local_to_global - }); - - let compute_cats = || { - let mut result = UInt32Vec::with_capacity(capacity); - - for opt_value in keys.into_iter() { - result.push(opt_value.map(|cat| { - debug_assert!((cat as usize) < local_to_global.len()); - *unsafe { local_to_global.get_unchecked(cat as usize) } - })); - } - result - }; - - let mut global_to_local = PlHashMap::with_capacity(local_to_global.len()); - let (_, cats) = POOL.join( - || fill_global_to_local(&local_to_global, &mut global_to_local), - compute_cats, - ); - unsafe { - CategoricalChunked::from_cats_and_rev_map_unchecked( - UInt32Chunked::with_chunk(name, cats.into()), - Arc::new(RevMapping::Global(global_to_local, values.clone(), id)), - false, - ordering, - ) - } - } - - pub(crate) unsafe fn from_keys_and_values_local( - name: PlSmallStr, - keys: &PrimitiveArray, - values: &Utf8ViewArray, - ordering: CategoricalOrdering, - ) -> CategoricalChunked { - CategoricalChunked::from_cats_and_rev_map_unchecked( - UInt32Chunked::with_chunk(name, keys.clone()), - Arc::new(RevMapping::build_local(values.clone())), - false, - ordering, - ) - } - - /// # Safety - /// The caller must ensure that index values in the `keys` are in within bounds of the `values` length. - pub(crate) unsafe fn from_keys_and_values( - name: PlSmallStr, - keys: &PrimitiveArray, - values: &Utf8ViewArray, - ordering: CategoricalOrdering, - ) -> Self { - if !using_string_cache() { - CategoricalChunked::from_keys_and_values_local(name, keys, values, ordering) - } else { - CategoricalChunked::from_keys_and_values_global( - name, - keys.into_iter().map(|c| c.copied()), - keys.len(), - values, - ordering, - ) - } - } - - /// Create a [`CategoricalChunked`] from a fixed list of categories and a List of strings. - /// This will error if a string is not in the fixed list of categories - pub fn from_string_to_enum( - values: &StringChunked, - categories: &Utf8ViewArray, - ordering: CategoricalOrdering, - ) -> PolarsResult { - polars_ensure!(categories.null_count() == 0, ComputeError: "categories can not contain null values"); - - // Build a mapping string -> idx - let mut map = PlHashMap::with_capacity(categories.len()); - for (idx, cat) in categories.values_iter().enumerate_idx() { - #[allow(clippy::unnecessary_cast)] - map.insert(cat, idx as u32); - } - // Find idx of every value in the map - let iter = values.downcast_iter().map(|arr| { - arr.iter() - .map(|opt_s: Option<&str>| opt_s.and_then(|s| map.get(s).copied())) - .collect_arr() - }); - let mut keys: UInt32Chunked = ChunkedArray::from_chunk_iter(values.name().clone(), iter); - keys.rename(values.name().clone()); - let rev_map = RevMapping::build_local(categories.clone()); - unsafe { - Ok(CategoricalChunked::from_cats_and_rev_map_unchecked( - keys, - Arc::new(rev_map), - true, - ordering, - ) - .with_fast_unique(false)) - } - } -} - -#[cfg(test)] -mod test { - use crate::prelude::*; - use crate::{SINGLE_LOCK, disable_string_cache, enable_string_cache}; - - #[test] - fn test_categorical_rev() -> PolarsResult<()> { - let _lock = SINGLE_LOCK.lock(); - disable_string_cache(); - let slice = &[ - Some("foo"), - None, - Some("bar"), - Some("foo"), - Some("foo"), - Some("bar"), - ]; - let ca = StringChunked::new(PlSmallStr::from_static("a"), slice); - let out = ca.cast(&DataType::Categorical(None, Default::default()))?; - let out = out.categorical().unwrap().clone(); - assert_eq!(out.get_rev_map().len(), 2); - - // test the global branch - enable_string_cache(); - // empty global cache - let out = ca.cast(&DataType::Categorical(None, Default::default()))?; - let out = out.categorical().unwrap().clone(); - assert_eq!(out.get_rev_map().len(), 2); - // full global cache - let out = ca.cast(&DataType::Categorical(None, Default::default()))?; - let out = out.categorical().unwrap().clone(); - assert_eq!(out.get_rev_map().len(), 2); - - // Check that we don't panic if we append two categorical arrays - // build under the same string cache - // https://github.com/pola-rs/polars/issues/1115 - let ca1 = StringChunked::new(PlSmallStr::from_static("a"), slice) - .cast(&DataType::Categorical(None, Default::default()))?; - let mut ca1 = ca1.categorical().unwrap().clone(); - let ca2 = StringChunked::new(PlSmallStr::from_static("a"), slice) - .cast(&DataType::Categorical(None, Default::default()))?; - let ca2 = ca2.categorical().unwrap(); - ca1.append(ca2).unwrap(); - - Ok(()) - } - - #[test] - fn test_categorical_builder() { - use crate::{disable_string_cache, enable_string_cache}; - let _lock = crate::SINGLE_LOCK.lock(); - for use_string_cache in [false, true] { - disable_string_cache(); - if use_string_cache { - enable_string_cache(); - } - - // Use 2 builders to check if the global string cache - // does not interfere with the index mapping - let builder1 = CategoricalChunkedBuilder::new( - PlSmallStr::from_static("foo"), - 10, - Default::default(), - ); - let builder2 = CategoricalChunkedBuilder::new( - PlSmallStr::from_static("foo"), - 10, - Default::default(), - ); - let s = builder1 - .drain_iter_and_finish(vec![None, Some("hello"), Some("vietnam")]) - .into_series(); - assert_eq!(s.str_value(0).unwrap(), "null"); - assert_eq!(s.str_value(1).unwrap(), "hello"); - assert_eq!(s.str_value(2).unwrap(), "vietnam"); - - let s = builder2 - .drain_iter_and_finish(vec![Some("hello"), None, Some("world")]) - .into_series(); - assert_eq!(s.str_value(0).unwrap(), "hello"); - assert_eq!(s.str_value(1).unwrap(), "null"); - assert_eq!(s.str_value(2).unwrap(), "world"); - } - } -} diff --git a/crates/polars-core/src/chunked_array/logical/categorical/from.rs b/crates/polars-core/src/chunked_array/logical/categorical/from.rs deleted file mode 100644 index 94ddaab1bfb4..000000000000 --- a/crates/polars-core/src/chunked_array/logical/categorical/from.rs +++ /dev/null @@ -1,100 +0,0 @@ -use arrow::datatypes::IntegerType; -use polars_compute::cast::{CastOptionsImpl, cast, utf8view_to_utf8}; - -use super::*; - -fn convert_values(arr: &Utf8ViewArray, compat_level: CompatLevel) -> ArrayRef { - if compat_level.0 >= 1 { - arr.clone().boxed() - } else { - utf8view_to_utf8::(arr).boxed() - } -} - -impl CategoricalChunked { - pub fn to_arrow(&self, compat_level: CompatLevel, as_i64: bool) -> ArrayRef { - if as_i64 { - self.to_i64(compat_level).boxed() - } else { - self.to_u32(compat_level).boxed() - } - } - - fn to_u32(&self, compat_level: CompatLevel) -> DictionaryArray { - let values_dtype = if compat_level.0 >= 1 { - ArrowDataType::Utf8View - } else { - ArrowDataType::LargeUtf8 - }; - let keys = self.physical().rechunk(); - let keys = keys.downcast_as_array(); - let map = &**self.get_rev_map(); - let dtype = ArrowDataType::Dictionary(IntegerType::UInt32, Box::new(values_dtype), false); - match map { - RevMapping::Local(arr, _) => { - let values = convert_values(arr, compat_level); - - // SAFETY: - // the keys are in bounds - unsafe { DictionaryArray::try_new_unchecked(dtype, keys.clone(), values).unwrap() } - }, - RevMapping::Global(reverse_map, values, _uuid) => { - let iter = keys - .iter() - .map(|opt_k| opt_k.map(|k| *reverse_map.get(k).unwrap())); - let keys = PrimitiveArray::from_trusted_len_iter(iter); - - let values = convert_values(values, compat_level); - - // SAFETY: - // the keys are in bounds - unsafe { DictionaryArray::try_new_unchecked(dtype, keys, values).unwrap() } - }, - } - } - - fn to_i64(&self, compat_level: CompatLevel) -> DictionaryArray { - let values_dtype = if compat_level.0 >= 1 { - ArrowDataType::Utf8View - } else { - ArrowDataType::LargeUtf8 - }; - let keys = self.physical().rechunk(); - let keys = keys.downcast_as_array(); - let map = &**self.get_rev_map(); - let dtype = ArrowDataType::Dictionary(IntegerType::Int64, Box::new(values_dtype), false); - match map { - RevMapping::Local(arr, _) => { - let values = convert_values(arr, compat_level); - - // SAFETY: - // the keys are in bounds - unsafe { - DictionaryArray::try_new_unchecked( - dtype, - cast(keys, &ArrowDataType::Int64, CastOptionsImpl::unchecked()) - .unwrap() - .as_any() - .downcast_ref::>() - .unwrap() - .clone(), - values, - ) - .unwrap() - } - }, - RevMapping::Global(reverse_map, values, _uuid) => { - let iter = keys - .iter() - .map(|opt_k| opt_k.map(|k| *reverse_map.get(k).unwrap() as i64)); - let keys = PrimitiveArray::from_trusted_len_iter(iter); - - let values = convert_values(values, compat_level); - - // SAFETY: - // the keys are in bounds - unsafe { DictionaryArray::try_new_unchecked(dtype, keys, values).unwrap() } - }, - } - } -} diff --git a/crates/polars-core/src/chunked_array/logical/categorical/merge.rs b/crates/polars-core/src/chunked_array/logical/categorical/merge.rs deleted file mode 100644 index 28e92cb4f11e..000000000000 --- a/crates/polars-core/src/chunked_array/logical/categorical/merge.rs +++ /dev/null @@ -1,260 +0,0 @@ -use std::borrow::Cow; - -use super::*; -use crate::series::IsSorted; -use crate::utils::align_chunks_binary; - -fn slots_to_mut(slots: &Utf8ViewArray) -> MutablePlString { - slots.clone().make_mut() -} - -struct State { - map: PlHashMap, - slots: MutablePlString, -} - -#[derive(Default)] -pub struct GlobalRevMapMerger { - id: u32, - original: Arc, - // only initiate state when - // we encounter a rev-map from a different source, - // but from the same string cache - state: Option, -} - -impl GlobalRevMapMerger { - pub fn new(rev_map: Arc) -> Self { - let RevMapping::Global(_, _, id) = rev_map.as_ref() else { - unreachable!() - }; - - GlobalRevMapMerger { - state: None, - id: *id, - original: rev_map, - } - } - - fn init_state(&mut self) { - let RevMapping::Global(map, slots, _) = self.original.as_ref() else { - unreachable!() - }; - self.state = Some(State { - map: (*map).clone(), - slots: slots_to_mut(slots), - }) - } - - pub fn merge_map(&mut self, rev_map: &Arc) -> PolarsResult<()> { - // happy path they come from the same source - if Arc::ptr_eq(&self.original, rev_map) { - return Ok(()); - } - - let RevMapping::Global(map, slots, id) = rev_map.as_ref() else { - polars_bail!(string_cache_mismatch) - }; - polars_ensure!(*id == self.id, string_cache_mismatch); - - if self.state.is_none() { - self.init_state() - } - let state = self.state.as_mut().unwrap(); - - for (cat, idx) in map.iter() { - state.map.entry(*cat).or_insert_with(|| { - // SAFETY: - // within bounds - let str_val = unsafe { slots.value_unchecked(*idx as usize) }; - let new_idx = state.slots.len() as u32; - state.slots.push(Some(str_val)); - - new_idx - }); - } - Ok(()) - } - - pub fn finish(self) -> Arc { - match self.state { - None => self.original, - Some(state) => { - let new_rev = RevMapping::Global(state.map, state.slots.into(), self.id); - Arc::new(new_rev) - }, - } - } -} - -fn merge_local_rhs_categorical<'a>( - categories: &'a Utf8ViewArray, - ca_right: &'a CategoricalChunked, -) -> Result<(UInt32Chunked, Arc), PolarsError> { - // Counterpart of the GlobalRevmapMerger. - // In case of local categorical we also need to change the physicals not only the revmap - - polars_warn!( - CategoricalRemappingWarning, - "Local categoricals have different encodings, expensive re-encoding is done \ - to perform this merge operation. Consider using a StringCache or an Enum type \ - if the categories are known in advance" - ); - - let RevMapping::Local(cats_right, _) = &**ca_right.get_rev_map() else { - unreachable!() - }; - - let cats_left_hashmap = PlHashMap::from_iter( - categories - .values_iter() - .enumerate() - .map(|(k, v)| (v, k as u32)), - ); - let mut new_categories = slots_to_mut(categories); - let mut idx_mapping = PlHashMap::with_capacity(cats_right.len()); - - for (idx, s) in cats_right.values_iter().enumerate() { - if let Some(v) = cats_left_hashmap.get(&s) { - idx_mapping.insert(idx as u32, *v); - } else { - idx_mapping.insert(idx as u32, new_categories.len() as u32); - new_categories.push(Some(s)); - } - } - let new_rev_map = Arc::new(RevMapping::build_local(new_categories.into())); - Ok(( - ca_right - .physical - .apply(|opt_v| opt_v.map(|v| *idx_mapping.get(&v).unwrap())), - new_rev_map, - )) -} - -pub trait CategoricalMergeOperation { - fn finish(self, lhs: &UInt32Chunked, rhs: &UInt32Chunked) -> PolarsResult; -} - -// Make the right categorical compatible with the left while applying the merge operation -pub fn call_categorical_merge_operation( - cat_left: &CategoricalChunked, - cat_right: &CategoricalChunked, - merge_ops: I, -) -> PolarsResult { - let rev_map_left = cat_left.get_rev_map(); - let rev_map_right = cat_right.get_rev_map(); - let (mut new_physical, new_rev_map) = match (&**rev_map_left, &**rev_map_right) { - (RevMapping::Global(_, _, idl), RevMapping::Global(_, _, idr)) if idl == idr => { - let mut rev_map_merger = GlobalRevMapMerger::new(rev_map_left.clone()); - rev_map_merger.merge_map(rev_map_right)?; - ( - merge_ops.finish(cat_left.physical(), cat_right.physical())?, - rev_map_merger.finish(), - ) - }, - (RevMapping::Local(_, idl), RevMapping::Local(_, idr)) - if idl == idr && cat_left.is_enum() == cat_right.is_enum() => - { - ( - merge_ops.finish(cat_left.physical(), cat_right.physical())?, - rev_map_left.clone(), - ) - }, - (RevMapping::Local(categorical, _), RevMapping::Local(_, _)) - if !cat_left.is_enum() && !cat_right.is_enum() => - { - let (rhs_physical, rev_map) = merge_local_rhs_categorical(categorical, cat_right)?; - ( - merge_ops.finish(cat_left.physical(), &rhs_physical)?, - rev_map, - ) - }, - (RevMapping::Local(_, _), RevMapping::Local(_, _)) - if cat_left.is_enum() | cat_right.is_enum() => - { - polars_bail!(ComputeError: "can not merge incompatible Enum types") - }, - _ => polars_bail!(string_cache_mismatch), - }; - // During merge operation, the sorted flag might get set on the underlying physical. - // Ensure that the sorted flag is not set if we use lexical order - if cat_left.uses_lexical_ordering() { - new_physical.set_sorted_flag(IsSorted::Not) - } - - // SAFETY: physical and rev map are correctly constructed above - unsafe { - Ok(CategoricalChunked::from_cats_and_rev_map_unchecked( - new_physical, - new_rev_map, - cat_left.is_enum(), - cat_left.get_ordering(), - )) - } -} - -struct DoNothing; -impl CategoricalMergeOperation for DoNothing { - fn finish(self, _lhs: &UInt32Chunked, rhs: &UInt32Chunked) -> PolarsResult { - Ok(rhs.clone()) - } -} - -// Make the right categorical compatible with the left -pub fn make_rhs_categoricals_compatible( - ca_left: &CategoricalChunked, - ca_right: &CategoricalChunked, -) -> PolarsResult<(CategoricalChunked, CategoricalChunked)> { - let new_ca_right = call_categorical_merge_operation(ca_left, ca_right, DoNothing)?; - - // Alter rev map of left - let mut new_ca_left = ca_left.clone(); - // SAFETY: We just made both rev maps compatible only appended categories - unsafe { - new_ca_left.set_rev_map( - new_ca_right.get_rev_map().clone(), - ca_left.get_rev_map().len() == new_ca_right.get_rev_map().len(), - ) - }; - - Ok((new_ca_left, new_ca_right)) -} - -pub fn make_rhs_list_categoricals_compatible( - mut list_ca_left: ListChunked, - list_ca_right: ListChunked, -) -> PolarsResult<(ListChunked, ListChunked)> { - // Make categoricals compatible - - let cat_left = list_ca_left.get_inner(); - let cat_right = list_ca_right.get_inner(); - let (cat_left, cat_right) = - make_rhs_categoricals_compatible(cat_left.categorical()?, cat_right.categorical()?)?; - - // we only appended categories to the rev_map at the end, so only change the inner dtype - list_ca_left.set_inner_dtype(cat_left.dtype().clone()); - - // We changed the physicals and the rev_map, offsets and validity buffers are still good - let (list_ca_right, cat_physical): (Cow, Cow) = - align_chunks_binary(&list_ca_right, cat_right.physical()); - let mut list_ca_right = list_ca_right.into_owned(); - // SAFETY: - // Chunks are aligned, length / dtype remains correct - unsafe { - list_ca_right - .downcast_iter_mut() - .zip(cat_physical.chunks()) - .for_each(|(arr, new_phys)| { - *arr = ListArray::new( - arr.dtype().clone(), - arr.offsets().clone(), - new_phys.clone(), - arr.validity().cloned(), - ) - }); - } - // reset the sorted flag and add extra categories back in - list_ca_right.set_sorted_flag(IsSorted::Not); - list_ca_right.set_inner_dtype(cat_right.dtype().clone()); - Ok((list_ca_left, list_ca_right)) -} diff --git a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs deleted file mode 100644 index 6d14f97d5653..000000000000 --- a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs +++ /dev/null @@ -1,598 +0,0 @@ -mod builder; -mod from; -mod merge; -mod ops; -pub mod revmap; -pub mod string_cache; - -use bitflags::bitflags; -pub use builder::*; -pub use merge::*; -use polars_utils::itertools::Itertools; -use polars_utils::sync::SyncPtr; -pub use revmap::*; - -use super::*; -use crate::chunked_array::cast::CastOptions; -use crate::chunked_array::flags::StatisticsFlags; -use crate::prelude::*; -use crate::series::IsSorted; -use crate::using_string_cache; - -bitflags! { - #[derive(Default, Clone)] - struct BitSettings: u8 { - const ORIGINAL = 0x01; - } -} - -#[derive(Clone)] -pub struct CategoricalChunked { - physical: Logical, - /// 1st bit: original local categorical - /// meaning that n_unique is the same as the cat map length - bit_settings: BitSettings, -} - -impl CategoricalChunked { - pub(crate) fn field(&self) -> Field { - let name = self.physical().name(); - Field::new(name.clone(), self.dtype().clone()) - } - - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - #[inline] - pub fn len(&self) -> usize { - self.physical.len() - } - - #[inline] - pub fn null_count(&self) -> usize { - self.physical.null_count() - } - - pub fn name(&self) -> &PlSmallStr { - self.physical.name() - } - - /// Get the physical array (the category indexes). - pub fn into_physical(self) -> UInt32Chunked { - self.physical.phys - } - - // TODO: Rename this - /// Get a reference to the physical array (the categories). - pub fn physical(&self) -> &UInt32Chunked { - &self.physical - } - - /// Get a mutable reference to the physical array (the categories). - pub(crate) fn physical_mut(&mut self) -> &mut UInt32Chunked { - &mut self.physical - } - - pub fn is_enum(&self) -> bool { - matches!(self.dtype(), DataType::Enum(_, _)) - } - - /// Convert a categorical column to its local representation. - pub fn to_local(&self) -> Self { - let rev_map = self.get_rev_map(); - let (physical_map, categories) = match rev_map.as_ref() { - RevMapping::Global(m, c, _) => (m, c), - RevMapping::Local(_, _) if !self.is_enum() => return self.clone(), - RevMapping::Local(_, _) => { - // Change dtype from Enum to Categorical - let mut local = self.clone(); - local.physical.dtype = - DataType::Categorical(Some(rev_map.clone()), self.get_ordering()); - return local; - }, - }; - - let local_rev_map = RevMapping::build_local(categories.clone()); - // TODO: A fast path can possibly be implemented here: - // if all physical map keys are equal to their values, - // we can skip the apply and only update the rev_map - let local_ca = self - .physical() - .apply(|opt_v| opt_v.map(|v| *physical_map.get(&v).unwrap())); - - let mut out = unsafe { - Self::from_cats_and_rev_map_unchecked( - local_ca, - local_rev_map.into(), - false, - self.get_ordering(), - ) - }; - out.set_fast_unique(self._can_fast_unique()); - - out - } - - pub fn to_global(&self) -> PolarsResult { - polars_ensure!(using_string_cache(), string_cache_mismatch); - // Fast path - let categories = match &**self.get_rev_map() { - RevMapping::Global(_, _, _) => return Ok(self.clone()), - RevMapping::Local(categories, _) => categories, - }; - - // SAFETY: keys and values are in bounds - unsafe { - Ok(CategoricalChunked::from_keys_and_values_global( - self.name().clone(), - self.physical(), - self.len(), - categories, - self.get_ordering(), - )) - } - } - - // Convert to fixed enum. Values not in categories are mapped to None. - pub fn to_enum(&self, categories: &Utf8ViewArray, hash: u128) -> Self { - // Fast paths - match self.get_rev_map().as_ref() { - RevMapping::Local(_, cur_hash) if hash == *cur_hash => { - return unsafe { - CategoricalChunked::from_cats_and_rev_map_unchecked( - self.physical().clone(), - self.get_rev_map().clone(), - true, - self.get_ordering(), - ) - }; - }, - _ => (), - }; - // Make a mapping from old idx to new idx - let old_rev_map = self.get_rev_map(); - - // Create map of old category -> idx for fast lookup. - let old_categories = old_rev_map.get_categories(); - let old_idx_map: PlHashMap<&str, u32> = old_categories - .values_iter() - .zip(0..old_categories.len() as u32) - .collect(); - - #[allow(clippy::unnecessary_cast)] - let idx_map: PlHashMap = categories - .values_iter() - .enumerate_idx() - .filter_map(|(new_idx, s)| old_idx_map.get(s).map(|old_idx| (*old_idx, new_idx as u32))) - .collect(); - - // Loop over the physicals and try get new idx - let new_phys: UInt32Chunked = self - .physical() - .into_iter() - .map(|opt_v: Option| opt_v.and_then(|v| idx_map.get(&v).copied())) - .collect(); - - // SAFETY: we created the physical from the enum categories - unsafe { - CategoricalChunked::from_cats_and_rev_map_unchecked( - new_phys, - Arc::new(RevMapping::Local(categories.clone(), hash)), - true, - self.get_ordering(), - ) - } - } - - pub(crate) fn get_flags(&self) -> StatisticsFlags { - self.physical().get_flags() - } - - /// Set flags for the Chunked Array - pub(crate) fn set_flags(&mut self, mut flags: StatisticsFlags) { - // We should not set the sorted flag if we are sorting in lexical order - if self.uses_lexical_ordering() { - flags.set_sorted(IsSorted::Not) - } - self.physical_mut().set_flags(flags) - } - - /// Return whether or not the [`CategoricalChunked`] uses the lexical order - /// of the string values when sorting. - pub fn uses_lexical_ordering(&self) -> bool { - self.get_ordering() == CategoricalOrdering::Lexical - } - - pub fn get_ordering(&self) -> CategoricalOrdering { - if let DataType::Categorical(_, ordering) | DataType::Enum(_, ordering) = - &self.physical.dtype - { - *ordering - } else { - panic!("implementation error") - } - } - - /// Create a [`CategoricalChunked`] from a physical array and dtype. - /// - /// # Safety - /// It's not checked that the indices are in-bounds or that the dtype is - /// correct. - pub unsafe fn from_cats_and_dtype_unchecked(idx: UInt32Chunked, dtype: DataType) -> Self { - debug_assert!(matches!( - dtype, - DataType::Enum { .. } | DataType::Categorical { .. } - )); - Self { - physical: Logical::new_logical(idx, dtype), - bit_settings: Default::default(), - } - } - - /// Create a [`CategoricalChunked`] from an array of `idx` and an existing [`RevMapping`]: `rev_map`. - /// - /// # Safety - /// Invariant in `v < rev_map.len() for v in idx` must hold. - pub unsafe fn from_cats_and_rev_map_unchecked( - idx: UInt32Chunked, - rev_map: Arc, - is_enum: bool, - ordering: CategoricalOrdering, - ) -> Self { - let dtype = if is_enum { - DataType::Enum(Some(rev_map), ordering) - } else { - DataType::Categorical(Some(rev_map), ordering) - }; - Self { - physical: Logical::new_logical(idx, dtype), - bit_settings: Default::default(), - } - } - - pub(crate) fn set_ordering( - mut self, - ordering: CategoricalOrdering, - keep_fast_unique: bool, - ) -> Self { - self.physical.dtype = match self.dtype() { - DataType::Enum(_, _) => DataType::Enum(Some(self.get_rev_map().clone()), ordering), - DataType::Categorical(_, _) => { - DataType::Categorical(Some(self.get_rev_map().clone()), ordering) - }, - _ => panic!("implementation error"), - }; - - if !keep_fast_unique { - self.set_fast_unique(false) - } - self - } - - /// # Safety - /// The existing index values must be in bounds of the new [`RevMapping`]. - pub(crate) unsafe fn set_rev_map(&mut self, rev_map: Arc, keep_fast_unique: bool) { - self.physical.dtype = match self.dtype() { - DataType::Enum(_, _) => DataType::Enum(Some(rev_map), self.get_ordering()), - DataType::Categorical(_, _) => { - DataType::Categorical(Some(rev_map), self.get_ordering()) - }, - _ => panic!("implementation error"), - }; - - if !keep_fast_unique { - self.set_fast_unique(false) - } - } - - /// True if all categories are represented in this array. When this is the case, the unique - /// values of the array are the categories. - pub fn _can_fast_unique(&self) -> bool { - self.bit_settings.contains(BitSettings::ORIGINAL) - && self.physical.chunks.len() == 1 - && self.null_count() == 0 - } - - pub(crate) fn set_fast_unique(&mut self, toggle: bool) { - if toggle { - self.bit_settings.insert(BitSettings::ORIGINAL); - } else { - self.bit_settings.remove(BitSettings::ORIGINAL); - } - } - - /// Set `FAST_UNIQUE` metadata - /// # Safety - /// This invariant must hold `unique(categories) == unique(self)` - pub(crate) unsafe fn with_fast_unique(mut self, toggle: bool) -> Self { - self.set_fast_unique(toggle); - self - } - - /// Set `FAST_UNIQUE` metadata - /// # Safety - /// This invariant must hold `unique(categories) == unique(self)` - pub unsafe fn _with_fast_unique(self, toggle: bool) -> Self { - self.with_fast_unique(toggle) - } - - /// Get a reference to the mapping of categorical types to the string values. - pub fn get_rev_map(&self) -> &Arc { - if let DataType::Categorical(Some(rev_map), _) | DataType::Enum(Some(rev_map), _) = - &self.physical.dtype - { - rev_map - } else { - panic!("implementation error") - } - } - - /// Create an [`Iterator`] that iterates over the `&str` values of the [`CategoricalChunked`]. - pub fn iter_str(&self) -> CatIter<'_> { - let iter = self.physical().into_iter(); - CatIter { - rev: self.get_rev_map(), - iter, - } - } -} - -impl LogicalType for CategoricalChunked { - fn dtype(&self) -> &DataType { - &self.physical.dtype - } - - fn get_any_value(&self, i: usize) -> PolarsResult> { - polars_ensure!(i < self.len(), oob = i, self.len()); - Ok(unsafe { self.get_any_value_unchecked(i) }) - } - - unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> { - match self.physical.phys.get_unchecked(i) { - Some(i) => match self.dtype() { - DataType::Enum(_, _) => AnyValue::Enum(i, self.get_rev_map(), SyncPtr::new_null()), - DataType::Categorical(_, _) => { - AnyValue::Categorical(i, self.get_rev_map(), SyncPtr::new_null()) - }, - _ => unimplemented!(), - }, - None => AnyValue::Null, - } - } - - fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { - match dtype { - DataType::String => { - let mapping = &**self.get_rev_map(); - - let mut builder = - StringChunkedBuilder::new(self.physical.name().clone(), self.len()); - - let f = |idx: u32| mapping.get(idx); - - if !self.physical.has_nulls() { - self.physical - .into_no_null_iter() - .for_each(|idx| builder.append_value(f(idx))); - } else { - self.physical.into_iter().for_each(|opt_idx| { - builder.append_option(opt_idx.map(f)); - }); - } - - let ca = builder.finish(); - Ok(ca.into_series()) - }, - DataType::UInt32 => { - let ca = unsafe { - UInt32Chunked::from_chunks( - self.physical.name().clone(), - self.physical.chunks.clone(), - ) - }; - Ok(ca.into_series()) - }, - #[cfg(feature = "dtype-categorical")] - DataType::Enum(Some(rev_map), ordering) => { - let RevMapping::Local(categories, hash) = &**rev_map else { - polars_bail!(ComputeError: "can not cast to enum with global mapping") - }; - Ok(self - .to_enum(categories, *hash) - .set_ordering(*ordering, true) - .into_series() - .with_name(self.name().clone())) - }, - DataType::Enum(None, _) => { - polars_bail!(ComputeError: "can not cast to enum without categories present") - }, - #[cfg(feature = "dtype-categorical")] - DataType::Categorical(rev_map, ordering) => { - // Casting from an Enum to a local or global - if matches!(self.dtype(), DataType::Enum(_, _)) && rev_map.is_none() { - if using_string_cache() { - return Ok(self - .to_global()? - .set_ordering(*ordering, true) - .into_series()); - } else { - return Ok(self.to_local().set_ordering(*ordering, true).into_series()); - } - } - // If casting to lexical categorical, set sorted flag as not set - - let mut ca = self.clone().set_ordering(*ordering, true); - if ca.uses_lexical_ordering() { - ca.physical.set_sorted_flag(IsSorted::Not); - } - Ok(ca.into_series()) - }, - dt if dt.is_primitive_numeric() => { - // Apply the cast to the categories and then index into the casted series. - // This has to be local for the gather. - let slf = self.to_local(); - let categories = StringChunked::with_chunk( - slf.physical.name().clone(), - slf.get_rev_map().get_categories().clone(), - ); - let casted_series = categories.cast_with_options(dtype, options)?; - - #[cfg(feature = "bigidx")] - { - let s = slf.physical.cast_with_options(&DataType::UInt64, options)?; - Ok(unsafe { casted_series.take_unchecked(s.u64()?) }) - } - #[cfg(not(feature = "bigidx"))] - { - // SAFETY: Invariant of categorical means indices are in bound - Ok(unsafe { casted_series.take_unchecked(&slf.physical) }) - } - }, - _ => self.physical.cast_with_options(dtype, options), - } - } -} - -pub struct CatIter<'a> { - rev: &'a RevMapping, - iter: Box> + 'a>, -} - -unsafe impl TrustedLen for CatIter<'_> {} - -impl<'a> Iterator for CatIter<'a> { - type Item = Option<&'a str>; - - fn next(&mut self) -> Option { - self.iter.next().map(|item| { - item.map(|idx| { - // SAFETY: - // all categories are in bound - unsafe { self.rev.get_unchecked(idx) } - }) - }) - } - - fn size_hint(&self) -> (usize, Option) { - self.iter.size_hint() - } -} - -impl DoubleEndedIterator for CatIter<'_> { - fn next_back(&mut self) -> Option { - self.iter.next_back().map(|item| { - item.map(|idx| { - // SAFETY: - // all categories are in bound - unsafe { self.rev.get_unchecked(idx) } - }) - }) - } -} - -impl ExactSizeIterator for CatIter<'_> {} - -#[cfg(test)] -mod test { - use super::*; - use crate::{SINGLE_LOCK, disable_string_cache, enable_string_cache}; - - #[test] - fn test_categorical_round_trip() -> PolarsResult<()> { - let _lock = SINGLE_LOCK.lock(); - disable_string_cache(); - let slice = &[ - Some("foo"), - None, - Some("bar"), - Some("foo"), - Some("foo"), - Some("bar"), - ]; - let ca = StringChunked::new(PlSmallStr::from_static("a"), slice); - let ca = ca.cast(&DataType::Categorical(None, Default::default()))?; - let ca = ca.categorical().unwrap(); - - let arr = ca.to_arrow(CompatLevel::newest(), false); - let s = Series::try_from((PlSmallStr::from_static("foo"), arr))?; - assert!(matches!(s.dtype(), &DataType::Categorical(_, _))); - assert_eq!(s.null_count(), 1); - assert_eq!(s.len(), 6); - - Ok(()) - } - - #[test] - fn test_append_categorical() { - let _lock = SINGLE_LOCK.lock(); - disable_string_cache(); - enable_string_cache(); - - let mut s1 = Series::new(PlSmallStr::from_static("1"), vec!["a", "b", "c"]) - .cast(&DataType::Categorical(None, Default::default())) - .unwrap(); - let s2 = Series::new(PlSmallStr::from_static("2"), vec!["a", "x", "y"]) - .cast(&DataType::Categorical(None, Default::default())) - .unwrap(); - let appended = s1.append(&s2).unwrap(); - assert_eq!(appended.str_value(0).unwrap(), "a"); - assert_eq!(appended.str_value(1).unwrap(), "b"); - assert_eq!(appended.str_value(4).unwrap(), "x"); - assert_eq!(appended.str_value(5).unwrap(), "y"); - } - - #[test] - fn test_fast_unique() { - let _lock = SINGLE_LOCK.lock(); - let s = Series::new(PlSmallStr::from_static("1"), vec!["a", "b", "c"]) - .cast(&DataType::Categorical(None, Default::default())) - .unwrap(); - - assert_eq!(s.n_unique().unwrap(), 3); - // Make sure that it does not take the fast path after take/slice. - let out = s.take(&IdxCa::new(PlSmallStr::EMPTY, [1, 2])).unwrap(); - assert_eq!(out.n_unique().unwrap(), 2); - let out = s.slice(1, 2); - assert_eq!(out.n_unique().unwrap(), 2); - } - - #[test] - fn test_categorical_flow() -> PolarsResult<()> { - let _lock = SINGLE_LOCK.lock(); - disable_string_cache(); - - // tests several things that may lose the dtype information - let s = Series::new(PlSmallStr::from_static("a"), vec!["a", "b", "c"]) - .cast(&DataType::Categorical(None, Default::default()))?; - - assert_eq!( - s.field().into_owned(), - Field::new( - PlSmallStr::from_static("a"), - DataType::Categorical(None, Default::default()) - ) - ); - assert!(matches!( - s.get(0)?, - AnyValue::Categorical(0, RevMapping::Local(_, _), _) - )); - - let groups = s.group_tuples(false, true); - let aggregated = unsafe { s.agg_list(&groups?) }; - match aggregated.get(0)? { - AnyValue::List(s) => { - assert!(matches!(s.dtype(), DataType::Categorical(_, _))); - let str_s = s.cast(&DataType::String).unwrap(); - assert_eq!(str_s.get(0)?, AnyValue::String("a")); - assert_eq!(s.len(), 1); - }, - _ => panic!(), - } - let flat = aggregated.explode(false)?; - let ca = flat.categorical().unwrap(); - let vals = ca.iter_str().map(|v| v.unwrap()).collect::>(); - assert_eq!(vals, &["a", "b", "c"]); - Ok(()) - } -} diff --git a/crates/polars-core/src/chunked_array/logical/categorical/ops/append.rs b/crates/polars-core/src/chunked_array/logical/categorical/ops/append.rs deleted file mode 100644 index 8802cde9e6ff..000000000000 --- a/crates/polars-core/src/chunked_array/logical/categorical/ops/append.rs +++ /dev/null @@ -1,66 +0,0 @@ -use polars_error::constants::LENGTH_LIMIT_MSG; - -use super::*; -use crate::chunked_array::ops::append::new_chunks; - -struct CategoricalAppend; - -impl CategoricalMergeOperation for CategoricalAppend { - fn finish(self, lhs: &UInt32Chunked, rhs: &UInt32Chunked) -> PolarsResult { - let mut lhs_mut = lhs.clone(); - lhs_mut.append(rhs)?; - Ok(lhs_mut) - } -} - -impl CategoricalChunked { - fn set_lengths(&mut self, other: &Self) { - let length_self = &mut self.physical_mut().length; - *length_self = length_self - .checked_add(other.len()) - .expect(LENGTH_LIMIT_MSG); - - assert!( - IdxSize::try_from(*length_self).is_ok(), - "{}", - LENGTH_LIMIT_MSG - ); - self.physical_mut().null_count += other.null_count(); - } - - pub fn take(&mut self) -> Self { - Self { - physical: Logical { - phys: core::mem::take(&mut self.physical.phys), - dtype: self.physical.dtype.clone(), - _phantom: PhantomData, - }, - bit_settings: self.bit_settings.clone(), - } - } - - pub fn append(&mut self, other: &Self) -> PolarsResult<()> { - polars_ensure!(!self.is_enum() || self.dtype() == other.dtype(), append); - - // fast path all nulls - if self.physical.null_count() == self.len() && other.physical.null_count() == other.len() { - let len = self.len(); - self.set_lengths(other); - new_chunks(&mut self.physical.chunks, &other.physical().chunks, len); - return Ok(()); - } - - if self.is_enum() { - self.physical_mut().append(other.physical())?; - } else { - let mut new_self = call_categorical_merge_operation(self, other, CategoricalAppend)?; - std::mem::swap(self, &mut new_self); - } - Ok(()) - } - - pub fn append_owned(&mut self, other: Self) -> PolarsResult<()> { - // @TODO: Move the implementation to append_owned and make append dispatch here. - self.append(&other) - } -} diff --git a/crates/polars-core/src/chunked_array/logical/categorical/ops/full.rs b/crates/polars-core/src/chunked_array/logical/categorical/ops/full.rs deleted file mode 100644 index ed53722d163e..000000000000 --- a/crates/polars-core/src/chunked_array/logical/categorical/ops/full.rs +++ /dev/null @@ -1,21 +0,0 @@ -use super::*; - -impl CategoricalChunked { - pub fn full_null( - name: PlSmallStr, - is_enum: bool, - length: usize, - ordering: CategoricalOrdering, - ) -> CategoricalChunked { - let cats = UInt32Chunked::full_null(name, length); - - unsafe { - CategoricalChunked::from_cats_and_rev_map_unchecked( - cats, - Arc::new(RevMapping::default()), - is_enum, - ordering, - ) - } - } -} diff --git a/crates/polars-core/src/chunked_array/logical/categorical/ops/mod.rs b/crates/polars-core/src/chunked_array/logical/categorical/ops/mod.rs deleted file mode 100644 index 759628b322cb..000000000000 --- a/crates/polars-core/src/chunked_array/logical/categorical/ops/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -mod append; -mod full; -#[cfg(feature = "algorithm_group_by")] -mod unique; -#[cfg(feature = "zip_with")] -mod zip; - -use super::*; diff --git a/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs b/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs deleted file mode 100644 index 17752f828d8d..000000000000 --- a/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs +++ /dev/null @@ -1,89 +0,0 @@ -use super::*; - -impl CategoricalChunked { - pub fn unique(&self) -> PolarsResult { - let cat_map = self.get_rev_map(); - if self.is_empty() { - // SAFETY: rev map is valid. - unsafe { - return Ok(CategoricalChunked::from_cats_and_rev_map_unchecked( - UInt32Chunked::full_null(self.name().clone(), 0), - cat_map.clone(), - self.is_enum(), - self.get_ordering(), - )); - } - }; - - if self._can_fast_unique() { - let ca = match &**cat_map { - RevMapping::Local(a, _) => UInt32Chunked::from_iter_values( - self.physical().name().clone(), - 0..(a.len() as u32), - ), - RevMapping::Global(map, _, _) => UInt32Chunked::from_iter_values( - self.physical().name().clone(), - map.keys().copied(), - ), - }; - // SAFETY: - // we only removed some indexes so we are still in bounds - unsafe { - let mut out = CategoricalChunked::from_cats_and_rev_map_unchecked( - ca, - cat_map.clone(), - self.is_enum(), - self.get_ordering(), - ); - out.set_fast_unique(true); - Ok(out) - } - } else { - let ca = self.physical().unique()?; - // SAFETY: - // we only removed some indexes so we are still in bounds - unsafe { - Ok(CategoricalChunked::from_cats_and_rev_map_unchecked( - ca, - cat_map.clone(), - self.is_enum(), - self.get_ordering(), - )) - } - } - } - - pub fn n_unique(&self) -> PolarsResult { - if self._can_fast_unique() { - Ok(self.get_rev_map().len()) - } else { - self.physical().n_unique() - } - } - - pub fn value_counts(&self) -> PolarsResult { - let groups = self.physical().group_tuples(true, false).unwrap(); - let physical_values = unsafe { - self.physical() - .clone() - .into_series() - .agg_first(&groups) - .u32() - .unwrap() - .clone() - }; - - let mut values = self.clone(); - *values.physical_mut() = physical_values; - - let mut counts = groups.group_count(); - counts.rename(PlSmallStr::from_static("counts")); - let height = counts.len(); - let cols = vec![values.into_series().into(), counts.into_series().into()]; - let df = unsafe { DataFrame::new_no_checks(height, cols) }; - df.sort( - ["counts"], - SortMultipleOptions::default().with_order_descending(true), - ) - } -} diff --git a/crates/polars-core/src/chunked_array/logical/categorical/ops/zip.rs b/crates/polars-core/src/chunked_array/logical/categorical/ops/zip.rs deleted file mode 100644 index b92f46c68e17..000000000000 --- a/crates/polars-core/src/chunked_array/logical/categorical/ops/zip.rs +++ /dev/null @@ -1,18 +0,0 @@ -use super::*; - -struct CategoricalZipWith<'a>(&'a BooleanChunked); - -impl CategoricalMergeOperation for CategoricalZipWith<'_> { - fn finish(self, lhs: &UInt32Chunked, rhs: &UInt32Chunked) -> PolarsResult { - lhs.zip_with(self.0, rhs) - } -} -impl CategoricalChunked { - pub(crate) fn zip_with( - &self, - mask: &BooleanChunked, - other: &CategoricalChunked, - ) -> PolarsResult { - call_categorical_merge_operation(self, other, CategoricalZipWith(mask)) - } -} diff --git a/crates/polars-core/src/chunked_array/logical/categorical/revmap.rs b/crates/polars-core/src/chunked_array/logical/categorical/revmap.rs deleted file mode 100644 index 0b27aaec5bbe..000000000000 --- a/crates/polars-core/src/chunked_array/logical/categorical/revmap.rs +++ /dev/null @@ -1,170 +0,0 @@ -#![allow(unsafe_op_in_unsafe_fn)] -use std::fmt::{Debug, Formatter}; -use std::hash::{BuildHasher, Hash, Hasher}; - -use arrow::array::*; -use polars_utils::aliases::PlFixedStateQuality; - -use crate::datatypes::PlHashMap; -use crate::{StringCache, using_string_cache}; - -#[derive(Clone)] -pub enum RevMapping { - /// Hashmap: maps the indexes from the global cache/categorical array to indexes in the local Utf8Array - /// Utf8Array: caches the string values - Global(PlHashMap, Utf8ViewArray, u32), - /// Utf8Array: caches the string values and a hash of all values for quick comparison - Local(Utf8ViewArray, u128), -} - -impl Debug for RevMapping { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - RevMapping::Global(_, _, _) => { - write!(f, "global") - }, - RevMapping::Local(_, _) => { - write!(f, "local") - }, - } - } -} - -impl Default for RevMapping { - fn default() -> Self { - let slice: &[Option<&str>] = &[]; - let cats = Utf8ViewArray::from_slice(slice); - if using_string_cache() { - let cache = &mut crate::STRING_CACHE.lock_map(); - let id = cache.uuid; - RevMapping::Global(Default::default(), cats, id) - } else { - RevMapping::build_local(cats) - } - } -} - -#[allow(clippy::len_without_is_empty)] -impl RevMapping { - pub fn is_active_global(&self) -> bool { - match self { - Self::Global(_, _, id) => *id == StringCache::active_cache_id(), - _ => false, - } - } - - pub fn is_global(&self) -> bool { - matches!(self, Self::Global(_, _, _)) - } - - pub fn is_local(&self) -> bool { - matches!(self, Self::Local(_, _)) - } - - /// Get the categories in this [`RevMapping`] - pub fn get_categories(&self) -> &Utf8ViewArray { - match self { - Self::Global(_, a, _) => a, - Self::Local(a, _) => a, - } - } - - fn build_hash(categories: &Utf8ViewArray) -> u128 { - // TODO! we must also validate the cases of duplicates! - let mut hb = PlFixedStateQuality::with_seed(0).build_hasher(); - categories.values_iter().for_each(|val| { - val.hash(&mut hb); - }); - let hash = hb.finish(); - ((hash as u128) << 64) | (categories.total_buffer_len() as u128) - } - - pub fn build_local(categories: Utf8ViewArray) -> Self { - debug_assert_eq!(categories.null_count(), 0); - let hash = Self::build_hash(&categories); - Self::Local(categories, hash) - } - - /// Get the length of the [`RevMapping`] - pub fn len(&self) -> usize { - self.get_categories().len() - } - - /// [`Categorical`] to [`str`] - /// - /// [`Categorical`]: crate::datatypes::DataType::Categorical - pub fn get(&self, idx: u32) -> &str { - match self { - Self::Global(map, a, _) => { - let idx = *map.get(&idx).unwrap(); - a.value(idx as usize) - }, - Self::Local(a, _) => a.value(idx as usize), - } - } - - pub fn get_optional(&self, idx: u32) -> Option<&str> { - match self { - Self::Global(map, a, _) => { - let idx = *map.get(&idx)?; - a.get(idx as usize) - }, - Self::Local(a, _) => a.get(idx as usize), - } - } - - /// [`Categorical`] to [`str`] - /// - /// [`Categorical`]: crate::datatypes::DataType::Categorical - /// - /// # Safety - /// This doesn't do any bound checking - pub(crate) unsafe fn get_unchecked(&self, idx: u32) -> &str { - match self { - Self::Global(map, a, _) => { - let idx = *map.get(&idx).unwrap(); - a.value_unchecked(idx as usize) - }, - Self::Local(a, _) => a.value_unchecked(idx as usize), - } - } - /// Check if the categoricals have a compatible mapping - #[inline] - pub fn same_src(&self, other: &Self) -> bool { - match (self, other) { - (RevMapping::Global(_, _, l), RevMapping::Global(_, _, r)) => *l == *r, - (RevMapping::Local(_, l_hash), RevMapping::Local(_, r_hash)) => l_hash == r_hash, - _ => false, - } - } - - /// [`str`] to [`Categorical`] - /// - /// - /// [`Categorical`]: crate::datatypes::DataType::Categorical - pub fn find(&self, value: &str) -> Option { - match self { - Self::Global(rev_map, a, id) => { - // fast path is check - if using_string_cache() { - let map = crate::STRING_CACHE.read_map(); - if map.uuid == *id { - return map.get_cat(value); - } - } - rev_map - .iter() - // SAFETY: - // value is always within bounds - .find(|&(_k, &v)| (unsafe { a.value_unchecked(v as usize) } == value)) - .map(|(k, _v)| *k) - }, - - Self::Local(a, _) => { - // SAFETY: within bounds - unsafe { (0..a.len()).find(|idx| a.value_unchecked(*idx) == value) } - .map(|idx| idx as u32) - }, - } - } -} diff --git a/crates/polars-core/src/chunked_array/logical/categorical/string_cache.rs b/crates/polars-core/src/chunked_array/logical/categorical/string_cache.rs deleted file mode 100644 index d6fc543fd273..000000000000 --- a/crates/polars-core/src/chunked_array/logical/categorical/string_cache.rs +++ /dev/null @@ -1,258 +0,0 @@ -use std::hash::{BuildHasher, Hash, Hasher}; -use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; -use std::sync::{LazyLock, Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard}; - -use hashbrown::HashTable; -use hashbrown::hash_table::Entry; -use polars_utils::aliases::PlFixedStateQuality; -use polars_utils::pl_str::PlSmallStr; - -use crate::hashing::_HASHMAP_INIT_SIZE; - -/// We use atomic reference counting to determine how many threads use the -/// string cache. If the refcount is zero, we may clear the string cache. -static STRING_CACHE_REFCOUNT: Mutex = Mutex::new(0); -static STRING_CACHE_ENABLED_GLOBALLY: AtomicBool = AtomicBool::new(false); -static STRING_CACHE_UUID_CTR: AtomicU32 = AtomicU32::new(0); - -/// Enable the global string cache as long as the object is alive ([RAII]). -/// -/// # Examples -/// -/// Enable the string cache by initializing the object: -/// -/// ``` -/// use polars_core::StringCacheHolder; -/// -/// let _sc = StringCacheHolder::hold(); -/// ``` -/// -/// The string cache is enabled until `handle` is dropped. -/// -/// # De-allocation -/// -/// Multiple threads can hold the string cache at the same time. -/// The contents of the cache will only get dropped when no thread holds it. -/// -/// [RAII]: https://en.wikipedia.org/wiki/Resource_acquisition_is_initialization -pub struct StringCacheHolder { - // only added so that it will never be constructed directly - #[allow(dead_code)] - private_zst: (), -} - -impl Default for StringCacheHolder { - fn default() -> Self { - Self::hold() - } -} - -impl StringCacheHolder { - /// Hold the StringCache - pub fn hold() -> StringCacheHolder { - increment_string_cache_refcount(); - StringCacheHolder { private_zst: () } - } -} - -impl Drop for StringCacheHolder { - fn drop(&mut self) { - decrement_string_cache_refcount(); - } -} - -fn increment_string_cache_refcount() { - let mut refcount = STRING_CACHE_REFCOUNT.lock().unwrap(); - *refcount += 1; -} -fn decrement_string_cache_refcount() { - let mut refcount = STRING_CACHE_REFCOUNT.lock().unwrap(); - *refcount -= 1; - if *refcount == 0 { - STRING_CACHE.clear() - } -} - -/// Enable the global string cache. -/// -/// [`Categorical`] columns created under the same global string cache have the -/// same underlying physical value when string values are equal. This allows the -/// columns to be concatenated or used in a join operation, for example. -/// -/// Note that enabling the global string cache introduces some overhead. -/// The amount of overhead depends on the number of categories in your data. -/// It is advised to enable the global string cache only when strictly necessary. -/// -/// [`Categorical`]: crate::datatypes::DataType::Categorical -pub fn enable_string_cache() { - let was_enabled = STRING_CACHE_ENABLED_GLOBALLY.swap(true, Ordering::AcqRel); - if !was_enabled { - increment_string_cache_refcount(); - } -} - -/// Disable and clear the global string cache. -/// -/// Note: Consider using [`StringCacheHolder`] for a more reliable way of -/// enabling and disabling the string cache. -pub fn disable_string_cache() { - let was_enabled = STRING_CACHE_ENABLED_GLOBALLY.swap(false, Ordering::AcqRel); - if was_enabled { - decrement_string_cache_refcount(); - } -} - -/// Check whether the global string cache is enabled. -pub fn using_string_cache() -> bool { - let refcount = STRING_CACHE_REFCOUNT.lock().unwrap(); - *refcount > 0 -} - -// This is the hash and the Index offset in the linear buffer -#[derive(Copy, Clone)] -struct Key { - pub(super) hash: u64, - pub(super) idx: u32, -} - -impl Key { - #[inline] - pub(super) fn new(hash: u64, idx: u32) -> Self { - Self { hash, idx } - } -} - -impl Hash for Key { - #[inline] - fn hash(&self, state: &mut H) { - state.write_u64(self.hash) - } -} - -pub(crate) struct SCacheInner { - map: HashTable, - pub(crate) uuid: u32, - payloads: Vec, -} - -impl SCacheInner { - #[inline] - pub(crate) unsafe fn get_unchecked(&self, cat: u32) -> &str { - self.payloads.get_unchecked(cat as usize).as_str() - } - - pub(crate) fn len(&self) -> usize { - self.map.len() - } - - #[inline] - pub(crate) fn insert_from_hash(&mut self, h: u64, s: &str) -> u32 { - let mut global_idx = self.payloads.len() as u32; - let entry = self.map.entry( - h, - |k| { - let value = unsafe { self.payloads.get_unchecked(k.idx as usize) }; - s == value.as_str() - }, - |k| k.hash, - ); - - match entry { - Entry::Occupied(entry) => { - global_idx = entry.get().idx; - }, - Entry::Vacant(entry) => { - let idx = self.payloads.len() as u32; - let key = Key::new(h, idx); - entry.insert(key); - self.payloads.push(PlSmallStr::from_str(s)); - }, - } - global_idx - } - - #[inline] - pub(crate) fn get_cat(&self, s: &str) -> Option { - let h = StringCache::get_hash_builder().hash_one(s); - self.map - .find(h, |k| { - let value = unsafe { self.payloads.get_unchecked(k.idx as usize) }; - s == value.as_str() - }) - .map(|k| k.idx) - } - - #[inline] - pub(crate) fn insert(&mut self, s: &str) -> u32 { - let h = StringCache::get_hash_builder().hash_one(s); - self.insert_from_hash(h, s) - } - - #[inline] - pub(crate) fn get_current_payloads(&self) -> &[PlSmallStr] { - &self.payloads - } -} - -impl Default for SCacheInner { - fn default() -> Self { - Self { - map: HashTable::with_capacity(_HASHMAP_INIT_SIZE), - uuid: STRING_CACHE_UUID_CTR.fetch_add(1, Ordering::AcqRel), - payloads: Vec::with_capacity(_HASHMAP_INIT_SIZE), - } - } -} - -/// Used by categorical data that need to share global categories. -/// In *eager* you need to specifically toggle global string cache to have a global effect. -/// In *lazy* it is toggled on at the start of a computation run and turned of (deleted) when a -/// result is produced. -#[derive(Default)] -pub(crate) struct StringCache(pub(crate) RwLock); - -impl StringCache { - /// The global `StringCache` will always use a predictable seed. This allows local builders to mimic - /// the hashes in case of contention. - #[inline] - pub(crate) fn get_hash_builder() -> PlFixedStateQuality { - PlFixedStateQuality::with_seed(0) - } - - pub(crate) fn active_cache_id() -> u32 { - STRING_CACHE_UUID_CTR - .load(Ordering::Relaxed) - .wrapping_sub(1) - } - - /// Lock the string cache - pub(crate) fn lock_map(&self) -> RwLockWriteGuard<'_, SCacheInner> { - self.0.write().unwrap() - } - - pub(crate) fn read_map(&self) -> RwLockReadGuard<'_, SCacheInner> { - self.0.read().unwrap() - } - - pub(crate) fn clear(&self) { - let mut lock = self.lock_map(); - *lock = Default::default(); - } - - pub(crate) fn apply(&self, fun: F) -> (u32, T) - where - F: FnOnce(&mut RwLockWriteGuard) -> T, - { - let cache = &mut crate::STRING_CACHE.lock_map(); - - let result = fun(cache); - - if cache.len() > u32::MAX as usize { - panic!("not more than {} categories supported", u32::MAX) - }; - - (cache.uuid, result) - } -} - -pub(crate) static STRING_CACHE: LazyLock = LazyLock::new(Default::default); diff --git a/crates/polars-core/src/chunked_array/logical/date.rs b/crates/polars-core/src/chunked_array/logical/date.rs index 00a59d65d096..ab4f995a2050 100644 --- a/crates/polars-core/src/chunked_array/logical/date.rs +++ b/crates/polars-core/src/chunked_array/logical/date.rs @@ -4,7 +4,8 @@ pub type DateChunked = Logical; impl Int32Chunked { pub fn into_date(self) -> DateChunked { - DateChunked::new_logical(self, DataType::Date) + // SAFETY: no invalid states. + unsafe { DateChunked::new_logical(self, DataType::Date) } } } @@ -39,7 +40,7 @@ impl LogicalType for DateChunked { TimeUnit::Milliseconds => MS_IN_DAY, }; Ok(casted - .deref() + .physical() .checked_mul_scalar(conversion) .into_datetime(*tu, tz.clone()) .into_series()) diff --git a/crates/polars-core/src/chunked_array/logical/datetime.rs b/crates/polars-core/src/chunked_array/logical/datetime.rs index 6e709ecac556..cde473b5bd01 100644 --- a/crates/polars-core/src/chunked_array/logical/datetime.rs +++ b/crates/polars-core/src/chunked_array/logical/datetime.rs @@ -6,7 +6,8 @@ pub type DatetimeChunked = Logical; impl Int64Chunked { pub fn into_datetime(self, timeunit: TimeUnit, tz: Option) -> DatetimeChunked { - DatetimeChunked::new_logical(self, DataType::Datetime(timeunit, tz)) + // SAFETY: no invalid states. + unsafe { DatetimeChunked::new_logical(self, DataType::Datetime(timeunit, tz)) } } } @@ -76,7 +77,7 @@ impl LogicalType for DatetimeChunked { .unwrap() .into_date() .into_series(); - dt.set_sorted_flag(self.is_sorted_flag()); + dt.set_sorted_flag(self.physical().is_sorted_flag()); Ok(dt) }; match self.time_unit() { @@ -115,7 +116,7 @@ impl LogicalType for DatetimeChunked { out.map(|mut s| { // TODO!; implement the divisions/multipliers above // in a checked manner so that we raise on overflow - s.set_sorted_flag(self.is_sorted_flag()); + s.set_sorted_flag(self.physical().is_sorted_flag()); s }) } diff --git a/crates/polars-core/src/chunked_array/logical/decimal.rs b/crates/polars-core/src/chunked_array/logical/decimal.rs index 58310e029917..ee3a69e86644 100644 --- a/crates/polars-core/src/chunked_array/logical/decimal.rs +++ b/crates/polars-core/src/chunked_array/logical/decimal.rs @@ -9,7 +9,8 @@ pub type DecimalChunked = Logical; impl Int128Chunked { #[inline] pub fn into_decimal_unchecked(self, precision: Option, scale: usize) -> DecimalChunked { - DecimalChunked::new_logical(self, DataType::Decimal(precision, Some(scale))) + // SAFETY: no invalid states. + unsafe { DecimalChunked::new_logical(self, DataType::Decimal(precision, Some(scale))) } } pub fn into_decimal( @@ -75,6 +76,7 @@ impl LogicalType for DecimalChunked { let arrow_dtype = self.dtype().to_arrow(CompatLevel::newest()); let chunks = self + .physical() .chunks .iter() .map(|arr| { diff --git a/crates/polars-core/src/chunked_array/logical/duration.rs b/crates/polars-core/src/chunked_array/logical/duration.rs index 36f3b9fe85cf..cb816f07426e 100644 --- a/crates/polars-core/src/chunked_array/logical/duration.rs +++ b/crates/polars-core/src/chunked_array/logical/duration.rs @@ -5,7 +5,8 @@ pub type DurationChunked = Logical; impl Int64Chunked { pub fn into_duration(self, timeunit: TimeUnit) -> DurationChunked { - DurationChunked::new_logical(self, DataType::Duration(timeunit)) + // SAFETY: no invalid states. + unsafe { DurationChunked::new_logical(self, DataType::Duration(timeunit)) } } } diff --git a/crates/polars-core/src/chunked_array/logical/enum_/mod.rs b/crates/polars-core/src/chunked_array/logical/enum_/mod.rs deleted file mode 100644 index 7894a8c6bf91..000000000000 --- a/crates/polars-core/src/chunked_array/logical/enum_/mod.rs +++ /dev/null @@ -1,109 +0,0 @@ -use std::sync::Arc; - -use arrow::array::UInt32Vec; -use arrow::bitmap::MutableBitmap; -use polars_error::{PolarsResult, polars_bail, polars_err}; -use polars_utils::aliases::{InitHashMaps, PlHashMap}; -use polars_utils::pl_str::PlSmallStr; - -use super::{CategoricalChunked, CategoricalOrdering, DataType, Field, RevMapping, UInt32Chunked}; - -pub struct EnumChunkedBuilder { - name: PlSmallStr, - enum_builder: UInt32Vec, - - rev: Arc, - ordering: CategoricalOrdering, - seen: MutableBitmap, - - // Mapping to amortize the costs of lookups. - mapping: PlHashMap, - strict: bool, -} - -impl EnumChunkedBuilder { - pub fn new( - name: PlSmallStr, - capacity: usize, - rev: Arc, - ordering: CategoricalOrdering, - strict: bool, - ) -> Self { - let seen = MutableBitmap::from_len_zeroed(rev.len()); - - Self { - name, - enum_builder: UInt32Vec::with_capacity(capacity), - - rev, - ordering, - seen, - - mapping: PlHashMap::new(), - strict, - } - } - - pub fn append_str(&mut self, v: &str) -> PolarsResult<&mut Self> { - match self.mapping.get(v) { - Some(v) => self.enum_builder.push(Some(*v)), - None => { - let Some(iv) = self.rev.find(v) else { - if self.strict { - polars_bail!(InvalidOperation: "cannot append '{v}' to enum without that variant"); - } else { - self.enum_builder.push(None); - return Ok(self); - } - }; - self.seen.set(iv as usize, true); - self.mapping.insert(v.into(), iv); - self.enum_builder.push(Some(iv)); - }, - } - - Ok(self) - } - - pub fn append_null(&mut self) -> &mut Self { - self.enum_builder.push(None); - self - } - - pub fn append_enum(&mut self, v: u32, rev: &RevMapping) -> PolarsResult<&mut Self> { - if !self.rev.same_src(rev) { - if self.strict { - return Err(polars_err!(ComputeError: "incompatible enum types")); - } else { - self.enum_builder.push(None); - } - } else { - self.seen.set(v as usize, true); - self.enum_builder.push(Some(v)); - } - - Ok(self) - } - - pub fn finish(self) -> CategoricalChunked { - let arr = self.enum_builder.freeze(); - let null_count = arr.validity().map_or(0, |a| a.unset_bits()); - let length = arr.len(); - let ca = unsafe { - UInt32Chunked::new_with_dims( - Arc::new(Field::new(self.name, DataType::UInt32)), - vec![Box::new(arr)], - length, - null_count, - ) - }; - // Fast Unique <=> unique(rev) == unique(ca) - let fast_unique = !ca.has_nulls() && self.seen.unset_bits() == 0; - - // SAFETY: keys and values are in bounds - unsafe { - CategoricalChunked::from_cats_and_rev_map_unchecked(ca, self.rev, true, self.ordering) - .with_fast_unique(fast_unique) - } - } -} diff --git a/crates/polars-core/src/chunked_array/logical/mod.rs b/crates/polars-core/src/chunked_array/logical/mod.rs index f97a60ad302a..fcbf5578cc2b 100644 --- a/crates/polars-core/src/chunked_array/logical/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/mod.rs @@ -16,13 +16,10 @@ mod duration; pub use duration::*; #[cfg(feature = "dtype-categorical")] pub mod categorical; -#[cfg(feature = "dtype-categorical")] -pub mod enum_; #[cfg(feature = "dtype-time")] mod time; use std::marker::PhantomData; -use std::ops::{Deref, DerefMut}; #[cfg(feature = "dtype-categorical")] pub use categorical::*; @@ -50,22 +47,10 @@ impl Clone for Logical { } } -impl Deref for Logical { - type Target = ChunkedArray; - - fn deref(&self) -> &Self::Target { - &self.phys - } -} - -impl DerefMut for Logical { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.phys - } -} - impl Logical { - pub fn new_logical(phys: ChunkedArray, dtype: DataType) -> Logical { + /// # Safety + /// You must uphold the logical types' invariants. + pub unsafe fn new_logical(phys: ChunkedArray, dtype: DataType) -> Logical { Logical { phys, dtype, @@ -100,11 +85,80 @@ impl Logical where Self: LogicalType, { - pub fn physical(&self) -> &ChunkedArray { - &self.phys + #[inline(always)] + pub fn name(&self) -> &PlSmallStr { + self.phys.name() + } + + #[inline(always)] + pub fn rename(&mut self, name: PlSmallStr) { + self.phys.rename(name) + } + + #[inline(always)] + pub fn len(&self) -> usize { + self.phys.len() + } + + #[inline(always)] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + #[inline(always)] + pub fn null_count(&self) -> usize { + self.phys.null_count() } + + #[inline(always)] + pub fn has_nulls(&self) -> bool { + self.phys.has_nulls() + } + + #[inline(always)] + pub fn is_null(&self) -> BooleanChunked { + self.phys.is_null() + } + + #[inline(always)] + pub fn is_not_null(&self) -> BooleanChunked { + self.phys.is_not_null() + } + + #[inline(always)] + pub fn split_at(&self, offset: i64) -> (Self, Self) { + let (left, right) = self.phys.split_at(offset); + unsafe { + ( + Self::new_logical(left, self.dtype.clone()), + Self::new_logical(right, self.dtype.clone()), + ) + } + } + + #[inline(always)] + pub fn slice(&self, offset: i64, length: usize) -> Self { + unsafe { Self::new_logical(self.phys.slice(offset, length), self.dtype.clone()) } + } + + #[inline(always)] pub fn field(&self) -> Field { let name = self.phys.ref_field().name(); Field::new(name.clone(), LogicalType::dtype(self).clone()) } + + #[inline(always)] + pub fn physical(&self) -> &ChunkedArray { + &self.phys + } + + #[inline(always)] + pub fn physical_mut(&mut self) -> &mut ChunkedArray { + &mut self.phys + } + + #[inline(always)] + pub fn into_physical(self) -> ChunkedArray { + self.phys + } } diff --git a/crates/polars-core/src/chunked_array/logical/time.rs b/crates/polars-core/src/chunked_array/logical/time.rs index 62e2622b5ce4..9d6c3240f02a 100644 --- a/crates/polars-core/src/chunked_array/logical/time.rs +++ b/crates/polars-core/src/chunked_array/logical/time.rs @@ -42,7 +42,8 @@ impl Int64Chunked { let int64chunked = unsafe { Self::new_with_dims(self.field.clone(), chunks, self.length, null_count) }; - TimeChunked::new_logical(int64chunked, DataType::Time) + // SAFETY: no invalid states. + unsafe { TimeChunked::new_logical(int64chunked, DataType::Time) } } } diff --git a/crates/polars-core/src/chunked_array/mod.rs b/crates/polars-core/src/chunked_array/mod.rs index b110cfa9b5d4..2092cbf41428 100644 --- a/crates/polars-core/src/chunked_array/mod.rs +++ b/crates/polars-core/src/chunked_array/mod.rs @@ -1011,17 +1011,17 @@ pub(crate) mod test { #[test] #[cfg(feature = "dtype-categorical")] fn test_iter_categorical() { - use crate::{SINGLE_LOCK, disable_string_cache}; - let _lock = SINGLE_LOCK.lock(); - disable_string_cache(); let ca = StringChunked::new( PlSmallStr::EMPTY, &[Some("foo"), None, Some("bar"), Some("ham")], ); - let ca = ca - .cast(&DataType::Categorical(None, Default::default())) - .unwrap(); - let ca = ca.categorical().unwrap(); + let cats = Categories::new( + PlSmallStr::EMPTY, + PlSmallStr::EMPTY, + CategoricalPhysical::U32, + ); + let ca = ca.cast(&DataType::from_categories(cats.clone())).unwrap(); + let ca = ca.cat32().unwrap(); let v: Vec<_> = ca.physical().into_iter().collect(); assert_eq!(v, &[Some(0), None, Some(1), Some(2)]); } diff --git a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs index 26009bcfc2fc..110e280a72cd 100644 --- a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs @@ -9,7 +9,6 @@ use polars_compute::min_max::MinMaxKernel; use polars_compute::rolling::QuantileMethod; use polars_compute::sum::{WrappingSum, wrapping_sum_arr}; use polars_utils::min_max::MinMax; -use polars_utils::sync::SyncPtr; pub use quantile::*; pub use var::*; @@ -438,137 +437,76 @@ impl ChunkAggSeries for StringChunked { } #[cfg(feature = "dtype-categorical")] -impl CategoricalChunked { - fn min_categorical(&self) -> Option { +impl CategoricalChunked +where + ChunkedArray: ChunkAgg, +{ + fn min_categorical(&self) -> Option { if self.is_empty() || self.null_count() == self.len() { return None; } if self.uses_lexical_ordering() { - let rev_map = self.get_rev_map(); - // Fast path where all categories are used - let c = if self._can_fast_unique() { - rev_map.get_categories().min_ignore_nan_kernel() - } else { - // SAFETY: - // Indices are in bounds - self.physical() - .iter() - .flat_map(|opt_el: Option| { - opt_el.map(|el| unsafe { rev_map.get_unchecked(el) }) - }) - .min() - }; - rev_map.find(c.unwrap()) + let mapping = self.get_mapping(); + let s = self + .physical() + .iter() + .flat_map(|opt_cat| { + Some(unsafe { mapping.cat_to_str_unchecked(opt_cat?.as_cat()) }) + }) + .min(); + mapping.get_cat(s.unwrap()) } else { - self.physical().min() + Some(self.physical().min()?.as_cat()) } } - fn max_categorical(&self) -> Option { + fn max_categorical(&self) -> Option { if self.is_empty() || self.null_count() == self.len() { return None; } if self.uses_lexical_ordering() { - let rev_map = self.get_rev_map(); - // Fast path where all categories are used - let c = if self._can_fast_unique() { - rev_map.get_categories().max_ignore_nan_kernel() - } else { - // SAFETY: - // Indices are in bounds - self.physical() - .iter() - .flat_map(|opt_el: Option| { - opt_el.map(|el| unsafe { rev_map.get_unchecked(el) }) - }) - .max() - }; - rev_map.find(c.unwrap()) + let mapping = self.get_mapping(); + let s = self + .physical() + .iter() + .flat_map(|opt_cat| { + Some(unsafe { mapping.cat_to_str_unchecked(opt_cat?.as_cat()) }) + }) + .max(); + mapping.get_cat(s.unwrap()) } else { - self.physical().max() + Some(self.physical().max()?.as_cat()) } } } #[cfg(feature = "dtype-categorical")] -impl ChunkAggSeries for CategoricalChunked { +impl ChunkAggSeries for CategoricalChunked +where + ChunkedArray: ChunkAgg, +{ fn min_reduce(&self) -> Scalar { - match self.dtype() { - DataType::Enum(r, _) => match self.physical().min() { - None => Scalar::new(self.dtype().clone(), AnyValue::Null), - Some(v) => { - let RevMapping::Local(arr, _) = &**r.as_ref().unwrap() else { - unreachable!() - }; - Scalar::new( - self.dtype().clone(), - AnyValue::EnumOwned( - v, - r.as_ref().unwrap().clone(), - SyncPtr::from_const(arr as *const _), - ), - ) - }, - }, - DataType::Categorical(r, _) => match self.min_categorical() { - None => Scalar::new(self.dtype().clone(), AnyValue::Null), - Some(v) => { - let r = r.as_ref().unwrap(); - let arr = match &**r { - RevMapping::Local(arr, _) => arr, - RevMapping::Global(_, arr, _) => arr, - }; - Scalar::new( - self.dtype().clone(), - AnyValue::CategoricalOwned( - v, - r.clone(), - SyncPtr::from_const(arr as *const _), - ), - ) - }, - }, + let Some(min) = self.min_categorical() else { + return Scalar::new(self.dtype().clone(), AnyValue::Null); + }; + let av = match self.dtype() { + DataType::Enum(_, mapping) => AnyValue::EnumOwned(min, mapping.clone()), + DataType::Categorical(_, mapping) => AnyValue::CategoricalOwned(min, mapping.clone()), _ => unreachable!(), - } + }; + Scalar::new(self.dtype().clone(), av) } + fn max_reduce(&self) -> Scalar { - match self.dtype() { - DataType::Enum(r, _) => match self.physical().max() { - None => Scalar::new(self.dtype().clone(), AnyValue::Null), - Some(v) => { - let RevMapping::Local(arr, _) = &**r.as_ref().unwrap() else { - unreachable!() - }; - Scalar::new( - self.dtype().clone(), - AnyValue::EnumOwned( - v, - r.as_ref().unwrap().clone(), - SyncPtr::from_const(arr as *const _), - ), - ) - }, - }, - DataType::Categorical(r, _) => match self.max_categorical() { - None => Scalar::new(self.dtype().clone(), AnyValue::Null), - Some(v) => { - let r = r.as_ref().unwrap(); - let arr = match &**r { - RevMapping::Local(arr, _) => arr, - RevMapping::Global(_, arr, _) => arr, - }; - Scalar::new( - self.dtype().clone(), - AnyValue::CategoricalOwned( - v, - r.clone(), - SyncPtr::from_const(arr as *const _), - ), - ) - }, - }, + let Some(max) = self.max_categorical() else { + return Scalar::new(self.dtype().clone(), AnyValue::Null); + }; + let av = match self.dtype() { + DataType::Enum(_, mapping) => AnyValue::EnumOwned(max, mapping.clone()), + DataType::Categorical(_, mapping) => AnyValue::CategoricalOwned(max, mapping.clone()), _ => unreachable!(), - } + }; + Scalar::new(self.dtype().clone(), av) } } diff --git a/crates/polars-core/src/chunked_array/ops/any_value.rs b/crates/polars-core/src/chunked_array/ops/any_value.rs index a9cfc4e08c27..754973f25d32 100644 --- a/crates/polars-core/src/chunked_array/ops/any_value.rs +++ b/crates/polars-core/src/chunked_array/ops/any_value.rs @@ -1,6 +1,4 @@ #![allow(unsafe_op_in_unsafe_fn)] -#[cfg(feature = "dtype-categorical")] -use polars_utils::sync::SyncPtr; #[cfg(feature = "object")] use crate::chunked_array::object::extension::polars_extension::PolarsExtension; @@ -82,16 +80,22 @@ pub(crate) unsafe fn arr_to_any_value<'a>( } }, #[cfg(feature = "dtype-categorical")] - DataType::Categorical(rev_map, _) => { - let arr = &*(arr as *const dyn Array as *const UInt32Array); - let v = arr.value_unchecked(idx); - AnyValue::Categorical(v, rev_map.as_ref().unwrap().as_ref(), SyncPtr::new_null()) + DataType::Categorical(cats, mapping) => { + with_match_categorical_physical_type!(cats.physical(), |$C| { + type A = <$C as PolarsDataType>::Array; + let arr = &*(arr as *const dyn Array as *const A); + let cat_id = arr.value_unchecked(idx).as_cat(); + AnyValue::Categorical(cat_id, mapping) + }) }, #[cfg(feature = "dtype-categorical")] - DataType::Enum(rev_map, _) => { - let arr = &*(arr as *const dyn Array as *const UInt32Array); - let v = arr.value_unchecked(idx); - AnyValue::Enum(v, rev_map.as_ref().unwrap().as_ref(), SyncPtr::new_null()) + DataType::Enum(fcats, mapping) => { + with_match_categorical_physical_type!(fcats.physical(), |$C| { + type A = <$C as PolarsDataType>::Array; + let arr = &*(arr as *const dyn Array as *const A); + let cat_id = arr.value_unchecked(idx).as_cat(); + AnyValue::Enum(cat_id, mapping) + }) }, #[cfg(feature = "dtype-struct")] DataType::Struct(flds) => { @@ -144,54 +148,14 @@ pub(crate) unsafe fn arr_to_any_value<'a>( #[cfg(feature = "dtype-struct")] impl<'a> AnyValue<'a> { pub fn _iter_struct_av(&self) -> impl Iterator> { - match self { - AnyValue::Struct(idx, arr, flds) => { - let idx = *idx; - unsafe { - arr.values().iter().zip(*flds).map(move |(arr, fld)| { - // The dictionary arrays categories don't have to map to the rev-map in the dtype - // so we set the array pointer with values of the dictionary array. - #[cfg(feature = "dtype-categorical")] - { - use arrow::legacy::is_valid::IsValid as _; - if let Some(arr) = arr.as_any().downcast_ref::>() { - let keys = arr.keys(); - let values = arr.values(); - let values = - values.as_any().downcast_ref::().unwrap(); - let arr = &*(keys as *const dyn Array as *const UInt32Array); - - if arr.is_valid_unchecked(idx) { - let v = arr.value_unchecked(idx); - match fld.dtype() { - DataType::Categorical(Some(rev_map), _) => { - AnyValue::Categorical( - v, - rev_map, - SyncPtr::from_const(values), - ) - }, - DataType::Enum(Some(rev_map), _) => { - AnyValue::Enum(v, rev_map, SyncPtr::from_const(values)) - }, - _ => unimplemented!(), - } - } else { - AnyValue::Null - } - } else { - arr_to_any_value(&**arr, idx, fld.dtype()) - } - } - - #[cfg(not(feature = "dtype-categorical"))] - { - arr_to_any_value(&**arr, idx, fld.dtype()) - } - }) - } - }, - _ => unreachable!(), + let AnyValue::Struct(idx, arr, flds) = self else { + unreachable!() + }; + unsafe { + arr.values() + .iter() + .zip(*flds) + .map(move |(arr, fld)| arr_to_any_value(&**arr, *idx, fld.dtype())) } } diff --git a/crates/polars-core/src/chunked_array/ops/append.rs b/crates/polars-core/src/chunked_array/ops/append.rs index 70e6bcdbe2a8..beb97e4fe36d 100644 --- a/crates/polars-core/src/chunked_array/ops/append.rs +++ b/crates/polars-core/src/chunked_array/ops/append.rs @@ -254,6 +254,20 @@ impl StructChunked { } } +#[cfg(feature = "dtype-categorical")] +#[doc(hidden)] +impl CategoricalChunked { + pub fn append(&mut self, other: &Self) -> PolarsResult<()> { + assert!(self.dtype() == other.dtype()); + self.phys.append(&other.phys) + } + + pub fn append_owned(&mut self, other: Self) -> PolarsResult<()> { + assert!(self.dtype() == other.dtype()); + self.phys.append_owned(other.phys) + } +} + #[cfg(feature = "object")] #[doc(hidden)] impl ObjectChunked { diff --git a/crates/polars-core/src/chunked_array/ops/chunkops.rs b/crates/polars-core/src/chunked_array/ops/chunkops.rs index 7b5c0f81ed78..52bad12325e4 100644 --- a/crates/polars-core/src/chunked_array/ops/chunkops.rs +++ b/crates/polars-core/src/chunked_array/ops/chunkops.rs @@ -368,11 +368,11 @@ mod test { fn test_categorical_map_after_rechunk() { let s = Series::new(PlSmallStr::EMPTY, &["foo", "bar", "spam"]); let mut a = s - .cast(&DataType::Categorical(None, Default::default())) + .cast(&DataType::from_categories(Categories::global())) .unwrap(); a.append(&a.slice(0, 2)).unwrap(); let a = a.rechunk(); - assert!(a.categorical().unwrap().get_rev_map().len() > 0); + assert!(a.cat32().unwrap().get_mapping().num_cats_upper_bound() > 0); } } diff --git a/crates/polars-core/src/chunked_array/ops/compare_inner.rs b/crates/polars-core/src/chunked_array/ops/compare_inner.rs index ce10855cb342..d079018cd2b5 100644 --- a/crates/polars-core/src/chunked_array/ops/compare_inner.rs +++ b/crates/polars-core/src/chunked_array/ops/compare_inner.rs @@ -203,44 +203,30 @@ impl<'a> IntoTotalOrdInner<'a> for &'a NullChunked { } #[cfg(feature = "dtype-categorical")] -struct LocalCategorical<'a> { - rev_map: &'a Utf8ViewArray, - cats: &'a UInt32Chunked, +struct LexicalCategorical<'a, T: PolarsCategoricalType> { + mapping: &'a CategoricalMapping, + cats: &'a ChunkedArray, } #[cfg(feature = "dtype-categorical")] -impl<'a> GetInner for LocalCategorical<'a> { +impl<'a, T: PolarsCategoricalType> GetInner for LexicalCategorical<'a, T> { type Item = Option<&'a str>; unsafe fn get_unchecked(&self, idx: usize) -> Self::Item { let cat = self.cats.get_unchecked(idx)?; - Some(self.rev_map.value_unchecked(cat as usize)) + Some(self.mapping.cat_to_str_unchecked(cat.as_cat())) } } #[cfg(feature = "dtype-categorical")] -struct GlobalCategorical<'a> { - p1: &'a PlHashMap, - p2: &'a Utf8ViewArray, - cats: &'a UInt32Chunked, -} - -#[cfg(feature = "dtype-categorical")] -impl<'a> GetInner for GlobalCategorical<'a> { - type Item = Option<&'a str>; - unsafe fn get_unchecked(&self, idx: usize) -> Self::Item { - let cat = self.cats.get_unchecked(idx)?; - let idx = self.p1.get(&cat).unwrap(); - Some(self.p2.value_unchecked(*idx as usize)) - } -} - -#[cfg(feature = "dtype-categorical")] -impl<'a> IntoTotalOrdInner<'a> for &'a CategoricalChunked { +impl<'a, T: PolarsCategoricalType> IntoTotalOrdInner<'a> for &'a CategoricalChunked { fn into_total_ord_inner(self) -> Box { - let cats = self.physical(); - match &**self.get_rev_map() { - RevMapping::Global(p1, p2, _) => Box::new(GlobalCategorical { p1, p2, cats }), - RevMapping::Local(rev_map, _) => Box::new(LocalCategorical { rev_map, cats }), + if self.uses_lexical_ordering() { + Box::new(LexicalCategorical:: { + mapping: self.get_mapping(), + cats: &self.phys, + }) + } else { + self.phys.into_total_ord_inner() } } } diff --git a/crates/polars-core/src/chunked_array/ops/extend.rs b/crates/polars-core/src/chunked_array/ops/extend.rs index 8111b4a764e0..c00f5e9afbf9 100644 --- a/crates/polars-core/src/chunked_array/ops/extend.rs +++ b/crates/polars-core/src/chunked_array/ops/extend.rs @@ -185,6 +185,15 @@ impl StructChunked { } } +#[cfg(feature = "dtype-categorical")] +#[doc(hidden)] +impl CategoricalChunked { + pub fn extend(&mut self, other: &Self) -> PolarsResult<()> { + assert!(self.dtype() == other.dtype()); + self.phys.extend(&other.phys) + } +} + #[cfg(test)] mod test { use super::*; diff --git a/crates/polars-core/src/chunked_array/ops/fill_null.rs b/crates/polars-core/src/chunked_array/ops/fill_null.rs index 3ca3aae9279f..b591a0062a5b 100644 --- a/crates/polars-core/src/chunked_array/ops/fill_null.rs +++ b/crates/polars-core/src/chunked_array/ops/fill_null.rs @@ -89,7 +89,7 @@ impl Series { let precision = ca.precision(); let scale = ca.scale(); let fill_value = 10i128.pow(scale as u32); - let phys = ca.as_ref().fill_null_with_values(fill_value)?; + let phys = ca.physical().fill_null_with_values(fill_value)?; Ok(phys.into_decimal_unchecked(precision, scale).into_series()) }, _ => { diff --git a/crates/polars-core/src/chunked_array/ops/row_encode.rs b/crates/polars-core/src/chunked_array/ops/row_encode.rs index 279fd7eb77f9..3fb61528e948 100644 --- a/crates/polars-core/src/chunked_array/ops/row_encode.rs +++ b/crates/polars-core/src/chunked_array/ops/row_encode.rs @@ -1,11 +1,7 @@ use std::borrow::Cow; use arrow::compute::utils::combine_validities_and_many; -use polars_row::{ - RowEncodingCategoricalContext, RowEncodingContext, RowEncodingOptions, RowsEncoded, - convert_columns, -}; -use polars_utils::itertools::Itertools; +use polars_row::{RowEncodingContext, RowEncodingOptions, RowsEncoded, convert_columns}; use rayon::prelude::*; use crate::POOL; @@ -76,7 +72,7 @@ pub fn encode_rows_vertical_par_unordered_broadcast_nulls( /// /// This should be given the logical type in order to communicate Polars datatype information down /// into the row encoding / decoding. -pub fn get_row_encoding_context(dtype: &DataType, ordered: bool) -> Option { +pub fn get_row_encoding_context(dtype: &DataType) -> Option { match dtype { DataType::Boolean | DataType::UInt8 @@ -99,6 +95,18 @@ pub fn get_row_encoding_context(dtype: &DataType, ordered: bool) -> Option None, + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_, mapping) | DataType::Enum(_, mapping) => { + use polars_row::NewRowEncodingCategoricalContext; + + Some(RowEncodingContext::Categorical( + NewRowEncodingCategoricalContext { + is_enum: matches!(dtype, DataType::Enum(_, _)), + mapping: mapping.clone(), + }, + )) + }, + DataType::Unknown(_) => panic!("Unsupported in row encoding"), #[cfg(feature = "object")] @@ -110,86 +118,14 @@ pub fn get_row_encoding_context(dtype: &DataType, ordered: bool) -> Option get_row_encoding_context(dtype, ordered), - DataType::List(dtype) => get_row_encoding_context(dtype, ordered), - #[cfg(feature = "dtype-categorical")] - DataType::Categorical(revmap, ordering) | DataType::Enum(revmap, ordering) => { - let is_enum = dtype.is_enum(); - let ctx = match revmap { - Some(revmap) => { - let (num_known_categories, lexical_sort_idxs) = match revmap.as_ref() { - RevMapping::Global(map, _, _) => { - let num_known_categories = - map.keys().max().copied().map_or(0, |m| m + 1); - - // @TODO: This should probably be cached. - let lexical_sort_idxs = (ordered - && matches!(ordering, CategoricalOrdering::Lexical)) - .then(|| { - let read_map = crate::STRING_CACHE.read_map(); - let payloads = read_map.get_current_payloads(); - assert!(payloads.len() >= num_known_categories as usize); - - let mut idxs = (0..num_known_categories).collect::>(); - idxs.sort_by_key(|&k| payloads[k as usize].as_str()); - let mut sort_idxs = vec![0; num_known_categories as usize]; - for (i, idx) in idxs.into_iter().enumerate_u32() { - sort_idxs[idx as usize] = i; - } - sort_idxs - }); - - (num_known_categories, lexical_sort_idxs) - }, - RevMapping::Local(values, _) => { - // @TODO: This should probably be cached. - let lexical_sort_idxs = (ordered - && matches!(ordering, CategoricalOrdering::Lexical)) - .then(|| { - assert_eq!(values.null_count(), 0); - let values: Vec<&str> = values.values_iter().collect(); - - let mut idxs = (0..values.len() as u32).collect::>(); - idxs.sort_by_key(|&k| values[k as usize]); - let mut sort_idxs = vec![0; values.len()]; - for (i, idx) in idxs.into_iter().enumerate_u32() { - sort_idxs[idx as usize] = i; - } - sort_idxs - }); - - (values.len() as u32, lexical_sort_idxs) - }, - }; - - RowEncodingCategoricalContext { - num_known_categories, - is_enum, - lexical_sort_idxs, - } - }, - None => { - let num_known_categories = u32::MAX; - - if matches!(ordering, CategoricalOrdering::Lexical) && ordered { - panic!("lexical ordering not yet supported if rev-map not given"); - } - RowEncodingCategoricalContext { - num_known_categories, - is_enum, - lexical_sort_idxs: None, - } - }, - }; - - Some(RowEncodingContext::Categorical(ctx)) - }, + DataType::Array(dtype, _) => get_row_encoding_context(dtype), + DataType::List(dtype) => get_row_encoding_context(dtype), #[cfg(feature = "dtype-struct")] DataType::Struct(fs) => { let mut ctxts = Vec::new(); for (i, f) in fs.iter().enumerate() { - if let Some(ctxt) = get_row_encoding_context(f.dtype(), ordered) { + if let Some(ctxt) = get_row_encoding_context(f.dtype()) { ctxts.reserve(fs.len()); ctxts.extend(std::iter::repeat_n(None, i)); ctxts.push(Some(ctxt)); @@ -204,7 +140,7 @@ pub fn get_row_encoding_context(dtype: &DataType, ordered: bool) -> Option PolarsResult { let by = by.as_materialized_series(); let arr = by.to_physical_repr().rechunk().chunks()[0].to_boxed(); let opt = RowEncodingOptions::new_unsorted(); - let ctxt = get_row_encoding_context(by.dtype(), false); + let ctxt = get_row_encoding_context(by.dtype()); cols.push(arr); opts.push(opt); @@ -274,7 +210,7 @@ pub fn _get_rows_encoded( let by = by.as_materialized_series(); let arr = by.to_physical_repr().rechunk().chunks()[0].to_boxed(); let opt = RowEncodingOptions::new_sorted(*desc, *null_last); - let ctxt = get_row_encoding_context(by.dtype(), true); + let ctxt = get_row_encoding_context(by.dtype()); cols.push(arr); opts.push(opt); diff --git a/crates/polars-core/src/chunked_array/ops/sort/categorical.rs b/crates/polars-core/src/chunked_array/ops/sort/categorical.rs index 865cb0cb3b26..b0d3a962e7ff 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/categorical.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/categorical.rs @@ -1,79 +1,66 @@ +use num_traits::Zero; + use super::*; -impl CategoricalChunked { +impl CategoricalChunked { #[must_use] - pub fn sort_with(&self, options: SortOptions) -> CategoricalChunked { - if self.uses_lexical_ordering() { - let mut vals = self - .physical() - .into_iter() - .zip(self.iter_str()) - .collect_trusted::>(); - - sort_unstable_by_branch(vals.as_mut_slice(), options, |a, b| a.1.cmp(&b.1)); - - let mut cats = Vec::with_capacity(self.len()); - let mut validity = - (self.null_count() > 0).then(|| BitmapBuilder::with_capacity(self.len())); - - if self.null_count() > 0 && !options.nulls_last { - cats.resize(self.null_count(), 0); - if let Some(validity) = &mut validity { - validity.extend_constant(self.null_count(), false); - } - } - - let valid_slice = if options.descending { - &vals[..self.len() - self.null_count()] - } else { - &vals[self.null_count()..] + pub fn sort_with(&self, options: SortOptions) -> CategoricalChunked { + if !self.uses_lexical_ordering() { + let cats = self.physical().sort_with(options); + // SAFETY: we only reordered the indexes so we are still in bounds. + return unsafe { + CategoricalChunked::::from_cats_and_dtype_unchecked(cats, self.dtype().clone()) }; - cats.extend(valid_slice.iter().map(|(idx, _v)| idx.unwrap())); + } + + let mut vals = self + .physical() + .into_iter() + .zip(self.iter_str()) + .collect_trusted::>(); + + sort_unstable_by_branch(vals.as_mut_slice(), options, |a, b| a.1.cmp(&b.1)); + + let mut cats = Vec::with_capacity(self.len()); + let mut validity = + (self.null_count() > 0).then(|| BitmapBuilder::with_capacity(self.len())); + + if self.null_count() > 0 && !options.nulls_last { + cats.resize(self.null_count(), T::Native::zero()); if let Some(validity) = &mut validity { - validity.extend_constant(self.len() - self.null_count(), true); + validity.extend_constant(self.null_count(), false); } + } + + let valid_slice = if options.descending { + &vals[..self.len() - self.null_count()] + } else { + &vals[self.null_count()..] + }; + cats.extend(valid_slice.iter().map(|(idx, _v)| idx.unwrap())); + if let Some(validity) = &mut validity { + validity.extend_constant(self.len() - self.null_count(), true); + } - if self.null_count() > 0 && options.nulls_last { - cats.resize(self.len(), 0); - if let Some(validity) = &mut validity { - validity.extend_constant(self.null_count(), false); - } + if self.null_count() > 0 && options.nulls_last { + cats.resize(self.len(), T::Native::zero()); + if let Some(validity) = &mut validity { + validity.extend_constant(self.null_count(), false); } + } - let cats = PrimitiveArray::::new( - ArrowDataType::UInt32, - cats.into(), - validity.map(|v| v.freeze()), - ); - let cats = UInt32Chunked::from_chunk_iter(self.name().clone(), Some(cats)); + let arr = PrimitiveArray::from_vec(cats).with_validity(validity.map(|v| v.freeze())); + let cats = ChunkedArray::with_chunk(self.name().clone(), arr); - // SAFETY: - // we only reordered the indexes so we are still in bounds - return unsafe { - CategoricalChunked::from_cats_and_rev_map_unchecked( - cats, - self.get_rev_map().clone(), - self.is_enum(), - self.get_ordering(), - ) - }; - } - let cats = self.physical().sort_with(options); - // SAFETY: - // we only reordered the indexes so we are still in bounds + // SAFETY: we only reordered the indexes so we are still in bounds. unsafe { - CategoricalChunked::from_cats_and_rev_map_unchecked( - cats, - self.get_rev_map().clone(), - self.is_enum(), - self.get_ordering(), - ) + CategoricalChunked::::from_cats_and_dtype_unchecked(cats, self.dtype().clone()) } } /// Returned a sorted `ChunkedArray`. #[must_use] - pub fn sort(&self, descending: bool) -> CategoricalChunked { + pub fn sort(&self, descending: bool) -> CategoricalChunked { self.sort_with(SortOptions { nulls_last: false, descending, @@ -133,9 +120,8 @@ impl CategoricalChunked { #[cfg(test)] mod test { use crate::prelude::*; - use crate::{SINGLE_LOCK, disable_string_cache, enable_string_cache}; - fn assert_order(ca: &CategoricalChunked, cmp: &[&str]) { + fn assert_order(ca: &Categorical8Chunked, cmp: &[&str]) { let s = ca.cast(&DataType::String).unwrap(); let ca = s.str().unwrap(); assert_eq!(ca.into_no_null_iter().collect::>(), cmp); @@ -145,34 +131,23 @@ mod test { fn test_cat_lexical_sort() -> PolarsResult<()> { let init = &["c", "b", "a", "d"]; - let _lock = SINGLE_LOCK.lock(); - for use_string_cache in [true, false] { - disable_string_cache(); - if use_string_cache { - enable_string_cache(); - } + let cats = Categories::new( + PlSmallStr::EMPTY, + PlSmallStr::EMPTY, + CategoricalPhysical::U8, + ); + let s = + Series::new(PlSmallStr::EMPTY, init).cast(&DataType::from_categories(cats.clone()))?; + let ca = s.cat8()?; - let s = Series::new(PlSmallStr::EMPTY, init) - .cast(&DataType::Categorical(None, CategoricalOrdering::Lexical))?; - let ca = s.categorical()?; - let ca_lexical = ca.clone(); + let out = ca.sort(false); + assert_order(&out, &["a", "b", "c", "d"]); - let out = ca_lexical.sort(false); - assert_order(&out, &["a", "b", "c", "d"]); - - let s = Series::new(PlSmallStr::EMPTY, init) - .cast(&DataType::Categorical(None, Default::default()))?; - let ca = s.categorical()?; - - let out = ca.sort(false); - assert_order(&out, init); - - let out = ca_lexical.arg_sort(SortOptions { - descending: false, - ..Default::default() - }); - assert_eq!(out.into_no_null_iter().collect::>(), &[2, 1, 0, 3]); - } + let out = ca.arg_sort(SortOptions { + descending: false, + ..Default::default() + }); + assert_eq!(out.into_no_null_iter().collect::>(), &[2, 1, 0, 3]); Ok(()) } @@ -181,41 +156,35 @@ mod test { fn test_cat_lexical_sort_multiple() -> PolarsResult<()> { let init = &["c", "b", "a", "a"]; - let _lock = SINGLE_LOCK.lock(); - for use_string_cache in [true, false] { - disable_string_cache(); - if use_string_cache { - enable_string_cache(); - } + let cats = Categories::new( + PlSmallStr::EMPTY, + PlSmallStr::EMPTY, + CategoricalPhysical::U8, + ); + let series = + Series::new(PlSmallStr::EMPTY, init).cast(&DataType::from_categories(cats.clone()))?; + + let df = df![ + "cat" => &series, + "vals" => [1, 1, 2, 2] + ]?; + + let out = df.sort( + ["cat", "vals"], + SortMultipleOptions::default().with_order_descending_multi([false, false]), + )?; + let out = out.column("cat")?; + let cat = out.as_materialized_series().cat8()?; + assert_order(cat, &["a", "a", "b", "c"]); + + let out = df.sort( + ["vals", "cat"], + SortMultipleOptions::default().with_order_descending_multi([false, false]), + )?; + let out = out.column("cat")?; + let cat = out.as_materialized_series().cat8()?; + assert_order(cat, &["b", "c", "a", "a"]); - let s = Series::new(PlSmallStr::EMPTY, init) - .cast(&DataType::Categorical(None, CategoricalOrdering::Lexical))?; - let ca = s.categorical()?; - let ca_lexical: CategoricalChunked = ca.clone(); - - let series = ca_lexical.into_series(); - - let df = df![ - "cat" => &series, - "vals" => [1, 1, 2, 2] - ]?; - - let out = df.sort( - ["cat", "vals"], - SortMultipleOptions::default().with_order_descending_multi([false, false]), - )?; - let out = out.column("cat")?; - let cat = out.as_materialized_series().categorical()?; - assert_order(cat, &["a", "a", "b", "c"]); - - let out = df.sort( - ["vals", "cat"], - SortMultipleOptions::default().with_order_descending_multi([false, false]), - )?; - let out = out.column("cat")?; - let cat = out.as_materialized_series().categorical()?; - assert_order(cat, &["b", "c", "a", "a"]); - } Ok(()) } } diff --git a/crates/polars-core/src/chunked_array/temporal/date.rs b/crates/polars-core/src/chunked_array/temporal/date.rs index 69d7fad4c6b1..58fa74fc92cf 100644 --- a/crates/polars-core/src/chunked_array/temporal/date.rs +++ b/crates/polars-core/src/chunked_array/temporal/date.rs @@ -15,7 +15,8 @@ impl DateChunked { pub fn as_date_iter(&self) -> impl TrustedLen> + '_ { // SAFETY: we know the iterators len unsafe { - self.downcast_iter() + self.physical() + .downcast_iter() .flat_map(|iter| { iter.into_iter() .map(|opt_v| opt_v.copied().map(date32_to_date)) @@ -39,11 +40,12 @@ impl DateChunked { format }; let datefmt_f = |ndt: NaiveDate| ndt.format(format); - self.try_apply_into_string_amortized(|val, buf| { - let ndt = date32_to_date(val); - write!(buf, "{}", datefmt_f(ndt)) - }) - .map_err(|_| polars_err!(ComputeError: "cannot format Date with format '{}'", format)) + self.physical() + .try_apply_into_string_amortized(|val, buf| { + let ndt = date32_to_date(val); + write!(buf, "{}", datefmt_f(ndt)) + }) + .map_err(|_| polars_err!(ComputeError: "cannot format Date with format '{}'", format)) } /// Convert from Date into String with the given format. diff --git a/crates/polars-core/src/chunked_array/temporal/datetime.rs b/crates/polars-core/src/chunked_array/temporal/datetime.rs index 5512c3f0ce38..ff50c6d1c1f3 100644 --- a/crates/polars-core/src/chunked_array/temporal/datetime.rs +++ b/crates/polars-core/src/chunked_array/temporal/datetime.rs @@ -19,7 +19,8 @@ impl DatetimeChunked { }; // we know the iterators len unsafe { - self.downcast_iter() + self.physical() + .downcast_iter() .flat_map(move |iter| iter.into_iter().map(move |opt_v| opt_v.copied().map(func))) .trust_my_length(self.len()) } @@ -53,7 +54,7 @@ impl DatetimeChunked { Some(time_zone) => { let parsed_time_zone = time_zone.parse::().expect("already validated"); let datefmt_f = |ndt| parsed_time_zone.from_utc_datetime(&ndt).format(&format); - self.try_apply_into_string_amortized(|val, buf| { + self.physical().try_apply_into_string_amortized(|val, buf| { let ndt = conversion_f(val); write!(buf, "{}", datefmt_f(ndt)) } @@ -63,7 +64,7 @@ impl DatetimeChunked { }, _ => { let datefmt_f = |ndt: NaiveDateTime| ndt.format(&format); - self.try_apply_into_string_amortized(|val, buf| { + self.physical().try_apply_into_string_amortized(|val, buf| { let ndt = conversion_f(val); write!(buf, "{}", datefmt_f(ndt)) } @@ -215,7 +216,7 @@ mod test { 1_441_497_364_000_000_000, 1_356_048_000_000_000_000 ], - dt.cont_slice().unwrap() + dt.physical().cont_slice().unwrap() ); } } diff --git a/crates/polars-core/src/chunked_array/temporal/time.rs b/crates/polars-core/src/chunked_array/temporal/time.rs index fe0ca0270c1d..a255c5737c59 100644 --- a/crates/polars-core/src/chunked_array/temporal/time.rs +++ b/crates/polars-core/src/chunked_array/temporal/time.rs @@ -21,7 +21,7 @@ impl TimeChunked { /// Convert from Time into String with the given format. /// See [chrono strftime/strptime](https://docs.rs/chrono/0.4.19/chrono/format/strftime/index.html). pub fn to_string(&self, format: &str) -> StringChunked { - let mut ca: StringChunked = self.apply_kernel_cast(&|arr| { + let mut ca: StringChunked = self.physical().apply_kernel_cast(&|arr| { let mut buf = String::new(); let format = if format == "iso" || format == "iso:strict" { "%T%.9f" @@ -60,7 +60,8 @@ impl TimeChunked { pub fn as_time_iter(&self) -> impl TrustedLen> + '_ { // we know the iterators len unsafe { - self.downcast_iter() + self.physical() + .downcast_iter() .flat_map(|iter| { iter.into_iter() .map(|opt_v| opt_v.copied().map(time64ns_to_time)) diff --git a/crates/polars-core/src/datatypes/_serde.rs b/crates/polars-core/src/datatypes/_serde.rs index 29de814b4ac5..985e2764ab33 100644 --- a/crates/polars-core/src/datatypes/_serde.rs +++ b/crates/polars-core/src/datatypes/_serde.rs @@ -4,6 +4,7 @@ //! We could use [serde_1712](https://github.com/serde-rs/serde/issues/1712), but that gave problems caused by //! [rust_96956](https://github.com/rust-lang/rust/issues/96956), so we make a dummy type without static +use polars_dtype::categorical::CategoricalPhysical; #[cfg(feature = "dtype-categorical")] use serde::de::SeqAccess; use serde::{Deserialize, Serialize}; @@ -127,11 +128,17 @@ enum SerializableDataType { // some logical types we cannot know statically, e.g. Datetime Unknown(UnknownKind), #[cfg(feature = "dtype-categorical")] - Categorical(Option, CategoricalOrdering), + Categorical { + name: String, + namespace: String, + physical: CategoricalPhysical, + }, + #[cfg(feature = "dtype-categorical")] + Enum { + strings: Series, + }, #[cfg(feature = "dtype-decimal")] Decimal(Option, Option), - #[cfg(feature = "dtype-categorical")] - Enum(Option, CategoricalOrdering), #[cfg(feature = "object")] Object(String), } @@ -167,25 +174,19 @@ impl From<&DataType> for SerializableDataType { #[cfg(feature = "dtype-struct")] Struct(flds) => Self::Struct(flds.clone()), #[cfg(feature = "dtype-categorical")] - Categorical(Some(rev_map), ordering) => Self::Categorical( - Some( - StringChunked::with_chunk(PlSmallStr::EMPTY, rev_map.get_categories().clone()) - .into_series(), - ), - *ordering, - ), - #[cfg(feature = "dtype-categorical")] - Categorical(None, ordering) => Self::Categorical(None, *ordering), + Categorical(cats, _) => Self::Categorical { + name: cats.name().to_string(), + namespace: cats.namespace().to_string(), + physical: cats.physical(), + }, #[cfg(feature = "dtype-categorical")] - Enum(Some(rev_map), ordering) => Self::Enum( - Some( - StringChunked::with_chunk(PlSmallStr::EMPTY, rev_map.get_categories().clone()) - .into_series(), - ), - *ordering, - ), - #[cfg(feature = "dtype-categorical")] - Enum(None, ordering) => Self::Enum(None, *ordering), + Enum(fcats, _) => Self::Enum { + strings: StringChunked::with_chunk( + PlSmallStr::from_static("categories"), + fcats.categories().clone(), + ) + .into_series(), + }, #[cfg(feature = "dtype-decimal")] Decimal(precision, scale) => Self::Decimal(*precision, *scale), #[cfg(feature = "object")] @@ -224,28 +225,26 @@ impl From for DataType { #[cfg(feature = "dtype-struct")] Struct(flds) => Self::Struct(flds), #[cfg(feature = "dtype-categorical")] - Categorical(Some(categories), ordering) => Self::Categorical( - Some(Arc::new(RevMapping::build_local( - categories.0.rechunk().chunks()[0] - .as_any() - .downcast_ref::() - .unwrap() - .clone(), - ))), - ordering, - ), - #[cfg(feature = "dtype-categorical")] - Categorical(None, ordering) => Self::Categorical(None, ordering), - #[cfg(feature = "dtype-categorical")] - Enum(Some(categories), _) => create_enum_dtype( - categories.rechunk().chunks()[0] - .as_any() - .downcast_ref::() - .unwrap() - .clone(), - ), + Categorical { + name, + namespace, + physical, + } => { + let cats = Categories::new( + PlSmallStr::from(name), + PlSmallStr::from(namespace), + physical, + ); + let mapping = cats.mapping(); + Self::Categorical(cats, mapping) + }, #[cfg(feature = "dtype-categorical")] - Enum(None, ordering) => Self::Enum(None, ordering), + Enum { strings } => { + let ca = strings.str().unwrap(); + let fcats = FrozenCategories::new(ca.iter().flatten()).unwrap(); + let mapping = fcats.mapping().clone(); + Self::Enum(fcats, mapping) + }, #[cfg(feature = "dtype-decimal")] Decimal(precision, scale) => Self::Decimal(precision, scale), #[cfg(feature = "object")] diff --git a/crates/polars-core/src/datatypes/any_value.rs b/crates/polars-core/src/datatypes/any_value.rs index 1f9c3d931acd..7a3907c8387a 100644 --- a/crates/polars-core/src/datatypes/any_value.rs +++ b/crates/polars-core/src/datatypes/any_value.rs @@ -4,8 +4,6 @@ use std::borrow::Cow; use arrow::types::PrimitiveType; use polars_compute::cast::SerPrimitive; use polars_error::feature_gated; -#[cfg(feature = "dtype-categorical")] -use polars_utils::sync::SyncPtr; use polars_utils::total_ord::ToTotalOrd; use super::*; @@ -72,18 +70,14 @@ pub enum AnyValue<'a> { /// A 64-bit time representing the elapsed time since midnight in nanoseconds #[cfg(feature = "dtype-time")] Time(i64), - // If syncptr is_null the data is in the rev-map - // otherwise it is in the array pointer #[cfg(feature = "dtype-categorical")] - Categorical(u32, &'a RevMapping, SyncPtr), - // If syncptr is_null the data is in the rev-map - // otherwise it is in the array pointer + Categorical(CatSize, &'a Arc), #[cfg(feature = "dtype-categorical")] - CategoricalOwned(u32, Arc, SyncPtr), + CategoricalOwned(CatSize, Arc), #[cfg(feature = "dtype-categorical")] - Enum(u32, &'a RevMapping, SyncPtr), + Enum(CatSize, &'a Arc), #[cfg(feature = "dtype-categorical")] - EnumOwned(u32, Arc, SyncPtr), + EnumOwned(CatSize, Arc), /// Nested type, contains arrays that are filled with one of the datatypes. List(Series), #[cfg(feature = "dtype-array")] @@ -391,11 +385,11 @@ impl<'a> AnyValue<'a> { #[cfg(feature = "dtype-duration")] Duration(_, tu) => DataType::Duration(*tu), #[cfg(feature = "dtype-categorical")] - Categorical(_, _, _) | CategoricalOwned(_, _, _) => { - DataType::Categorical(None, Default::default()) + Categorical(_, _) | CategoricalOwned(_, _) => { + unimplemented!("can not get dtype of Categorical AnyValue") }, #[cfg(feature = "dtype-categorical")] - Enum(_, _, _) | EnumOwned(_, _, _) => DataType::Enum(None, Default::default()), + Enum(_, _) | EnumOwned(_, _) => unimplemented!("can not get dtype of Enum AnyValue"), List(s) => DataType::List(Box::new(s.dtype().clone())), #[cfg(feature = "dtype-array")] Array(s, size) => DataType::Array(Box::new(s.dtype().clone()), *size), @@ -557,6 +551,67 @@ impl<'a> AnyValue<'a> { (AnyValue::Float32(v), DataType::Boolean) => AnyValue::Boolean(*v != f32::default()), (AnyValue::Float64(v), DataType::Boolean) => AnyValue::Boolean(*v != f64::default()), + // Categorical casts. + #[cfg(feature = "dtype-categorical")] + ( + &AnyValue::Categorical(cat, &ref lmap) | &AnyValue::CategoricalOwned(cat, ref lmap), + DataType::Categorical(_, rmap), + ) => { + if Arc::ptr_eq(lmap, rmap) { + self.clone() + } else { + let s = unsafe { lmap.cat_to_str_unchecked(cat) }; + let new_cat = rmap.insert_cat(s).unwrap(); + AnyValue::CategoricalOwned(new_cat, rmap.clone()) + } + }, + + #[cfg(feature = "dtype-categorical")] + ( + &AnyValue::Enum(cat, &ref lmap) | &AnyValue::EnumOwned(cat, ref lmap), + DataType::Enum(_, rmap), + ) => { + if Arc::ptr_eq(lmap, rmap) { + self.clone() + } else { + let s = unsafe { lmap.cat_to_str_unchecked(cat) }; + let new_cat = rmap.get_cat(s)?; + AnyValue::EnumOwned(new_cat, rmap.clone()) + } + }, + + #[cfg(feature = "dtype-categorical")] + ( + &AnyValue::Categorical(cat, &ref map) + | &AnyValue::CategoricalOwned(cat, ref map) + | &AnyValue::Enum(cat, &ref map) + | &AnyValue::EnumOwned(cat, ref map), + DataType::String, + ) => { + let s = unsafe { map.cat_to_str_unchecked(cat) }; + AnyValue::StringOwned(PlSmallStr::from(s)) + }, + + #[cfg(feature = "dtype-categorical")] + (AnyValue::String(s), DataType::Categorical(_, map)) => { + AnyValue::CategoricalOwned(map.insert_cat(s).unwrap(), map.clone()) + }, + + #[cfg(feature = "dtype-categorical")] + (AnyValue::StringOwned(s), DataType::Categorical(_, map)) => { + AnyValue::CategoricalOwned(map.insert_cat(s).unwrap(), map.clone()) + }, + + #[cfg(feature = "dtype-categorical")] + (AnyValue::String(s), DataType::Enum(_, map)) => { + AnyValue::CategoricalOwned(map.get_cat(s)?, map.clone()) + }, + + #[cfg(feature = "dtype-categorical")] + (AnyValue::StringOwned(s), DataType::Enum(_, map)) => { + AnyValue::CategoricalOwned(map.get_cat(s)?, map.clone()) + }, + // to string (AnyValue::String(v), DataType::String) => AnyValue::String(v), (AnyValue::StringOwned(v), DataType::String) => AnyValue::StringOwned(v.clone()), @@ -737,20 +792,12 @@ impl<'a> AnyValue<'a> { Self::StringOwned(s) => Cow::Owned(s.to_string()), Self::Null => Cow::Borrowed("null"), #[cfg(feature = "dtype-categorical")] - Self::Categorical(idx, rev, arr) | AnyValue::Enum(idx, rev, arr) => { - if arr.is_null() { - Cow::Borrowed(rev.get(*idx)) - } else { - unsafe { Cow::Borrowed(arr.deref_unchecked().value(*idx as usize)) } - } + Self::Categorical(cat, map) | Self::Enum(cat, map) => { + Cow::Borrowed(unsafe { map.cat_to_str_unchecked(*cat) }) }, #[cfg(feature = "dtype-categorical")] - Self::CategoricalOwned(idx, rev, arr) | AnyValue::EnumOwned(idx, rev, arr) => { - if arr.is_null() { - Cow::Owned(rev.get(*idx).to_string()) - } else { - unsafe { Cow::Borrowed(arr.deref_unchecked().value(*idx as usize)) } - } + Self::CategoricalOwned(cat, map) | Self::EnumOwned(cat, map) => { + Cow::Owned(unsafe { map.cat_to_str_unchecked(*cat) }.to_owned()) }, av => Cow::Owned(av.to_string()), } @@ -790,10 +837,10 @@ impl<'a> AnyValue<'a> { Self::Time(v) => Self::Int64(v), #[cfg(feature = "dtype-categorical")] - Self::Categorical(v, _, _) - | Self::CategoricalOwned(v, _, _) - | Self::Enum(v, _, _) - | Self::EnumOwned(v, _, _) => Self::UInt32(v), + Self::Categorical(v, _) + | Self::CategoricalOwned(v, _) + | Self::Enum(v, _) + | Self::EnumOwned(v, _) => Self::UInt32(v), Self::List(series) => Self::List(series.to_physical_repr().into_owned()), #[cfg(feature = "dtype-array")] @@ -915,10 +962,9 @@ impl AnyValue<'_> { #[cfg(feature = "dtype-time")] Time(v) => v.hash(state), #[cfg(feature = "dtype-categorical")] - Categorical(v, _, _) - | CategoricalOwned(v, _, _) - | Enum(v, _, _) - | EnumOwned(v, _, _) => v.hash(state), + Categorical(v, _) | CategoricalOwned(v, _) | Enum(v, _) | EnumOwned(v, _) => { + v.hash(state) + }, #[cfg(feature = "object")] Object(_) => {}, #[cfg(feature = "object")] @@ -1066,11 +1112,9 @@ impl<'a> AnyValue<'a> { AnyValue::Datetime(*v, *tu, tz.as_ref().map(AsRef::as_ref)) }, #[cfg(feature = "dtype-categorical")] - AnyValue::CategoricalOwned(v, rev, arr) => { - AnyValue::Categorical(*v, rev.as_ref(), *arr) - }, + AnyValue::CategoricalOwned(cat, map) => AnyValue::Categorical(*cat, map), #[cfg(feature = "dtype-categorical")] - AnyValue::EnumOwned(v, rev, arr) => AnyValue::Enum(*v, rev.as_ref(), *arr), + AnyValue::EnumOwned(cat, map) => AnyValue::Enum(*cat, map), av => av.clone(), } } @@ -1133,13 +1177,13 @@ impl<'a> AnyValue<'a> { #[cfg(feature = "dtype-decimal")] Decimal(val, scale) => Decimal(val, scale), #[cfg(feature = "dtype-categorical")] - Categorical(v, rev, arr) => CategoricalOwned(v, Arc::new(rev.clone()), arr), + Categorical(cat, map) => CategoricalOwned(cat, map.clone()), #[cfg(feature = "dtype-categorical")] - CategoricalOwned(v, rev, arr) => CategoricalOwned(v, rev, arr), + CategoricalOwned(cat, map) => CategoricalOwned(cat, map), #[cfg(feature = "dtype-categorical")] - Enum(v, rev, arr) => EnumOwned(v, Arc::new(rev.clone()), arr), + Enum(cat, map) => EnumOwned(cat, map.clone()), #[cfg(feature = "dtype-categorical")] - EnumOwned(v, rev, arr) => EnumOwned(v, rev, arr), + EnumOwned(cat, map) => EnumOwned(cat, map), } } @@ -1149,22 +1193,12 @@ impl<'a> AnyValue<'a> { AnyValue::String(s) => Some(s), AnyValue::StringOwned(s) => Some(s.as_str()), #[cfg(feature = "dtype-categorical")] - AnyValue::Categorical(idx, rev, arr) | AnyValue::Enum(idx, rev, arr) => { - let s = if arr.is_null() { - rev.get(*idx) - } else { - unsafe { arr.deref_unchecked().value(*idx as usize) } - }; - Some(s) + Self::Categorical(cat, map) | Self::Enum(cat, map) => { + Some(unsafe { map.cat_to_str_unchecked(*cat) }) }, #[cfg(feature = "dtype-categorical")] - AnyValue::CategoricalOwned(idx, rev, arr) | AnyValue::EnumOwned(idx, rev, arr) => { - let s = if arr.is_null() { - rev.get(*idx) - } else { - unsafe { arr.deref_unchecked().value(*idx as usize) } - }; - Some(s) + Self::CategoricalOwned(cat, map) | Self::EnumOwned(cat, map) => { + Some(unsafe { map.cat_to_str_unchecked(*cat) }) }, _ => None, } @@ -1238,13 +1272,13 @@ impl AnyValue<'_> { *l == Datetime(*rv, *rtu, rtz.as_ref().map(|v| v.as_ref())) }, #[cfg(feature = "dtype-categorical")] - (CategoricalOwned(lv, lrev, larr), r) => Categorical(*lv, lrev.as_ref(), *larr) == *r, + (CategoricalOwned(cat, map), r) => Categorical(*cat, map) == *r, #[cfg(feature = "dtype-categorical")] - (l, CategoricalOwned(rv, rrev, rarr)) => *l == Categorical(*rv, rrev.as_ref(), *rarr), + (l, CategoricalOwned(cat, map)) => *l == Categorical(*cat, map), #[cfg(feature = "dtype-categorical")] - (EnumOwned(lv, lrev, larr), r) => Enum(*lv, lrev.as_ref(), *larr) == *r, + (EnumOwned(cat, map), r) => Enum(*cat, map) == *r, #[cfg(feature = "dtype-categorical")] - (l, EnumOwned(rv, rrev, rarr)) => *l == Enum(*rv, rrev.as_ref(), *rarr), + (l, EnumOwned(cat, map)) => *l == Enum(*cat, map), // Comparison with null. (Null, Null) => null_equal, @@ -1276,26 +1310,28 @@ impl AnyValue<'_> { }, (List(l), List(r)) => l == r, #[cfg(feature = "dtype-categorical")] - (Categorical(idx_l, rev_l, ptr_l), Categorical(idx_r, rev_r, ptr_r)) => { - if !same_revmap(rev_l, *ptr_l, rev_r, *ptr_r) { + (Categorical(cat_l, map_l), Categorical(cat_r, map_r)) => { + if !Arc::ptr_eq(map_l, map_r) { // We can't support this because our Hash impl directly hashes the index. If you // add support for this we must change the Hash impl. unimplemented!( - "comparing categoricals with different revmaps is not supported" + "comparing categoricals with different Categories is not supported through AnyValue" ); } - idx_l == idx_r + cat_l == cat_r }, #[cfg(feature = "dtype-categorical")] - (Enum(idx_l, rev_l, ptr_l), Enum(idx_r, rev_r, ptr_r)) => { - // We can't support this because our Hash impl directly hashes the index. If you - // add support for this we must change the Hash impl. - if !same_revmap(rev_l, *ptr_l, rev_r, *ptr_r) { - unimplemented!("comparing enums with different revmaps is not supported"); + (Enum(cat_l, map_l), Enum(cat_r, map_r)) => { + if !Arc::ptr_eq(map_l, map_r) { + // We can't support this because our Hash impl directly hashes the index. If you + // add support for this we must change the Hash impl. + unimplemented!( + "comparing enums with different FrozenCategories is not supported through AnyValue" + ); } - idx_l == idx_r + cat_l == cat_r }, #[cfg(feature = "dtype-duration")] (Duration(l, tu_l), Duration(r, tu_r)) => l == r && tu_l == tu_r, @@ -1416,17 +1452,13 @@ impl PartialOrd for AnyValue<'_> { l.partial_cmp(&Datetime(*rv, *rtu, rtz.as_ref().map(|v| v.as_ref()))) }, #[cfg(feature = "dtype-categorical")] - (CategoricalOwned(lv, lrev, larr), r) => { - Categorical(*lv, lrev.as_ref(), *larr).partial_cmp(r) - }, + (CategoricalOwned(cat, map), r) => Categorical(*cat, map).partial_cmp(r), #[cfg(feature = "dtype-categorical")] - (l, CategoricalOwned(rv, rrev, rarr)) => { - l.partial_cmp(&Categorical(*rv, rrev.as_ref(), *rarr)) - }, + (l, CategoricalOwned(cat, map)) => l.partial_cmp(&Categorical(*cat, map)), #[cfg(feature = "dtype-categorical")] - (EnumOwned(lv, lrev, larr), r) => Enum(*lv, lrev.as_ref(), *larr).partial_cmp(r), + (EnumOwned(cat, map), r) => Enum(*cat, map).partial_cmp(r), #[cfg(feature = "dtype-categorical")] - (l, EnumOwned(rv, rrev, rarr)) => l.partial_cmp(&Enum(*rv, rrev.as_ref(), *rarr)), + (l, EnumOwned(cat, map)) => l.partial_cmp(&Enum(*cat, map)), // Comparison with null. (Null, Null) => Some(Ordering::Equal), @@ -1471,14 +1503,17 @@ impl PartialOrd for AnyValue<'_> { #[cfg(feature = "dtype-time")] (Time(l), Time(r)) => l.partial_cmp(r), #[cfg(feature = "dtype-categorical")] - (Categorical(..), Categorical(..)) => { - unimplemented!( - "can't order categoricals as AnyValues, dtype for ordering is needed" - ) + (Categorical(l_cat, l_map), Categorical(r_cat, r_map)) => unsafe { + let l_str = l_map.cat_to_str_unchecked(*l_cat); + let r_str = r_map.cat_to_str_unchecked(*r_cat); + l_str.partial_cmp(r_str) }, #[cfg(feature = "dtype-categorical")] - (Enum(..), Enum(..)) => { - unimplemented!("can't order enums as AnyValues, dtype for ordering is needed") + (Enum(l_cat, l_map), Enum(r_cat, r_map)) => { + if !Arc::ptr_eq(l_map, r_map) { + unimplemented!("can't order enums from different FrozenCategories") + } + l_cat.partial_cmp(r_cat) }, (List(_), List(_)) => { unimplemented!("ordering for List dtype is not supported") @@ -1561,24 +1596,6 @@ fn struct_to_avs_static(idx: usize, arr: &StructArray, fields: &[Field]) -> Vec< .collect() } -#[cfg(feature = "dtype-categorical")] -fn same_revmap( - rev_l: &RevMapping, - ptr_l: SyncPtr, - rev_r: &RevMapping, - ptr_r: SyncPtr, -) -> bool { - if ptr_l.is_null() && ptr_r.is_null() { - match (rev_l, rev_r) { - (RevMapping::Global(_, _, id_l), RevMapping::Global(_, _, id_r)) => id_l == id_r, - (RevMapping::Local(_, id_l), RevMapping::Local(_, id_r)) => id_l == id_r, - _ => false, - } - } else { - ptr_l == ptr_r - } -} - pub trait GetAnyValue { /// # Safety /// diff --git a/crates/polars-core/src/datatypes/categories/mod.rs b/crates/polars-core/src/datatypes/categories/mod.rs deleted file mode 100644 index 3a45c136be0a..000000000000 --- a/crates/polars-core/src/datatypes/categories/mod.rs +++ /dev/null @@ -1,275 +0,0 @@ -use std::hash::{BuildHasher, Hasher}; -use std::sync::{Arc, LazyLock, Mutex, Weak}; - -use arrow::array::builder::StaticArrayBuilder; -use arrow::array::{Utf8ViewArray, Utf8ViewArrayBuilder}; -use hashbrown::HashTable; -use hashbrown::hash_table::Entry; -use polars_error::{PolarsResult, polars_ensure}; -use polars_utils::pl_str::PlSmallStr; -use uuid::Uuid; - -use crate::prelude::*; - -mod mapping; - -pub use mapping::CategoricalMapping; - -/// The physical datatype backing a categorical / enum. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum CategoricalPhysical { - U8, - U16, - U32, -} - -impl CategoricalPhysical { - pub fn dtype(&self) -> DataType { - match self { - CategoricalPhysical::U8 => DataType::UInt8, - CategoricalPhysical::U16 => DataType::UInt16, - CategoricalPhysical::U32 => DataType::UInt32, - } - } - - pub fn max_categories(&self) -> usize { - // We might use T::MAX as an indicator, so the maximum number of categories is T::MAX - // (giving T::MAX - 1 as the largest category). - match self { - CategoricalPhysical::U8 => u8::MAX as usize, - CategoricalPhysical::U16 => u16::MAX as usize, - CategoricalPhysical::U32 => u32::MAX as usize, - } - } -} - -// Used to maintain a 1:1 mapping between Categories' UUID and the Categories objects themselves. -// This is important for serialization. -static CATEGORIES_REGISTRY: LazyLock>>> = - LazyLock::new(|| Mutex::new(PlHashMap::new())); - -// Used to make FrozenCategories unique based on their content. This allows comparison of datatypes -// in constant time by comparing pointers. -#[expect(clippy::type_complexity)] -static FROZEN_CATEGORIES_REGISTRY: LazyLock)>>> = - LazyLock::new(|| Mutex::new(HashTable::new())); - -static FROZEN_CATEGORIES_HASHER: LazyLock = - LazyLock::new(PlSeedableRandomStateQuality::random); - -static GLOBAL_CATEGORIES: LazyLock> = LazyLock::new(|| { - let categories = Arc::new(Categories { - name: PlSmallStr::from_static("__POLARS_GLOBAL_CATEGORIES"), - physical: CategoricalPhysical::U32, - uuid: Uuid::nil(), - mapping: MaybeGcMapping::Gc(Mutex::new(Weak::new())), - }); - CATEGORIES_REGISTRY - .lock() - .unwrap() - .insert(Uuid::nil(), Arc::downgrade(&categories)); - categories -}); - -/// A (named) object which is used to indicate which categorical data types -/// have the same mapping. The underlying mapping is dynamic, and if gc is true -/// may be automatically cleared when the last reference to it goes away. -pub struct Categories { - name: PlSmallStr, - physical: CategoricalPhysical, - uuid: Uuid, - mapping: MaybeGcMapping, -} - -enum MaybeGcMapping { - Gc(Mutex>), - Persistent(Arc), -} - -impl Categories { - /// Creates a new Categories object with the given name and physical type. - /// - /// If gc is true the underlying categories will automatically get cleaned - /// up when the last CategoricalMapping reference goes away, otherwise they - /// are persistent. - pub fn new(name: PlSmallStr, physical: CategoricalPhysical, gc: bool) -> Arc { - Self::new_with_registry(name, physical, gc, &mut CATEGORIES_REGISTRY.lock().unwrap()) - } - - /// Returns the Categories object with the given UUID. If the UUID is unknown a new one is created. - pub fn from_uuid( - name: PlSmallStr, - physical: CategoricalPhysical, - gc: bool, - uuid: Uuid, - ) -> Arc { - if uuid.is_nil() { - return Self::global(); - } - - let mut registry = CATEGORIES_REGISTRY.lock().unwrap(); - if let Some(cats_ref) = registry.get(&uuid) { - if let Some(cats) = cats_ref.upgrade() { - assert!( - cats.name == name, - "UUID already exists with a different name" - ); - assert!( - cats.physical == physical, - "UUID already exists with a different physical type" - ); - return cats; - } - } - Self::new_with_registry(name, physical, gc, &mut registry) - } - - /// Returns the global Categories. - pub fn global() -> Arc { - GLOBAL_CATEGORIES.clone() - } - - fn new_with_registry( - name: PlSmallStr, - physical: CategoricalPhysical, - gc: bool, - registry: &mut PlHashMap>, - ) -> Arc { - let uuid = Uuid::new_v4(); - - let mapping = if gc { - MaybeGcMapping::Gc(Mutex::new(Weak::new())) - } else { - MaybeGcMapping::Persistent(Arc::new(CategoricalMapping::new(physical.max_categories()))) - }; - - let slf = Arc::new(Self { - name, - physical, - uuid, - mapping, - }); - registry.insert(uuid, Arc::downgrade(&slf)); - slf - } - - /// The name of this Categories object (not unique). - pub fn name(&self) -> &PlSmallStr { - &self.name - } - - /// The mapping for this Categories object. If no mapping currently exists - /// it creates a new empty mapping. - pub fn mapping(&self) -> Arc { - match &self.mapping { - MaybeGcMapping::Gc(weak) => { - let mut guard = weak.lock().unwrap(); - if let Some(arc) = guard.upgrade() { - return arc; - } - let arc = Arc::new(CategoricalMapping::new(self.physical.max_categories())); - *guard = Arc::downgrade(&arc); - arc - }, - MaybeGcMapping::Persistent(arc) => arc.clone(), - } - } - - pub fn freeze(&self, physical: CategoricalPhysical) -> Arc { - let mapping = self.mapping(); - let n = mapping.num_cats_upper_bound(); - FrozenCategories::new(physical, (0..n).flat_map(|i| mapping.cat_to_str(i as u32))).unwrap() - } -} - -impl Drop for Categories { - fn drop(&mut self) { - CATEGORIES_REGISTRY.lock().unwrap().remove(&self.uuid); - } -} - -/// An ordered collection of unique strings with an associated pre-computed -/// mapping to go from string <-> index. -/// -/// FrozenCategories are globally unique to facilitate constant-time comparison. -pub struct FrozenCategories { - physical: CategoricalPhysical, - combined_hash: u64, - categories: Utf8ViewArray, - mapping: Arc, -} - -impl FrozenCategories { - /// Creates a new FrozenCategories object (or returns a reference to an existing one - /// in case these are already known). Returns an error if the categories are not unique. - /// It is guaranteed that the nth string ends up with category n (0-indexed). - pub fn new<'a, I: Iterator>( - physical: CategoricalPhysical, - strings: I, - ) -> PolarsResult> { - let hasher = *FROZEN_CATEGORIES_HASHER; - let mut mapping = CategoricalMapping::with_hasher(physical.max_categories(), hasher); - let mut builder = Utf8ViewArrayBuilder::new(ArrowDataType::Utf8); - builder.reserve(strings.size_hint().0); - - let hasher = PlFixedStateQuality::default(); - let mut combined_hasher = hasher.build_hasher(); - for s in strings { - let hash = hasher.hash_one(s); - combined_hasher.write_u64(hash); - mapping.insert_cat_with_hash(s, hash)?; - builder.push_value_ignore_validity(s); - polars_ensure!(mapping.len() == builder.len(), ComputeError: "FrozenCategories must contain unique strings; found duplicate '{s}'"); - } - - let combined_hash = combined_hasher.finish(); - let categories = builder.freeze(); - - let mut registry = FROZEN_CATEGORIES_REGISTRY.lock().unwrap(); - let mut last_compared = None; // We have to store the strong reference to avoid a race condition. - match registry.entry( - combined_hash, - |(hash, weak)| { - *hash == combined_hash && { - if let Some(frozen_cats) = weak.upgrade() { - let cmp = frozen_cats.categories == categories - && frozen_cats.physical == physical; - last_compared = Some(frozen_cats); - cmp - } else { - false - } - } - }, - |(hash, _weak)| *hash, - ) { - Entry::Occupied(_) => Ok(last_compared.unwrap()), - Entry::Vacant(v) => { - let slf = Arc::new(Self { - physical, - combined_hash, - categories, - mapping: Arc::new(mapping), - }); - v.insert((combined_hash, Arc::downgrade(&slf))); - Ok(slf) - }, - } - } - - /// The mapping for this FrozenCategories object. - pub fn mapping(&self) -> &Arc { - &self.mapping - } -} - -impl Drop for FrozenCategories { - fn drop(&mut self) { - let mut registry = FROZEN_CATEGORIES_REGISTRY.lock().unwrap(); - while let Ok(entry) = - registry.find_entry(self.combined_hash, |(_, weak)| weak.strong_count() == 0) - { - entry.remove(); - } - } -} diff --git a/crates/polars-core/src/datatypes/dtype.rs b/crates/polars-core/src/datatypes/dtype.rs index aff0080723f7..9151147aeb62 100644 --- a/crates/polars-core/src/datatypes/dtype.rs +++ b/crates/polars-core/src/datatypes/dtype.rs @@ -1,12 +1,13 @@ use std::collections::BTreeMap; -use arrow::datatypes::{DTYPE_CATEGORICAL, DTYPE_ENUM_VALUES, Metadata}; +use arrow::datatypes::{ + DTYPE_CATEGORICAL_NEW, DTYPE_ENUM_VALUES_LEGACY, DTYPE_ENUM_VALUES_NEW, Metadata, +}; #[cfg(feature = "dtype-array")] use polars_utils::format_tuple; use polars_utils::itertools::Itertools; #[cfg(any(feature = "serde-lazy", feature = "serde"))] use serde::{Deserialize, Serialize}; -use strum_macros::IntoStaticStr; pub use temporal::time_zone::TimeZone; use super::*; @@ -18,18 +19,20 @@ static MAINTAIN_PL_TYPE: &str = "maintain_type"; static PL_KEY: &str = "pl"; pub trait MetaDataExt: IntoMetadata { - fn is_enum(&self) -> bool { - let metadata = self.into_metadata_ref(); - metadata.get(DTYPE_ENUM_VALUES).is_some() - } - - fn categorical(&self) -> Option { - let metadata = self.into_metadata_ref(); - match metadata.get(DTYPE_CATEGORICAL)?.as_str() { - "lexical" => Some(CategoricalOrdering::Lexical), - // Default is Physical - _ => Some(CategoricalOrdering::Physical), - } + fn pl_enum_metadata(&self) -> Option<&str> { + let md = self.into_metadata_ref(); + let values = md + .get(DTYPE_ENUM_VALUES_NEW) + .or_else(|| md.get(DTYPE_ENUM_VALUES_LEGACY)); + Some(values?.as_str()) + } + + fn pl_categorical_metadata(&self) -> Option<&str> { + Some( + self.into_metadata_ref() + .get(DTYPE_CATEGORICAL_NEW)? + .as_str(), + ) } fn maintain_type(&self) -> bool { @@ -78,19 +81,6 @@ impl UnknownKind { } } -#[derive(Debug, Copy, Clone, PartialEq, Default, IntoStaticStr)] -#[cfg_attr( - any(feature = "serde-lazy", feature = "serde"), - derive(Serialize, Deserialize) -)] -#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))] -#[strum(serialize_all = "snake_case")] -pub enum CategoricalOrdering { - #[default] - Physical, - Lexical, -} - #[derive(Clone, Debug)] pub enum DataType { Boolean, @@ -134,13 +124,11 @@ pub enum DataType { #[cfg(feature = "object")] Object(&'static str), Null, - // The RevMapping has the internal state. - // This is ignored with comparisons, hashing etc. #[cfg(feature = "dtype-categorical")] - Categorical(Option>, CategoricalOrdering), + Categorical(Arc, Arc), // It is an Option, so that matching Enum/Categoricals can take the same guards. #[cfg(feature = "dtype-categorical")] - Enum(Option>, CategoricalOrdering), + Enum(Arc, Arc), #[cfg(feature = "dtype-struct")] Struct(Vec), // some logical types we cannot know statically, e.g. Datetime @@ -169,17 +157,9 @@ impl PartialEq for DataType { { match (self, other) { #[cfg(feature = "dtype-categorical")] - // Don't include rev maps in comparisons - (Categorical(_, ordering_l), Categorical(_, ordering_r)) => { - ordering_l == ordering_r - }, - #[cfg(feature = "dtype-categorical")] - // None means select all Enum dtypes. This is for operation `pl.col(pl.Enum)` - (Enum(None, _), Enum(_, _)) | (Enum(_, _), Enum(None, _)) => true, + (Categorical(cats_l, _), Categorical(cats_r, _)) => Arc::ptr_eq(cats_l, cats_r), #[cfg(feature = "dtype-categorical")] - (Enum(Some(cat_lhs), _), Enum(Some(cat_rhs), _)) => { - cat_lhs.get_categories() == cat_rhs.get_categories() - }, + (Enum(fcats_l, _), Enum(fcats_r, _)) => Arc::ptr_eq(fcats_l, fcats_r), (Datetime(tu_l, tz_l), Datetime(tu_r, tz_r)) => tu_l == tu_r && tz_l == tz_r, (List(left_inner), List(right_inner)) => left_inner == right_inner, #[cfg(feature = "dtype-duration")] @@ -377,7 +357,7 @@ impl DataType { Some(match (self, to) { #[cfg(feature = "dtype-categorical")] (D::Categorical(_, _) | D::Enum(_, _), D::Binary) - | (D::Binary, D::Categorical(_, _) | D::Enum(_, _)) => false, + | (D::Binary, D::Categorical(_, _) | D::Enum(_, _)) => false, // TODO @ cat-rework: why can we not cast to Binary? #[cfg(feature = "object")] (D::Object(_), D::Object(_)) => true, @@ -437,7 +417,9 @@ impl DataType { #[cfg(feature = "dtype-decimal")] Decimal(_, _) => Int128, #[cfg(feature = "dtype-categorical")] - Categorical(_, _) | Enum(_, _) => UInt32, + Categorical(cats, _) => cats.physical().dtype(), + #[cfg(feature = "dtype-categorical")] + Enum(fcats, _) => fcats.physical().dtype(), #[cfg(feature = "dtype-array")] Array(dt, width) => Array(Box::new(dt.to_physical()), *width), List(dt) => List(Box::new(dt.to_physical())), @@ -552,7 +534,7 @@ impl DataType { match self { Binary | String => true, #[cfg(feature = "dtype-categorical")] - Categorical(_, _) | Enum(_, _) => true, + Categorical(_, _) | Enum(_, _) => true, // TODO @ cat-rework: is this right? List(inner) => inner.contains_views(), #[cfg(feature = "dtype-array")] Array(inner, _) => inner.contains_views(), @@ -622,7 +604,7 @@ impl DataType { /// Check if type is sortable pub fn is_ord(&self) -> bool { #[cfg(feature = "dtype-categorical")] - let is_cat = matches!(self, DataType::Categorical(_, _) | DataType::Enum(_, _)); + let is_cat = matches!(self, DataType::Categorical(_, _) | DataType::Enum(_, _)); // TODO @ cat-rework: is this right? Why not sortable? #[cfg(not(feature = "dtype-categorical"))] let is_cat = false; @@ -712,28 +694,44 @@ impl DataType { } } - /// Convert to an Arrow Field + /// Convert to an Arrow Field. pub fn to_arrow_field(&self, name: PlSmallStr, compat_level: CompatLevel) -> ArrowField { let metadata = match self { #[cfg(feature = "dtype-categorical")] - DataType::Enum(Some(revmap), _) => { - let cats = revmap.get_categories(); - let mut encoded = String::with_capacity(cats.len() * 10); + DataType::Enum(fcats, _map) => { + let cats = fcats.categories(); + let strings_size: usize = cats + .values_iter() + .map(|s| (s.len() + 1).ilog10() as usize + 1 + s.len()) + .sum(); + let mut encoded = String::with_capacity(strings_size); for cat in cats.values_iter() { encoded.push_str(itoa::Buffer::new().format(cat.len())); encoded.push(';'); encoded.push_str(cat); } Some(BTreeMap::from([( - PlSmallStr::from_static(DTYPE_ENUM_VALUES), + PlSmallStr::from_static(DTYPE_ENUM_VALUES_NEW), PlSmallStr::from_string(encoded), )])) }, #[cfg(feature = "dtype-categorical")] - DataType::Categorical(_, ordering) => Some(BTreeMap::from([( - PlSmallStr::from_static(DTYPE_CATEGORICAL), - PlSmallStr::from_static(ordering.into()), - )])), + DataType::Categorical(cats, _) => { + let mut encoded = String::new(); + encoded.push_str(itoa::Buffer::new().format(cats.name().len())); + encoded.push(';'); + encoded.push_str(cats.name()); + encoded.push_str(itoa::Buffer::new().format(cats.namespace().len())); + encoded.push(';'); + encoded.push_str(cats.namespace()); + encoded.push_str(cats.physical().as_str()); + encoded.push(';'); + + Some(BTreeMap::from([( + PlSmallStr::from_static(DTYPE_CATEGORICAL_NEW), + PlSmallStr::from_string(encoded), + )])) + }, DataType::BinaryOffset => Some(BTreeMap::from([( PlSmallStr::from_static(PL_KEY), PlSmallStr::from_static(MAINTAIN_PL_TYPE), @@ -861,13 +859,20 @@ impl DataType { Object(_) => Ok(get_object_physical_type()), #[cfg(feature = "dtype-categorical")] Categorical(_, _) | Enum(_, _) => { + let arrow_phys = match self.cat_physical().unwrap() { + CategoricalPhysical::U8 => IntegerType::UInt8, + CategoricalPhysical::U16 => IntegerType::UInt16, + CategoricalPhysical::U32 => IntegerType::UInt32, + }; + let values = if compat_level.0 >= 1 { ArrowDataType::Utf8View } else { ArrowDataType::LargeUtf8 }; + Ok(ArrowDataType::Dictionary( - IntegerType::UInt32, + arrow_phys, Box::new(values), false, )) @@ -938,6 +943,17 @@ impl DataType { // We don't allow the other way around, only if our current type is // null and the schema isn't we allow it. (DataType::Null, _) => Ok(true), + #[cfg(feature = "dtype-categorical")] + (DataType::Categorical(l, _), DataType::Categorical(r, _)) => { + ensure_same_categories(l, r)?; + Ok(false) + }, + #[cfg(feature = "dtype-categorical")] + (DataType::Enum(l, _), DataType::Enum(r, _)) => { + ensure_same_frozen_categories(l, r)?; + Ok(false) + }, + (l, r) if l == r => Ok(false), (l, r) => { polars_bail!(SchemaMismatch: "type {:?} is incompatible with expected type {:?}", l, r) @@ -959,6 +975,41 @@ impl DataType { } level } + + /// If this dtype is a Categorical or Enum, returns the physical backing type. + #[cfg(feature = "dtype-categorical")] + pub fn cat_physical(&self) -> PolarsResult { + match self { + DataType::Categorical(cats, _) => Ok(cats.physical()), + DataType::Enum(fcats, _) => Ok(fcats.physical()), + _ => { + polars_bail!(SchemaMismatch: "invalid dtype: expected an Enum or Categorical type, received '{:?}'", self) + }, + } + } + + /// If this dtype is a Categorical or Enum, returns the underlying mapping. + #[cfg(feature = "dtype-categorical")] + pub fn cat_mapping(&self) -> PolarsResult<&Arc> { + match self { + DataType::Categorical(_, mapping) | DataType::Enum(_, mapping) => Ok(mapping), + _ => { + polars_bail!(SchemaMismatch: "invalid dtype: expected an Enum or Categorical type, received '{:?}'", self) + }, + } + } + + #[cfg(feature = "dtype-categorical")] + pub fn from_categories(cats: Arc) -> Self { + let mapping = cats.mapping(); + Self::Categorical(cats, mapping) + } + + #[cfg(feature = "dtype-categorical")] + pub fn from_frozen_categories(fcats: Arc) -> Self { + let mapping = fcats.mapping().clone(); + Self::Enum(fcats, mapping) + } } impl Display for DataType { @@ -1036,27 +1087,14 @@ pub fn merge_dtypes(left: &DataType, right: &DataType) -> PolarsResult use DataType::*; Ok(match (left, right) { #[cfg(feature = "dtype-categorical")] - (Categorical(Some(rev_map_l), ordering), Categorical(Some(rev_map_r), _)) => { - match (&**rev_map_l, &**rev_map_r) { - (RevMapping::Global(_, _, idl), RevMapping::Global(_, _, idr)) if idl == idr => { - let mut merger = GlobalRevMapMerger::new(rev_map_l.clone()); - merger.merge_map(rev_map_r)?; - Categorical(Some(merger.finish()), *ordering) - }, - (RevMapping::Local(_, idl), RevMapping::Local(_, idr)) if idl == idr => { - left.clone() - }, - _ => polars_bail!(string_cache_mismatch), - } + (Categorical(cats_l, map), Categorical(cats_r, _)) => { + ensure_same_categories(cats_l, cats_r)?; + Categorical(cats_l.clone(), map.clone()) }, #[cfg(feature = "dtype-categorical")] - (Enum(Some(rev_map_l), _), Enum(Some(rev_map_r), _)) => { - match (&**rev_map_l, &**rev_map_r) { - (RevMapping::Local(_, idl), RevMapping::Local(_, idr)) if idl == idr => { - left.clone() - }, - _ => polars_bail!(ComputeError: "can not combine with different categories"), - } + (Enum(fcats_l, map), Enum(fcats_r, _)) => { + ensure_same_frozen_categories(fcats_l, fcats_r)?; + Enum(fcats_l.clone(), map.clone()) }, (List(inner_l), List(inner_r)) => { let merged = merge_dtypes(inner_l, inner_r)?; @@ -1123,12 +1161,6 @@ pub fn unpack_dtypes(dtype: &DataType, include_compound_types: bool) -> PlHashSe result } -#[cfg(feature = "dtype-categorical")] -pub fn create_enum_dtype(categories: Utf8ViewArray) -> DataType { - let rev_map = RevMapping::build_local(categories); - DataType::Enum(Some(Arc::new(rev_map)), Default::default()) -} - #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))] diff --git a/crates/polars-core/src/datatypes/field.rs b/crates/polars-core/src/datatypes/field.rs index f3541af450bd..bcf64749ee1a 100644 --- a/crates/polars-core/src/datatypes/field.rs +++ b/crates/polars-core/src/datatypes/field.rs @@ -1,4 +1,5 @@ -use arrow::datatypes::{DTYPE_ENUM_VALUES, Metadata}; +use arrow::datatypes::Metadata; +use polars_dtype::categorical::CategoricalPhysical; use polars_utils::pl_str::PlSmallStr; use super::*; @@ -178,42 +179,58 @@ impl DataType { ArrowDataType::Duration(tu) => DataType::Duration(tu.into()), ArrowDataType::Date64 => DataType::Datetime(TimeUnit::Milliseconds, None), ArrowDataType::Time64(_) | ArrowDataType::Time32(_) => DataType::Time, + #[cfg(feature = "dtype-categorical")] ArrowDataType::Dictionary(_, value_type, _) => { - if md.map(|md| md.is_enum()).unwrap_or(false) { - let md = md.unwrap(); - let encoded = md.get(DTYPE_ENUM_VALUES).unwrap(); - let mut encoded = encoded.as_str(); - let mut cats = MutableBinaryViewArray::::new(); + // The metadata encoding here must match DataType::to_arrow_field. + if let Some(mut enum_md) = md.and_then(|md| md.pl_enum_metadata()) { + let cats = move || { + if enum_md.is_empty() { + return None; + } - // Data is encoded as - // We know thus that len is only [0-9] and the first ';' doesn't belong to the - // payload. - while let Some(pos) = encoded.find(';') { - let (len, remainder) = encoded.split_at(pos); - // Split off ';' - encoded = &remainder[1..]; + let len; + (len, enum_md) = enum_md.split_once(';').unwrap(); let len = len.parse::().unwrap(); + let cat; + (cat, enum_md) = enum_md.split_at(len); + Some(cat) + }; + + let fcats = FrozenCategories::new(std::iter::from_fn(cats)).unwrap(); + DataType::from_frozen_categories(fcats) + } else if let Some(mut cat_md) = md.and_then(|md| md.pl_categorical_metadata()) { + let name_len; + (name_len, cat_md) = cat_md.split_once(';').unwrap(); + let name_len = name_len.parse::().unwrap(); + let name; + (name, cat_md) = cat_md.split_at(name_len); + + let namespace_len; + (namespace_len, cat_md) = cat_md.split_once(';').unwrap(); + let namespace_len = namespace_len.parse::().unwrap(); + let namespace; + (namespace, cat_md) = cat_md.split_at(namespace_len); - let (value, remainder) = encoded.split_at(len); - cats.push_value(value); - encoded = remainder; - } - DataType::Enum( - Some(Arc::new(RevMapping::build_local(cats.into()))), - Default::default(), - ) - } else if let Some(ordering) = md.and_then(|md| md.categorical()) { - DataType::Categorical(None, ordering) + let (physical, _rest) = cat_md.split_once(';').unwrap(); + + let physical: CategoricalPhysical = physical.parse().ok().unwrap(); + let cats = Categories::new( + PlSmallStr::from_str(name), + PlSmallStr::from_str(namespace), + physical, + ); + DataType::from_categories(cats) } else if matches!( value_type.as_ref(), ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 | ArrowDataType::Utf8View ) { - DataType::Categorical(None, Default::default()) + DataType::from_categories(Categories::global()) } else { Self::from_arrow(value_type, None) } }, + #[cfg(feature = "dtype-struct")] ArrowDataType::Struct(fields) => { DataType::Struct(fields.iter().map(|fld| fld.into()).collect()) diff --git a/crates/polars-core/src/datatypes/mod.rs b/crates/polars-core/src/datatypes/mod.rs index 84d2486b5139..f5e7d5da436d 100644 --- a/crates/polars-core/src/datatypes/mod.rs +++ b/crates/polars-core/src/datatypes/mod.rs @@ -10,8 +10,6 @@ mod _serde; mod aliases; mod any_value; -#[cfg(feature = "dtype-categorical")] -mod categories; mod dtype; mod field; mod into_scalar; @@ -35,14 +33,17 @@ pub use arrow::datatypes::reshape::*; pub use arrow::datatypes::{ArrowDataType, TimeUnit as ArrowTimeUnit}; use arrow::types::NativeType; use bytemuck::Zeroable; -#[cfg(feature = "dtype-categorical")] -pub use categories::{CategoricalMapping, Categories, FrozenCategories}; pub use dtype::*; pub use field::*; pub use into_scalar::*; use num_traits::{AsPrimitive, Bounded, FromPrimitive, Num, NumCast, One, Zero}; use polars_compute::arithmetic::HasPrimitiveArithmeticKernel; use polars_compute::float_sum::FloatSum; +#[cfg(feature = "dtype-categorical")] +pub use polars_dtype::categorical::{ + CatNative, CatSize, CategoricalMapping, CategoricalPhysical, Categories, FrozenCategories, + ensure_same_categories, ensure_same_frozen_categories, +}; use polars_utils::abs_diff::AbsDiff; use polars_utils::float::IsFloat; use polars_utils::min_max::MinMax; @@ -117,6 +118,16 @@ where pub trait PolarsIntegerType: PolarsNumericType {} pub trait PolarsFloatType: PolarsNumericType {} +/// # Safety +/// The physical() return type must be correct for Native. +#[cfg(feature = "dtype-categorical")] +pub unsafe trait PolarsCategoricalType: PolarsDataType { + type Native: NumericNative + CatNative + DictionaryKey + PartialEq + Eq + Hash; + type PolarsPhysical: PolarsIntegerType; + + fn physical() -> CategoricalPhysical; +} + macro_rules! impl_polars_num_datatype { ($trait: ident, $pdt:ident, $variant:ident, $physical:ty, $owned_phys:ty) => { #[derive(Clone, Copy)] @@ -169,6 +180,31 @@ macro_rules! impl_polars_datatype { }; } +macro_rules! impl_polars_categorical_datatype { + ($pdt:ident, $phys:ty, $native:ty, $phys_variant:ident) => { + impl_polars_datatype!( + $pdt, + unimplemented!(), + PrimitiveArray<$native>, + 'a, + $native, + $native, + $native, + FalseT + ); + + #[cfg(feature = "dtype-categorical")] + unsafe impl PolarsCategoricalType for $pdt { + type Native = $native; + type PolarsPhysical = $phys; + + fn physical() -> CategoricalPhysical { + CategoricalPhysical::$phys_variant + } + } + } +} + impl_polars_num_datatype!(PolarsIntegerType, UInt8Type, UInt8, u8, u8); impl_polars_num_datatype!(PolarsIntegerType, UInt16Type, UInt16, u16, u16); impl_polars_num_datatype!(PolarsIntegerType, UInt32Type, UInt32, u32, u32); @@ -195,6 +231,10 @@ impl_polars_datatype!(CategoricalType, unimplemented!(), PrimitiveArray, 'a impl_polars_datatype!(DateType, DataType::Date, PrimitiveArray, 'a, i32, i32, i32, FalseT); impl_polars_datatype!(TimeType, DataType::Time, PrimitiveArray, 'a, i64, i64, i64, FalseT); +impl_polars_categorical_datatype!(Categorical8Type, UInt8Type, u8, U8); +impl_polars_categorical_datatype!(Categorical16Type, UInt16Type, u16, U16); +impl_polars_categorical_datatype!(Categorical32Type, UInt32Type, u32, U32); + #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct ListType {} unsafe impl PolarsDataType for ListType { diff --git a/crates/polars-core/src/fmt.rs b/crates/polars-core/src/fmt.rs index 8b676fbbd1bc..deb071ffa3d5 100644 --- a/crates/polars-core/src/fmt.rs +++ b/crates/polars-core/src/fmt.rs @@ -402,18 +402,18 @@ impl Debug for Series { #[cfg(feature = "object")] DataType::Object(_) => format_object_array(f, self, self.name(), "Series"), #[cfg(feature = "dtype-categorical")] - DataType::Categorical(_, _) => { - format_array!(f, self.categorical().unwrap(), "cat", self.name(), "Series") + DataType::Categorical(cats, _) => { + with_match_categorical_physical_type!(cats.physical(), |$C| { + format_array!(f, self.cat::<$C>().unwrap(), "cat", self.name(), "Series") + }) }, #[cfg(feature = "dtype-categorical")] - DataType::Enum(_, _) => format_array!( - f, - self.categorical().unwrap(), - "enum", - self.name(), - "Series" - ), + DataType::Enum(fcats, _) => { + with_match_categorical_physical_type!(fcats.physical(), |$C| { + format_array!(f, self.cat::<$C>().unwrap(), "enum", self.name(), "Series") + }) + }, #[cfg(feature = "dtype-struct")] dt @ DataType::Struct(_) => format_array!( f, @@ -1182,10 +1182,10 @@ impl Display for AnyValue<'_> { write!(f, "{nt}") }, #[cfg(feature = "dtype-categorical")] - AnyValue::Categorical(_, _, _) - | AnyValue::CategoricalOwned(_, _, _) - | AnyValue::Enum(_, _, _) - | AnyValue::EnumOwned(_, _, _) => { + AnyValue::Categorical(_, _) + | AnyValue::CategoricalOwned(_, _) + | AnyValue::Enum(_, _) + | AnyValue::EnumOwned(_, _) => { let s = self.get_str().unwrap(); write!(f, "\"{s}\"") }, diff --git a/crates/polars-core/src/frame/column/mod.rs b/crates/polars-core/src/frame/column/mod.rs index 009a6080e074..2fd9c8e40319 100644 --- a/crates/polars-core/src/frame/column/mod.rs +++ b/crates/polars-core/src/frame/column/mod.rs @@ -321,8 +321,20 @@ impl Column { self.as_materialized_series().try_array() } #[cfg(feature = "dtype-categorical")] - pub fn try_categorical(&self) -> Option<&CategoricalChunked> { - self.as_materialized_series().try_categorical() + pub fn try_cat(&self) -> Option<&CategoricalChunked> { + self.as_materialized_series().try_cat::() + } + #[cfg(feature = "dtype-categorical")] + pub fn try_cat8(&self) -> Option<&Categorical8Chunked> { + self.as_materialized_series().try_cat8() + } + #[cfg(feature = "dtype-categorical")] + pub fn try_cat16(&self) -> Option<&Categorical16Chunked> { + self.as_materialized_series().try_cat16() + } + #[cfg(feature = "dtype-categorical")] + pub fn try_cat32(&self) -> Option<&Categorical32Chunked> { + self.as_materialized_series().try_cat32() } #[cfg(feature = "dtype-date")] pub fn try_date(&self) -> Option<&DateChunked> { @@ -403,8 +415,20 @@ impl Column { self.as_materialized_series().array() } #[cfg(feature = "dtype-categorical")] - pub fn categorical(&self) -> PolarsResult<&CategoricalChunked> { - self.as_materialized_series().categorical() + pub fn cat(&self) -> PolarsResult<&CategoricalChunked> { + self.as_materialized_series().cat::() + } + #[cfg(feature = "dtype-categorical")] + pub fn cat8(&self) -> PolarsResult<&Categorical8Chunked> { + self.as_materialized_series().cat8() + } + #[cfg(feature = "dtype-categorical")] + pub fn cat16(&self) -> PolarsResult<&Categorical16Chunked> { + self.as_materialized_series().cat16() + } + #[cfg(feature = "dtype-categorical")] + pub fn cat32(&self) -> PolarsResult<&Categorical32Chunked> { + self.as_materialized_series().cat32() } #[cfg(feature = "dtype-date")] pub fn date(&self) -> PolarsResult<&DateChunked> { diff --git a/crates/polars-core/src/frame/group_by/into_groups.rs b/crates/polars-core/src/frame/group_by/into_groups.rs index e4340307a086..3507c4d6c847 100644 --- a/crates/polars-core/src/frame/group_by/into_groups.rs +++ b/crates/polars-core/src/frame/group_by/into_groups.rs @@ -130,9 +130,9 @@ where } #[cfg(all(feature = "dtype-categorical", feature = "performant"))] -impl IntoGroupsType for CategoricalChunked { +impl IntoGroupsType for CategoricalChunked { fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { - Ok(self.group_tuples_perfect(multithreaded, sorted)) + self.phys.group_tuples(multithreaded, sorted) } } diff --git a/crates/polars-core/src/frame/group_by/mod.rs b/crates/polars-core/src/frame/group_by/mod.rs index bd61eaeebc45..a557572f5108 100644 --- a/crates/polars-core/src/frame/group_by/mod.rs +++ b/crates/polars-core/src/frame/group_by/mod.rs @@ -16,7 +16,6 @@ pub mod aggregations; pub mod expr; pub(crate) mod hashing; mod into_groups; -mod perfect; mod position; pub use into_groups::*; @@ -1118,7 +1117,7 @@ mod test { .unwrap(); df.apply("foo", |s| { - s.cast(&DataType::Categorical(None, Default::default())) + s.cast(&DataType::from_categories(Categories::global())) .unwrap() }) .unwrap(); @@ -1199,7 +1198,7 @@ mod test { ]?; df.try_apply("g", |s| { - s.cast(&DataType::Categorical(None, Default::default())) + s.cast(&DataType::from_categories(Categories::global())) })?; // Use of deprecated `sum()` for testing purposes diff --git a/crates/polars-core/src/frame/group_by/perfect.rs b/crates/polars-core/src/frame/group_by/perfect.rs deleted file mode 100644 index 04ce4c54970e..000000000000 --- a/crates/polars-core/src/frame/group_by/perfect.rs +++ /dev/null @@ -1,197 +0,0 @@ -use std::fmt::Debug; -use std::mem::MaybeUninit; - -use num_traits::{FromPrimitive, ToPrimitive}; -use polars_utils::idx_vec::IdxVec; -use polars_utils::sync::SyncPtr; -use rayon::prelude::*; - -use crate::POOL; -#[cfg(all(feature = "dtype-categorical", feature = "performant"))] -use crate::config::verbose; -use crate::datatypes::*; -use crate::prelude::*; - -impl ChunkedArray -where - T: PolarsIntegerType, - T::Native: ToPrimitive + FromPrimitive + Debug, -{ - /// Use the indexes as perfect groups. - /// - /// # Safety - /// This ChunkedArray must contain each value in [0..num_groups) at least - /// once, and nothing outside this range. - pub unsafe fn group_tuples_perfect( - &self, - num_groups: usize, - mut multithreaded: bool, - group_capacity: usize, - ) -> GroupsType { - multithreaded &= POOL.current_num_threads() > 1; - // The latest index will be used for the null sentinel. - let len = if self.null_count() > 0 { - // We add one to store the null sentinel group. - num_groups + 1 - } else { - num_groups - }; - let null_idx = len.saturating_sub(1); - - let n_threads = POOL.current_num_threads(); - let chunk_size = len / n_threads; - - let (groups, first) = if multithreaded && chunk_size > 1 { - let mut groups: Vec = Vec::new(); - groups.resize_with(len, || IdxVec::with_capacity(group_capacity)); - let mut first: Vec = Vec::with_capacity(len); - - // Round up offsets to nearest cache line for groups to reduce false sharing. - let groups_start = groups.as_ptr(); - let mut per_thread_offsets = Vec::with_capacity(n_threads + 1); - per_thread_offsets.push(0); - for t in 0..n_threads { - let ideal_offset = (t + 1) * chunk_size; - let cache_aligned_offset = - ideal_offset + groups_start.wrapping_add(ideal_offset).align_offset(128); - if t == n_threads - 1 { - per_thread_offsets.push(len); - } else { - per_thread_offsets.push(std::cmp::min(cache_aligned_offset, len)); - } - } - - let groups_ptr = unsafe { SyncPtr::new(groups.as_mut_ptr()) }; - let first_ptr = unsafe { SyncPtr::new(first.as_mut_ptr()) }; - POOL.install(|| { - (0..n_threads).into_par_iter().for_each(|thread_no| { - // We use raw pointers because the slices would overlap. - // However, each thread has its own range it is responsible for. - let groups = groups_ptr.get(); - let first = first_ptr.get(); - let start = per_thread_offsets[thread_no]; - let start = T::Native::from_usize(start).unwrap(); - let end = per_thread_offsets[thread_no + 1]; - let end = T::Native::from_usize(end).unwrap(); - - if start == end && thread_no != n_threads - 1 { - return; - }; - - let push_to_group = |cat, row_nr| unsafe { - debug_assert!(cat < len); - let buf = &mut *groups.add(cat); - buf.push(row_nr); - if buf.len() == 1 { - *first.add(cat) = row_nr; - } - }; - - let mut row_nr = 0 as IdxSize; - for arr in self.downcast_iter() { - if arr.null_count() == 0 { - for &cat in arr.values().as_slice() { - if cat >= start && cat < end { - push_to_group(cat.to_usize().unwrap(), row_nr); - } - - row_nr += 1; - } - } else { - for opt_cat in arr.iter() { - if let Some(&cat) = opt_cat { - if cat >= start && cat < end { - push_to_group(cat.to_usize().unwrap(), row_nr); - } - } else if thread_no == n_threads - 1 { - // Last thread handles null values. - push_to_group(null_idx, row_nr); - } - - row_nr += 1; - } - } - } - }); - }); - unsafe { - first.set_len(len); - } - (groups, first) - } else { - let mut groups = Vec::with_capacity(len); - let mut first = Vec::with_capacity(len); - let first_out = first.spare_capacity_mut(); - groups.resize_with(len, || IdxVec::with_capacity(group_capacity)); - - let mut push_to_group = |cat, row_nr| unsafe { - let buf: &mut IdxVec = groups.get_unchecked_mut(cat); - buf.push(row_nr); - if buf.len() == 1 { - *first_out.get_unchecked_mut(cat) = MaybeUninit::new(row_nr); - } - }; - - let mut row_nr = 0 as IdxSize; - for arr in self.downcast_iter() { - for opt_cat in arr.iter() { - if let Some(cat) = opt_cat { - push_to_group(cat.to_usize().unwrap(), row_nr); - } else { - push_to_group(null_idx, row_nr); - } - - row_nr += 1; - } - } - unsafe { - first.set_len(len); - } - (groups, first) - }; - - // NOTE! we set sorted here! - // this happens to be true for `fast_unique` categoricals - GroupsType::Idx(GroupsIdx::new(first, groups, true)) - } -} - -#[cfg(all(feature = "dtype-categorical", feature = "performant"))] -// Special implementation so that cats can be processed in a single pass -impl CategoricalChunked { - // Use the indexes as perfect groups - pub fn group_tuples_perfect(&self, multithreaded: bool, sorted: bool) -> GroupsType { - let rev_map = self.get_rev_map(); - if self.is_empty() { - return GroupsType::Idx(GroupsIdx::new(vec![], vec![], true)); - } - let cats = self.physical(); - - let mut out = match &**rev_map { - RevMapping::Local(cached, _) => { - if self._can_fast_unique() { - assert!(cached.len() <= self.len(), "invalid invariant"); - if verbose() { - eprintln!("grouping categoricals, run perfect hash function"); - } - // on relative small tables this isn't much faster than the default strategy - // but on huge tables, this can be > 2x faster - unsafe { cats.group_tuples_perfect(cached.len(), multithreaded, 0) } - } else { - self.physical().group_tuples(multithreaded, sorted).unwrap() - } - }, - RevMapping::Global(_mapping, _cached, _) => { - // TODO! see if we can optimize this - // the problem is that the global categories are not guaranteed packed together - // so we might need to deref them first to local ones, but that might be more - // expensive than just hashing (benchmark first) - self.physical().group_tuples(multithreaded, sorted).unwrap() - }, - }; - if sorted { - out.sort() - } - out - } -} diff --git a/crates/polars-core/src/frame/row/transpose.rs b/crates/polars-core/src/frame/row/transpose.rs index a90a721dde6e..e53a7e5b8315 100644 --- a/crates/polars-core/src/frame/row/transpose.rs +++ b/crates/polars-core/src/frame/row/transpose.rs @@ -142,27 +142,6 @@ impl DataFrame { NoData: "unable to transpose an empty DataFrame" ); let dtype = df.get_supertype().unwrap()?; - match dtype { - #[cfg(feature = "dtype-categorical")] - DataType::Categorical(_, _) | DataType::Enum(_, _) => { - let mut valid = true; - let mut rev_map: Option<&Arc> = None; - for s in self.columns.iter() { - if let DataType::Categorical(Some(col_rev_map), _) - | DataType::Enum(Some(col_rev_map), _) = &s.dtype() - { - match rev_map { - Some(rev_map) => valid = valid && rev_map.same_src(col_rev_map), - None => { - rev_map = Some(col_rev_map); - }, - } - } - } - polars_ensure!(valid, string_cache_mismatch); - }, - _ => {}, - } df.transpose_from_dtype(&dtype, keep_names_as.map(PlSmallStr::from_str), &names_out) } } diff --git a/crates/polars-core/src/lib.rs b/crates/polars-core/src/lib.rs index f004f64d1162..b949f62159d8 100644 --- a/crates/polars-core/src/lib.rs +++ b/crates/polars-core/src/lib.rs @@ -34,9 +34,6 @@ pub use datatypes::SchemaExtPl; pub use hashing::IdBuildHasher; use rayon::{ThreadPool, ThreadPoolBuilder}; -#[cfg(feature = "dtype-categorical")] -pub use crate::chunked_array::logical::categorical::string_cache::*; - pub static PROCESS_ID: LazyLock = LazyLock::new(|| { SystemTime::now() .duration_since(UNIX_EPOCH) diff --git a/crates/polars-core/src/prelude.rs b/crates/polars-core/src/prelude.rs index ced824db3c54..240b75ff41dc 100644 --- a/crates/polars-core/src/prelude.rs +++ b/crates/polars-core/src/prelude.rs @@ -36,8 +36,6 @@ pub use crate::chunked_array::ops::rolling_window::RollingOptionsFixedWindow; pub use crate::chunked_array::ops::*; #[cfg(feature = "temporal")] pub use crate::chunked_array::temporal::conversion::*; -#[cfg(feature = "dtype-categorical")] -pub use crate::datatypes::string_cache::StringCacheHolder; pub use crate::datatypes::{ArrayCollectIterExt, *}; pub use crate::error::signals::try_raise_keyboard_interrupt; pub use crate::error::{ @@ -60,4 +58,4 @@ pub use crate::series::arithmetic::{LhsNumOps, NumOpsDispatch}; pub use crate::series::{IntoSeries, Series, SeriesTrait}; pub(crate) use crate::utils::CustomIterTools; pub use crate::utils::IntoVec; -pub use crate::{datatypes, df}; +pub use crate::{datatypes, df, with_match_categorical_physical_type}; diff --git a/crates/polars-core/src/series/any_value.rs b/crates/polars-core/src/series/any_value.rs index 237c96af6736..b5bd1232623b 100644 --- a/crates/polars-core/src/series/any_value.rs +++ b/crates/polars-core/src/series/any_value.rs @@ -1,8 +1,9 @@ -use std::borrow::Cow; use std::fmt::Write; use arrow::bitmap::MutableBitmap; +#[cfg(feature = "dtype-categorical")] +use crate::chunked_array::builder::CategoricalChunkedBuilder; use crate::chunked_array::builder::{AnonymousOwnedListBuilder, get_list_builder}; use crate::prelude::*; use crate::utils::any_values_to_supertype; @@ -22,48 +23,6 @@ impl<'a, T: AsRef<[AnyValue<'a>]>> NamedFrom]> for Series { } } -fn initialize_empty_categorical_revmap_rec(dtype: &DataType) -> Cow<'_, DataType> { - use DataType as T; - match dtype { - #[cfg(feature = "dtype-categorical")] - T::Categorical(None, o) => { - Cow::Owned(T::Categorical(Some(Arc::new(RevMapping::default())), *o)) - }, - T::List(inner_dtype) => match initialize_empty_categorical_revmap_rec(inner_dtype) { - Cow::Owned(inner_dtype) => Cow::Owned(T::List(Box::new(inner_dtype))), - _ => Cow::Borrowed(dtype), - }, - #[cfg(feature = "dtype-array")] - T::Array(inner_dtype, width) => { - match initialize_empty_categorical_revmap_rec(inner_dtype) { - Cow::Owned(inner_dtype) => Cow::Owned(T::Array(Box::new(inner_dtype), *width)), - _ => Cow::Borrowed(dtype), - } - }, - #[cfg(feature = "dtype-struct")] - T::Struct(fields) => { - for (i, field) in fields.iter().enumerate() { - if let Cow::Owned(field_dtype) = - initialize_empty_categorical_revmap_rec(field.dtype()) - { - let mut new_fields = Vec::with_capacity(fields.len()); - new_fields.extend(fields[..i].iter().cloned()); - new_fields.push(Field::new(field.name().clone(), field_dtype)); - new_fields.extend(fields[i + 1..].iter().map(|field| { - let field_dtype = - initialize_empty_categorical_revmap_rec(field.dtype()).into_owned(); - Field::new(field.name().clone(), field_dtype) - })); - return Cow::Owned(T::Struct(new_fields)); - } - } - - Cow::Borrowed(dtype) - }, - _ => Cow::Borrowed(dtype), - } -} - impl Series { /// Construct a new [`Series`] from a slice of AnyValues. /// @@ -132,12 +91,7 @@ impl Series { strict: bool, ) -> PolarsResult { if values.is_empty() { - return Ok(Self::new_empty( - name, - // This is given categoricals with empty revmaps, but we need to always return - // categoricals with non-empty revmaps. - initialize_empty_categorical_revmap_rec(dtype).as_ref(), - )); + return Ok(Self::new_empty(name, dtype)); } let mut s = match dtype { @@ -172,9 +126,9 @@ impl Series { #[cfg(feature = "dtype-duration")] DataType::Duration(tu) => any_values_to_duration(values, *tu, strict)?.into_series(), #[cfg(feature = "dtype-categorical")] - dt @ DataType::Categorical(_, _) => any_values_to_categorical(values, dt, strict)?, - #[cfg(feature = "dtype-categorical")] - dt @ DataType::Enum(_, _) => any_values_to_enum(values, dt, strict)?, + dt @ (DataType::Categorical(_, _) | DataType::Enum(_, _)) => { + any_values_to_categorical(values, dt, strict)? + }, #[cfg(feature = "dtype-decimal")] DataType::Decimal(precision, scale) => { any_values_to_decimal(values, *precision, *scale, strict)?.into_series() @@ -498,91 +452,52 @@ fn any_values_to_categorical( dtype: &DataType, strict: bool, ) -> PolarsResult { - let ordering = match dtype { - DataType::Categorical(_, ordering) => ordering, - _ => panic!("any_values_to_categorical with dtype={dtype:?}"), - }; - - let mut builder = CategoricalChunkedBuilder::new(PlSmallStr::EMPTY, values.len(), *ordering); - - let mut owned = String::new(); // Amortize allocations. - for av in values { - match av { - AnyValue::String(s) => builder.append_value(s), - AnyValue::StringOwned(s) => builder.append_value(s), - - AnyValue::Enum(s, rev, _) => builder.append_value(rev.get(*s)), - AnyValue::EnumOwned(s, rev, _) => builder.append_value(rev.get(*s)), + with_match_categorical_physical_type!(dtype.cat_physical().unwrap(), |$C| { + let mut builder = CategoricalChunkedBuilder::<$C>::new(PlSmallStr::EMPTY, dtype.clone()); - AnyValue::Categorical(s, rev, _) => builder.append_value(rev.get(*s)), - AnyValue::CategoricalOwned(s, rev, _) => builder.append_value(rev.get(*s)), - - AnyValue::Binary(_) | AnyValue::BinaryOwned(_) if !strict => builder.append_null(), - AnyValue::Null => builder.append_null(), - - av => { - if strict { - return Err(invalid_value_error(&DataType::String, av)); + let mut owned = String::new(); // Amortize allocations. + for av in values { + let ret = match av { + AnyValue::String(s) => builder.append_str(s), + AnyValue::StringOwned(s) => builder.append_str(s), + + &AnyValue::Enum(cat, &ref map) | + &AnyValue::EnumOwned(cat, ref map) | + &AnyValue::Categorical(cat, &ref map) | + &AnyValue::CategoricalOwned(cat, ref map) => builder.append_cat(cat, map), + + AnyValue::Binary(_) | AnyValue::BinaryOwned(_) if !strict => { + builder.append_null(); + Ok(()) + }, + AnyValue::Null => { + builder.append_null(); + Ok(()) } - owned.clear(); - write!(owned, "{av}").unwrap(); - builder.append_value(&owned); - }, - } - } - - let ca = builder.finish(); - - Ok(ca.into_series()) -} - -#[cfg(feature = "dtype-categorical")] -fn any_values_to_enum(values: &[AnyValue], dtype: &DataType, strict: bool) -> PolarsResult { - use self::enum_::EnumChunkedBuilder; - - let (rev, ordering) = match dtype { - DataType::Enum(rev, ordering) => (rev.clone(), ordering), - _ => panic!("any_values_to_categorical with dtype={dtype:?}"), - }; - - let Some(rev) = rev else { - polars_bail!(nyi = "Not yet possible to create enum series without a rev-map"); - }; - - let mut builder = - EnumChunkedBuilder::new(PlSmallStr::EMPTY, values.len(), rev, *ordering, strict); - - let mut owned = String::new(); // Amortize allocations. - for av in values { - match av { - AnyValue::String(s) => builder.append_str(s)?, - AnyValue::StringOwned(s) => builder.append_str(s)?, - - AnyValue::Enum(s, rev, _) => builder.append_enum(*s, rev)?, - AnyValue::EnumOwned(s, rev, _) => builder.append_enum(*s, rev)?, - - AnyValue::Categorical(s, rev, _) => builder.append_str(rev.get(*s))?, - AnyValue::CategoricalOwned(s, rev, _) => builder.append_str(rev.get(*s))?, + av => { + if strict { + return Err(invalid_value_error(&DataType::String, av)); + } - AnyValue::Binary(_) | AnyValue::BinaryOwned(_) if !strict => builder.append_null(), - AnyValue::Null => builder.append_null(), + owned.clear(); + write!(owned, "{av}").unwrap(); + builder.append_str(&owned) + }, + }; - av => { + if let Err(e) = ret { if strict { - return Err(invalid_value_error(&DataType::String, av)); + return Err(e); + } else { + builder.append_null(); } + } + } - owned.clear(); - write!(owned, "{av}").unwrap(); - builder.append_str(&owned)? - }, - }; - } - - let ca = builder.finish(); - - Ok(ca.into_series()) + let ca = builder.finish(); + Ok(ca.into_series()) + }) } #[cfg(feature = "dtype-decimal")] @@ -701,20 +616,8 @@ fn any_values_to_list( DataType::Object(_) => polars_bail!(nyi = "Nested object types"), _ => { - let list_inner_type = match inner_type { - // Categoricals may not have a revmap yet. We just give them an empty one here and - // the list builder takes care of the rest. - #[cfg(feature = "dtype-categorical")] - DataType::Categorical(None, ordering) => { - DataType::Categorical(Some(Arc::new(RevMapping::default())), *ordering) - }, - - _ => inner_type.clone(), - }; - let mut builder = - get_list_builder(&list_inner_type, capacity * 5, capacity, PlSmallStr::EMPTY); - + get_list_builder(inner_type, capacity * 5, capacity, PlSmallStr::EMPTY); for av in avs { match av { AnyValue::List(b) => match b.cast(inner_type) { diff --git a/crates/polars-core/src/series/builder.rs b/crates/polars-core/src/series/builder.rs index e19e1ca5736f..df08f1828ff7 100644 --- a/crates/polars-core/src/series/builder.rs +++ b/crates/polars-core/src/series/builder.rs @@ -6,29 +6,10 @@ use crate::chunked_array::object::registry::get_object_builder; use crate::prelude::*; use crate::utils::Container; -#[cfg(feature = "dtype-categorical")] -#[inline(always)] -fn fill_rev_map(dtype: &DataType, rev_map_merger: &mut Option>) { - if let DataType::Categorical(Some(rev_map), _) = dtype { - assert!( - rev_map.is_active_global(), - "{}", - polars_err!(string_cache_mismatch) - ); - if let Some(merger) = rev_map_merger { - merger.merge_map(rev_map).unwrap(); - } else { - *rev_map_merger = Some(Box::new(GlobalRevMapMerger::new(rev_map.clone()))); - } - } -} - /// A type-erased wrapper around ArrayBuilder. pub struct SeriesBuilder { dtype: DataType, builder: Box, - #[cfg(feature = "dtype-categorical")] - rev_map_merger: Option>, } impl SeriesBuilder { @@ -37,21 +18,11 @@ impl SeriesBuilder { #[cfg(feature = "object")] if matches!(dtype, DataType::Object(_)) { let builder = get_object_builder(PlSmallStr::EMPTY, 0).as_array_builder(); - return Self { - dtype, - builder, - #[cfg(feature = "dtype-categorical")] - rev_map_merger: None, - }; + return Self { dtype, builder }; } let builder = make_builder(&dtype.to_physical().to_arrow(CompatLevel::newest())); - Self { - dtype, - builder, - #[cfg(feature = "dtype-categorical")] - rev_map_merger: None, - } + Self { dtype, builder } } #[inline(always)] @@ -59,22 +30,9 @@ impl SeriesBuilder { self.builder.reserve(additional); } - fn freeze_dtype(&mut self) -> DataType { - #[cfg(feature = "dtype-categorical")] - if let Some(rev_map_merger) = self.rev_map_merger.take() { - let DataType::Categorical(_, order) = self.dtype else { - unreachable!() - }; - return DataType::Categorical(Some(rev_map_merger.finish()), order); - } - - self.dtype.clone() - } - - pub fn freeze(mut self, name: PlSmallStr) -> Series { + pub fn freeze(self, name: PlSmallStr) -> Series { unsafe { - let dtype = self.freeze_dtype(); - Series::from_chunks_and_dtype_unchecked(name, vec![self.builder.freeze()], &dtype) + Series::from_chunks_and_dtype_unchecked(name, vec![self.builder.freeze()], &self.dtype) } } @@ -83,7 +41,7 @@ impl SeriesBuilder { Series::from_chunks_and_dtype_unchecked( name, vec![self.builder.freeze_reset()], - &self.freeze_dtype(), + &self.dtype, ) } } @@ -100,11 +58,6 @@ impl SeriesBuilder { /// other does not match the dtype of this builder. #[inline(always)] pub fn extend(&mut self, other: &Series, share: ShareStrategy) { - #[cfg(feature = "dtype-categorical")] - { - fill_rev_map(other.dtype(), &mut self.rev_map_merger); - } - self.subslice_extend(other, 0, other.len(), share); } @@ -117,11 +70,6 @@ impl SeriesBuilder { mut length: usize, share: ShareStrategy, ) { - #[cfg(feature = "dtype-categorical")] - { - fill_rev_map(other.dtype(), &mut self.rev_map_merger); - } - if length == 0 || other.is_empty() { return; } @@ -151,11 +99,6 @@ impl SeriesBuilder { repeats: usize, share: ShareStrategy, ) { - #[cfg(feature = "dtype-categorical")] - { - fill_rev_map(other.dtype(), &mut self.rev_map_merger); - } - if length == 0 || other.is_empty() { return; } @@ -179,11 +122,6 @@ impl SeriesBuilder { repeats: usize, share: ShareStrategy, ) { - #[cfg(feature = "dtype-categorical")] - { - fill_rev_map(other.dtype(), &mut self.rev_map_merger); - } - if length == 0 || repeats == 0 || other.is_empty() { return; } @@ -218,22 +156,12 @@ impl SeriesBuilder { /// # Safety /// The indices must be in-bounds. pub unsafe fn gather_extend(&mut self, other: &Series, idxs: &[IdxSize], share: ShareStrategy) { - #[cfg(feature = "dtype-categorical")] - { - fill_rev_map(other.dtype(), &mut self.rev_map_merger); - } - let chunks = other.chunks(); assert!(chunks.len() == 1); self.builder.gather_extend(&*chunks[0], idxs, share); } pub fn opt_gather_extend(&mut self, other: &Series, idxs: &[IdxSize], share: ShareStrategy) { - #[cfg(feature = "dtype-categorical")] - { - fill_rev_map(other.dtype(), &mut self.rev_map_merger); - } - let chunks = other.chunks(); assert!(chunks.len() == 1); self.builder.opt_gather_extend(&*chunks[0], idxs, share); diff --git a/crates/polars-core/src/series/comparison.rs b/crates/polars-core/src/series/comparison.rs index c4ed9b5769f1..6c7a8f5c113c 100644 --- a/crates/polars-core/src/series/comparison.rs +++ b/crates/polars-core/src/series/comparison.rs @@ -25,26 +25,27 @@ macro_rules! impl_eq_compare { #[cfg(feature = "dtype-categorical")] match (lhs.dtype(), rhs.dtype()) { - (Categorical(_, _) | Enum(_, _), Categorical(_, _) | Enum(_, _)) => { - return Ok(lhs - .categorical() - .unwrap() - .$method(rhs.categorical().unwrap())? - .with_name(lhs.name().clone())); + (Categorical(lcats, _), Categorical(rcats, _)) => { + ensure_same_categories(lcats, rcats)?; + return with_match_categorical_physical_type!(lcats.physical(), |$C| { + lhs.cat::<$C>().unwrap().$method(rhs.cat::<$C>().unwrap()) + }) + }, + (Enum(lfcats, _), Enum(rfcats, _)) => { + ensure_same_frozen_categories(lfcats, rfcats)?; + return with_match_categorical_physical_type!(lfcats.physical(), |$C| { + lhs.cat::<$C>().unwrap().$method(rhs.cat::<$C>().unwrap()) + }) }, (Categorical(_, _) | Enum(_, _), String) => { - return Ok(lhs - .categorical() - .unwrap() - .$method(rhs.str().unwrap())? - .with_name(lhs.name().clone())); + return with_match_categorical_physical_type!(lhs.dtype().cat_physical().unwrap(), |$C| { + Ok(lhs.cat::<$C>().unwrap().$method(rhs.str().unwrap())) + }) }, (String, Categorical(_, _) | Enum(_, _)) => { - return Ok(rhs - .categorical() - .unwrap() - .$method(lhs.str().unwrap())? - .with_name(lhs.name().clone())); + return with_match_categorical_physical_type!(rhs.dtype().cat_physical().unwrap(), |$C| { + Ok(rhs.cat::<$C>().unwrap().$method(lhs.str().unwrap())) + }) }, _ => (), }; @@ -97,7 +98,7 @@ macro_rules! bail_invalid_ineq { } macro_rules! impl_ineq_compare { - ($self:expr, $rhs:expr, $method:ident, $op:literal) => {{ + ($self:expr, $rhs:expr, $method:ident, $op:literal, $rev_method:ident) => {{ use DataType::*; let (lhs, rhs) = ($self, $rhs); validate_types(lhs.dtype(), rhs.dtype())?; @@ -117,26 +118,28 @@ macro_rules! impl_ineq_compare { #[cfg(feature = "dtype-categorical")] match (lhs.dtype(), rhs.dtype()) { - (Categorical(_, _) | Enum(_, _), Categorical(_, _) | Enum(_, _)) => { - return Ok(lhs - .categorical() - .unwrap() - .$method(rhs.categorical().unwrap())? - .with_name(lhs.name().clone())); + (Categorical(lcats, _), Categorical(rcats, _)) => { + ensure_same_categories(lcats, rcats)?; + return with_match_categorical_physical_type!(lcats.physical(), |$C| { + lhs.cat::<$C>().unwrap().$method(rhs.cat::<$C>().unwrap()) + }) + }, + (Enum(lfcats, _), Enum(rfcats, _)) => { + ensure_same_frozen_categories(lfcats, rfcats)?; + return with_match_categorical_physical_type!(lfcats.physical(), |$C| { + lhs.cat::<$C>().unwrap().$method(rhs.cat::<$C>().unwrap()) + }) }, (Categorical(_, _) | Enum(_, _), String) => { - return Ok(lhs - .categorical() - .unwrap() - .$method(rhs.str().unwrap())? - .with_name(lhs.name().clone())); + return with_match_categorical_physical_type!(lhs.dtype().cat_physical().unwrap(), |$C| { + lhs.cat::<$C>().unwrap().$method(rhs.str().unwrap()) + }) }, (String, Categorical(_, _) | Enum(_, _)) => { - return Ok(rhs - .categorical() - .unwrap() - .$method(lhs.str().unwrap())? - .with_name(lhs.name().clone())); + return with_match_categorical_physical_type!(rhs.dtype().cat_physical().unwrap(), |$C| { + // We use the reverse method as string <-> enum comparisons are only implemented one-way. + rhs.cat::<$C>().unwrap().$rev_method(lhs.str().unwrap()) + }) }, _ => (), }; @@ -227,22 +230,22 @@ impl ChunkCompareIneq<&Series> for Series { /// Create a boolean mask by checking if self > rhs. fn gt(&self, rhs: &Series) -> Self::Item { - impl_ineq_compare!(self, rhs, gt, ">") + impl_ineq_compare!(self, rhs, gt, ">", lt) } /// Create a boolean mask by checking if self >= rhs. fn gt_eq(&self, rhs: &Series) -> Self::Item { - impl_ineq_compare!(self, rhs, gt_eq, ">=") + impl_ineq_compare!(self, rhs, gt_eq, ">=", lt_eq) } /// Create a boolean mask by checking if self < rhs. fn lt(&self, rhs: &Series) -> Self::Item { - impl_ineq_compare!(self, rhs, lt, "<") + impl_ineq_compare!(self, rhs, lt, "<", gt) } /// Create a boolean mask by checking if self <= rhs. fn lt_eq(&self, rhs: &Series) -> Self::Item { - impl_ineq_compare!(self, rhs, lt_eq, "<=") + impl_ineq_compare!(self, rhs, lt_eq, "<=", gt_eq) } } @@ -316,9 +319,11 @@ impl ChunkCompareEq<&str> for Series { match self.dtype() { DataType::String => Ok(self.str().unwrap().equal(rhs)), #[cfg(feature = "dtype-categorical")] - DataType::Categorical(_, _) | DataType::Enum(_, _) => { - self.categorical().unwrap().equal(rhs) - }, + DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok( + with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| { + self.cat::<$C>().unwrap().equal(rhs) + }), + ), _ => Ok(BooleanChunked::full(self.name().clone(), false, self.len())), } } @@ -328,9 +333,11 @@ impl ChunkCompareEq<&str> for Series { match self.dtype() { DataType::String => Ok(self.str().unwrap().equal_missing(rhs)), #[cfg(feature = "dtype-categorical")] - DataType::Categorical(_, _) | DataType::Enum(_, _) => { - self.categorical().unwrap().equal_missing(rhs) - }, + DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok( + with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| { + self.cat::<$C>().unwrap().equal_missing(rhs) + }), + ), _ => Ok(replace_non_null( self.name().clone(), self.0.chunks(), @@ -344,9 +351,11 @@ impl ChunkCompareEq<&str> for Series { match self.dtype() { DataType::String => Ok(self.str().unwrap().not_equal(rhs)), #[cfg(feature = "dtype-categorical")] - DataType::Categorical(_, _) | DataType::Enum(_, _) => { - self.categorical().unwrap().not_equal(rhs) - }, + DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok( + with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| { + self.cat::<$C>().unwrap().not_equal(rhs) + }), + ), _ => Ok(BooleanChunked::full(self.name().clone(), true, self.len())), } } @@ -356,9 +365,11 @@ impl ChunkCompareEq<&str> for Series { match self.dtype() { DataType::String => Ok(self.str().unwrap().not_equal_missing(rhs)), #[cfg(feature = "dtype-categorical")] - DataType::Categorical(_, _) | DataType::Enum(_, _) => { - self.categorical().unwrap().not_equal_missing(rhs) - }, + DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok( + with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| { + self.cat::<$C>().unwrap().not_equal_missing(rhs) + }), + ), _ => Ok(replace_non_null(self.name().clone(), self.0.chunks(), true)), } } @@ -372,9 +383,11 @@ impl ChunkCompareIneq<&str> for Series { match self.dtype() { DataType::String => Ok(self.str().unwrap().gt(rhs)), #[cfg(feature = "dtype-categorical")] - DataType::Categorical(_, _) | DataType::Enum(_, _) => { - self.categorical().unwrap().gt(rhs) - }, + DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok( + with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| { + self.cat::<$C>().unwrap().gt(rhs) + }), + ), _ => polars_bail!( ComputeError: "cannot compare str value to series of type {}", self.dtype(), ), @@ -386,9 +399,11 @@ impl ChunkCompareIneq<&str> for Series { match self.dtype() { DataType::String => Ok(self.str().unwrap().gt_eq(rhs)), #[cfg(feature = "dtype-categorical")] - DataType::Categorical(_, _) | DataType::Enum(_, _) => { - self.categorical().unwrap().gt_eq(rhs) - }, + DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok( + with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| { + self.cat::<$C>().unwrap().gt_eq(rhs) + }), + ), _ => polars_bail!( ComputeError: "cannot compare str value to series of type {}", self.dtype(), ), @@ -400,9 +415,11 @@ impl ChunkCompareIneq<&str> for Series { match self.dtype() { DataType::String => Ok(self.str().unwrap().lt(rhs)), #[cfg(feature = "dtype-categorical")] - DataType::Categorical(_, _) | DataType::Enum(_, _) => { - self.categorical().unwrap().lt(rhs) - }, + DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok( + with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| { + self.cat::<$C>().unwrap().lt(rhs) + }), + ), _ => polars_bail!( ComputeError: "cannot compare str value to series of type {}", self.dtype(), ), @@ -414,9 +431,11 @@ impl ChunkCompareIneq<&str> for Series { match self.dtype() { DataType::String => Ok(self.str().unwrap().lt_eq(rhs)), #[cfg(feature = "dtype-categorical")] - DataType::Categorical(_, _) | DataType::Enum(_, _) => { - self.categorical().unwrap().lt_eq(rhs) - }, + DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok( + with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| { + self.cat::<$C>().unwrap().lt_eq(rhs) + }), + ), _ => polars_bail!( ComputeError: "cannot compare str value to series of type {}", self.dtype(), ), diff --git a/crates/polars-core/src/series/from.rs b/crates/polars-core/src/series/from.rs index 5b43a9fd7220..c224af065eca 100644 --- a/crates/polars-core/src/series/from.rs +++ b/crates/polars-core/src/series/from.rs @@ -98,20 +98,11 @@ impl Series { String => StringChunked::from_chunks(name, chunks).into_series(), Binary => BinaryChunked::from_chunks(name, chunks).into_series(), #[cfg(feature = "dtype-categorical")] - dt @ (Categorical(rev_map, ordering) | Enum(rev_map, ordering)) => { - let cats = UInt32Chunked::from_chunks(name, chunks); - let rev_map = rev_map.clone().unwrap_or_else(|| { - assert!(cats.is_empty()); - Arc::new(RevMapping::default()) - }); - let mut ca = CategoricalChunked::from_cats_and_rev_map_unchecked( - cats, - rev_map, - matches!(dt, Enum(_, _)), - *ordering, - ); - ca.set_fast_unique(false); - ca.into_series() + dt @ (Categorical(_, _) | Enum(_, _)) => { + with_match_categorical_physical_type!(dt.cat_physical().unwrap(), |$C| { + let phys = ChunkedArray::from_chunks(name, chunks); + CategoricalChunked::<$C>::from_cats_and_dtype_unchecked(phys, dt.clone()).into_series() + }) }, Boolean => BooleanChunked::from_chunks(name, chunks).into_series(), Float32 => Float32Chunked::from_chunks(name, chunks).into_series(), @@ -390,35 +381,71 @@ impl Series { panic!("activate dtype-categorical to convert dictionary arrays") }, #[cfg(feature = "dtype-categorical")] - ArrowDataType::Dictionary(key_type, value_type, _) => { - use arrow::datatypes::IntegerType; - // don't spuriously call this; triggers a read on mmapped data + ArrowDataType::Dictionary(key_type, _, _) => { + use arrow::datatypes::IntegerType as I; + + // Don't spuriously call this; triggers a read on mmapped data. let arr = if chunks.len() > 1 { concatenate_unchecked(&chunks)? } else { chunks[0].clone() }; - // If the value type is a string, they are converted to Categoricals or Enums + let polars_dtype = DataType::from_arrow(dtype, md); if matches!( - value_type.as_ref(), - ArrowDataType::Utf8 - | ArrowDataType::LargeUtf8 - | ArrowDataType::Utf8View - | ArrowDataType::Null + polars_dtype, + DataType::Categorical(_, _) | DataType::Enum(_, _) ) { - macro_rules! unpack_keys_values { + macro_rules! unpack_categorical_chunked { ($dt:ty) => {{ let arr = arr.as_any().downcast_ref::>().unwrap(); let keys = arr.keys(); - let keys = cast(keys, &ArrowDataType::UInt32).unwrap(); let values = arr.values(); let values = cast(&**values, &ArrowDataType::Utf8View)?; - (keys, values) + let values = values.as_any().downcast_ref::().unwrap(); + with_match_categorical_physical_type!( + polars_dtype.cat_physical().unwrap(), + |$C| { + let ca = CategoricalChunked::<$C>::from_str_iter( + name, + polars_dtype, + keys.iter().map(|k| { + let k: usize = (*k?).try_into().ok()?; + values.get(k) + }), + )?; + Ok(ca.into_series()) + } + ) + }}; + } + + match key_type { + I::Int8 => unpack_categorical_chunked!(i8), + I::UInt8 => unpack_categorical_chunked!(u8), + I::Int16 => unpack_categorical_chunked!(i16), + I::UInt16 => unpack_categorical_chunked!(u16), + I::Int32 => unpack_categorical_chunked!(i32), + I::UInt32 => unpack_categorical_chunked!(u32), + I::Int64 => unpack_categorical_chunked!(i64), + I::UInt64 => unpack_categorical_chunked!(u64), + _ => polars_bail!( + ComputeError: "unsupported arrow key type: {key_type:?}" + ), + } + } else { + macro_rules! unpack_keys_values { + ($dt:ty) => {{ + let arr = arr.as_any().downcast_ref::>().unwrap(); + let keys = arr.keys(); + let keys = polars_compute::cast::primitive_to_primitive::< + $dt, + ::Native, + >(keys, &IDX_DTYPE.to_arrow(CompatLevel::newest())); + (keys, arr.values()) }}; } - use IntegerType as I; let (keys, values) = match key_type { I::Int8 => unpack_keys_values!(i8), I::UInt8 => unpack_keys_values!(u8), @@ -427,83 +454,25 @@ impl Series { I::Int32 => unpack_keys_values!(i32), I::UInt32 => unpack_keys_values!(u32), I::Int64 => unpack_keys_values!(i64), + I::UInt64 => unpack_keys_values!(u64), _ => polars_bail!( - ComputeError: "dictionaries with unsigned 64-bit keys are not supported" + ComputeError: "unsupported arrow key type: {key_type:?}" ), }; - let keys = keys.as_any().downcast_ref::>().unwrap(); - let values = values.as_any().downcast_ref::().unwrap(); - - // Categoricals and Enums expect the RevMap values to not contain any nulls - let (keys, values) = - polars_compute::propagate_dictionary::propagate_dictionary_value_nulls( - keys, values, - ); - - let mut ordering = CategoricalOrdering::default(); - if let Some(metadata) = md { - if metadata.is_enum() { - // SAFETY: - // the invariants of an Arrow Dictionary guarantee the keys are in bounds - return Ok(CategoricalChunked::from_cats_and_rev_map_unchecked( - UInt32Chunked::with_chunk(name, keys), - Arc::new(RevMapping::build_local(values)), - true, - CategoricalOrdering::Physical, // Enum always uses physical ordering - ) - .into_series()); - } else if let Some(o) = metadata.categorical() { - ordering = o; - } - } - - return Ok(CategoricalChunked::from_keys_and_values( - name, &keys, &values, ordering, - ) - .into_series()); - } - - macro_rules! unpack_keys_values { - ($dt:ty) => {{ - let arr = arr.as_any().downcast_ref::>().unwrap(); - let keys = arr.keys(); - let keys = polars_compute::cast::primitive_as_primitive::< - $dt, - ::Native, - >(keys, &IDX_DTYPE.to_arrow(CompatLevel::newest())); - (arr.values(), keys) - }}; + let values = Series::_try_from_arrow_unchecked_with_md( + name, + vec![values.clone()], + values.dtype(), + None, + )?; + + values.take(&IdxCa::from_chunks_and_dtype( + PlSmallStr::EMPTY, + vec![keys.to_boxed()], + IDX_DTYPE, + )) } - - use IntegerType as I; - let (values, keys) = match key_type { - I::Int8 => unpack_keys_values!(i8), - I::UInt8 => unpack_keys_values!(u8), - I::Int16 => unpack_keys_values!(i16), - I::UInt16 => unpack_keys_values!(u16), - I::Int32 => unpack_keys_values!(i32), - I::UInt32 => unpack_keys_values!(u32), - I::Int64 => unpack_keys_values!(i64), - _ => polars_bail!( - ComputeError: "dictionaries with unsigned 64-bit keys are not supported" - ), - }; - - // Convert the dictionary to a flat array - let values = Series::_try_from_arrow_unchecked_with_md( - name, - vec![values.clone()], - values.dtype(), - None, - )?; - let values = values.take_unchecked(&IdxCa::from_chunks_and_dtype( - PlSmallStr::EMPTY, - vec![keys.to_boxed()], - IDX_DTYPE, - )); - - Ok(values) }, #[cfg(feature = "object")] ArrowDataType::Extension(ext) diff --git a/crates/polars-core/src/series/implementations/categorical.rs b/crates/polars-core/src/series/implementations/categorical.rs index cbf004d9e22d..0c63c944c1f6 100644 --- a/crates/polars-core/src/series/implementations/categorical.rs +++ b/crates/polars-core/src/series/implementations/categorical.rs @@ -2,335 +2,319 @@ use super::*; use crate::chunked_array::comparison::*; use crate::prelude::*; -unsafe impl IntoSeries for CategoricalChunked { +unsafe impl IntoSeries for CategoricalChunked { fn into_series(self) -> Series { - Series(Arc::new(SeriesWrap(self))) + // We do this hack to go from generic T to concrete T to avoid adding bounds on IntoSeries. + with_match_categorical_physical_type!(T::physical(), |$C| { + unsafe { + Series(Arc::new(SeriesWrap(core::mem::transmute::>(self)))) + } + }) } } -impl SeriesWrap { - fn finish_with_state(&self, keep_fast_unique: bool, cats: UInt32Chunked) -> CategoricalChunked { - let mut out = unsafe { - CategoricalChunked::from_cats_and_rev_map_unchecked( - cats, - self.0.get_rev_map().clone(), - self.0.is_enum(), - self.0.get_ordering(), - ) - }; - if keep_fast_unique && self.0._can_fast_unique() { - out.set_fast_unique(true) - } - out - } - - fn with_state(&self, keep_fast_unique: bool, apply: F) -> CategoricalChunked +impl SeriesWrap> { + unsafe fn apply_on_phys(&self, apply: F) -> CategoricalChunked where - F: Fn(&UInt32Chunked) -> UInt32Chunked, + F: Fn(&ChunkedArray) -> ChunkedArray, { let cats = apply(self.0.physical()); - self.finish_with_state(keep_fast_unique, cats) + unsafe { CategoricalChunked::from_cats_and_dtype_unchecked(cats, self.0.dtype().clone()) } } - fn try_with_state<'a, F>( - &'a self, - keep_fast_unique: bool, - apply: F, - ) -> PolarsResult + unsafe fn try_apply_on_phys(&self, apply: F) -> PolarsResult> where - F: for<'b> Fn(&'a UInt32Chunked) -> PolarsResult, + F: Fn(&ChunkedArray) -> PolarsResult>, { let cats = apply(self.0.physical())?; - Ok(self.finish_with_state(keep_fast_unique, cats)) - } -} - -impl private::PrivateSeries for SeriesWrap { - fn compute_len(&mut self) { - self.0.physical_mut().compute_len() - } - fn _field(&self) -> Cow<'_, Field> { - Cow::Owned(self.0.field()) - } - fn _dtype(&self) -> &DataType { - self.0.dtype() - } - fn _get_flags(&self) -> StatisticsFlags { - self.0.get_flags() - } - fn _set_flags(&mut self, flags: StatisticsFlags) { - self.0.set_flags(flags) - } - - unsafe fn equal_element(&self, idx_self: usize, idx_other: usize, other: &Series) -> bool { - self.0.physical().equal_element(idx_self, idx_other, other) - } - - #[cfg(feature = "zip_with")] - fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult { - self.0 - .zip_with(mask, other.categorical()?) - .map(|ca| ca.into_series()) - } - fn into_total_ord_inner<'a>(&'a self) -> Box { - if self.0.uses_lexical_ordering() { - (&self.0).into_total_ord_inner() - } else { - self.0.physical().into_total_ord_inner() - } - } - fn into_total_eq_inner<'a>(&'a self) -> Box { - invalid_operation_panic!(into_total_eq_inner, self) - } - - fn vec_hash( - &self, - random_state: PlSeedableRandomStateQuality, - buf: &mut Vec, - ) -> PolarsResult<()> { - self.0.physical().vec_hash(random_state, buf)?; - Ok(()) - } - - fn vec_hash_combine( - &self, - build_hasher: PlSeedableRandomStateQuality, - hashes: &mut [u64], - ) -> PolarsResult<()> { - self.0.physical().vec_hash_combine(build_hasher, hashes)?; - Ok(()) - } - - #[cfg(feature = "algorithm_group_by")] - unsafe fn agg_list(&self, groups: &GroupsType) -> Series { - // we cannot cast and dispatch as the inner type of the list would be incorrect - let list = self.0.physical().agg_list(groups); - let mut list = list.list().unwrap().clone(); - unsafe { list.to_logical(self.dtype().clone()) }; - list.into_series() - } - - #[cfg(feature = "algorithm_group_by")] - fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { - #[cfg(feature = "performant")] - { - Ok(self.0.group_tuples_perfect(multithreaded, sorted)) - } - #[cfg(not(feature = "performant"))] - { - self.0.physical().group_tuples(multithreaded, sorted) + unsafe { + Ok(CategoricalChunked::from_cats_and_dtype_unchecked( + cats, + self.0.dtype().clone(), + )) } } - - fn arg_sort_multiple( - &self, - by: &[Column], - options: &SortMultipleOptions, - ) -> PolarsResult { - self.0.arg_sort_multiple(by, options) - } } -impl SeriesTrait for SeriesWrap { - fn rename(&mut self, name: PlSmallStr) { - self.0.physical_mut().rename(name); - } - - fn chunk_lengths(&self) -> ChunkLenIter<'_> { - self.0.physical().chunk_lengths() - } - fn name(&self) -> &PlSmallStr { - self.0.physical().name() - } - - fn chunks(&self) -> &Vec { - self.0.physical().chunks() - } - unsafe fn chunks_mut(&mut self) -> &mut Vec { - self.0.physical_mut().chunks_mut() - } - fn shrink_to_fit(&mut self) { - self.0.physical_mut().shrink_to_fit() - } - - fn slice(&self, offset: i64, length: usize) -> Series { - self.with_state(false, |cats| cats.slice(offset, length)) - .into_series() - } - fn split_at(&self, offset: i64) -> (Series, Series) { - let (a, b) = self.0.physical().split_at(offset); - let a = self.finish_with_state(false, a).into_series(); - let b = self.finish_with_state(false, b).into_series(); - (a, b) - } - - fn append(&mut self, other: &Series) -> PolarsResult<()> { - polars_ensure!(self.0.dtype() == other.dtype(), append); - self.0.append(other.categorical().unwrap()) - } - fn append_owned(&mut self, mut other: Series) -> PolarsResult<()> { - polars_ensure!(self.0.dtype() == other.dtype(), append); - let other = other - ._get_inner_mut() - .as_any_mut() - .downcast_mut::() - .unwrap(); - self.0.append_owned(other.take()) - } - - fn extend(&mut self, other: &Series) -> PolarsResult<()> { - polars_ensure!(self.0.dtype() == other.dtype(), extend); - let other_ca = other.categorical().unwrap(); - // Fast path for globals of the same source - let rev_map_self = self.0.get_rev_map(); - let rev_map_other = other_ca.get_rev_map(); - match (&**rev_map_self, &**rev_map_other) { - (RevMapping::Global(_, _, idl), RevMapping::Global(_, _, idr)) if idl == idr => { - let mut rev_map_merger = GlobalRevMapMerger::new(rev_map_self.clone()); - rev_map_merger.merge_map(rev_map_other)?; - self.0.physical_mut().extend(other_ca.physical())?; - // SAFETY: rev_maps are merged - unsafe { self.0.set_rev_map(rev_map_merger.finish(), false) }; +macro_rules! impl_cat_series { + ($ca: ident, $pdt:ty) => { + impl private::PrivateSeries for SeriesWrap<$ca> { + fn compute_len(&mut self) { + self.0.physical_mut().compute_len() + } + fn _field(&self) -> Cow<'_, Field> { + Cow::Owned(self.0.field()) + } + fn _dtype(&self) -> &DataType { + self.0.dtype() + } + fn _get_flags(&self) -> StatisticsFlags { + self.0.get_flags() + } + fn _set_flags(&mut self, flags: StatisticsFlags) { + self.0.set_flags(flags) + } + + unsafe fn equal_element(&self, idx_self: usize, idx_other: usize, other: &Series) -> bool { + self.0.physical().equal_element(idx_self, idx_other, other) + } + + #[cfg(feature = "zip_with")] + fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult { + polars_ensure!(self.dtype() == other.dtype(), SchemaMismatch: "expected '{}' found '{}'", self.dtype(), other.dtype()); + let other = other.to_physical_repr().into_owned(); + unsafe { + Ok(self.try_apply_on_phys(|ca| { + ca.zip_with(mask, other.as_ref().as_ref()) + })?.into_series()) + } + } + + fn into_total_ord_inner<'a>(&'a self) -> Box { + if self.0.uses_lexical_ordering() { + (&self.0).into_total_ord_inner() + } else { + self.0.physical().into_total_ord_inner() + } + } + fn into_total_eq_inner<'a>(&'a self) -> Box { + invalid_operation_panic!(into_total_eq_inner, self) + } + + fn vec_hash( + &self, + random_state: PlSeedableRandomStateQuality, + buf: &mut Vec, + ) -> PolarsResult<()> { + self.0.physical().vec_hash(random_state, buf)?; Ok(()) - }, - _ => self.0.append(other_ca), + } + + fn vec_hash_combine( + &self, + build_hasher: PlSeedableRandomStateQuality, + hashes: &mut [u64], + ) -> PolarsResult<()> { + self.0.physical().vec_hash_combine(build_hasher, hashes)?; + Ok(()) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { + // we cannot cast and dispatch as the inner type of the list would be incorrect + let list = self.0.physical().agg_list(groups); + let mut list = list.list().unwrap().clone(); + unsafe { list.to_logical(self.dtype().clone()) }; + list.into_series() + } + + #[cfg(feature = "algorithm_group_by")] + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + self.0.physical().group_tuples(multithreaded, sorted) + } + + fn arg_sort_multiple( + &self, + by: &[Column], + options: &SortMultipleOptions, + ) -> PolarsResult { + self.0.arg_sort_multiple(by, options) + } } - } - - fn filter(&self, filter: &BooleanChunked) -> PolarsResult { - self.try_with_state(false, |cats| cats.filter(filter)) - .map(|ca| ca.into_series()) - } - - fn take(&self, indices: &IdxCa) -> PolarsResult { - self.try_with_state(false, |cats| cats.take(indices)) - .map(|ca| ca.into_series()) - } - - unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { - self.with_state(false, |cats| cats.take_unchecked(indices)) - .into_series() - } - - fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { - self.try_with_state(false, |cats| cats.take(indices)) - .map(|ca| ca.into_series()) - } - - unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { - self.with_state(false, |cats| cats.take_unchecked(indices)) - .into_series() - } - - fn len(&self) -> usize { - self.0.len() - } - - fn rechunk(&self) -> Series { - self.with_state(true, |ca| ca.rechunk().into_owned()) - .into_series() - } - - fn new_from_index(&self, index: usize, length: usize) -> Series { - self.with_state(false, |cats| cats.new_from_index(index, length)) - .into_series() - } - - fn cast(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { - self.0.cast_with_options(dtype, options) - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> AnyValue<'_> { - self.0.get_any_value_unchecked(index) - } - - fn sort_with(&self, options: SortOptions) -> PolarsResult { - Ok(self.0.sort_with(options).into_series()) - } - - fn arg_sort(&self, options: SortOptions) -> IdxCa { - self.0.arg_sort(options) - } - - fn null_count(&self) -> usize { - self.0.physical().null_count() - } - - fn has_nulls(&self) -> bool { - self.0.physical().has_nulls() - } - - #[cfg(feature = "algorithm_group_by")] - fn unique(&self) -> PolarsResult { - self.0.unique().map(|ca| ca.into_series()) - } - - #[cfg(feature = "algorithm_group_by")] - fn n_unique(&self) -> PolarsResult { - self.0.n_unique() - } - - #[cfg(feature = "algorithm_group_by")] - fn arg_unique(&self) -> PolarsResult { - self.0.physical().arg_unique() - } - fn is_null(&self) -> BooleanChunked { - self.0.physical().is_null() - } - - fn is_not_null(&self) -> BooleanChunked { - self.0.physical().is_not_null() - } - - fn reverse(&self) -> Series { - self.with_state(true, |cats| cats.reverse()).into_series() - } - - fn as_single_ptr(&mut self) -> PolarsResult { - self.0.physical_mut().as_single_ptr() - } - - fn shift(&self, periods: i64) -> Series { - self.with_state(false, |ca| ca.shift(periods)).into_series() - } - - fn clone_inner(&self) -> Arc { - Arc::new(SeriesWrap(Clone::clone(&self.0))) - } - - fn min_reduce(&self) -> PolarsResult { - Ok(ChunkAggSeries::min_reduce(&self.0)) - } - - fn max_reduce(&self) -> PolarsResult { - Ok(ChunkAggSeries::max_reduce(&self.0)) - } - - fn find_validity_mismatch(&self, other: &Series, idxs: &mut Vec) { - self.0.physical().find_validity_mismatch(other, idxs) - } - - fn as_any(&self) -> &dyn Any { - &self.0 - } - - fn as_any_mut(&mut self) -> &mut dyn Any { - &mut self.0 - } - - fn as_phys_any(&self) -> &dyn Any { - self.0.physical() - } + impl SeriesTrait for SeriesWrap<$ca> { + fn rename(&mut self, name: PlSmallStr) { + self.0.physical_mut().rename(name); + } + + fn chunk_lengths(&self) -> ChunkLenIter<'_> { + self.0.physical().chunk_lengths() + } + + fn name(&self) -> &PlSmallStr { + self.0.physical().name() + } + + fn chunks(&self) -> &Vec { + self.0.physical().chunks() + } + + unsafe fn chunks_mut(&mut self) -> &mut Vec { + self.0.physical_mut().chunks_mut() + } + + fn shrink_to_fit(&mut self) { + self.0.physical_mut().shrink_to_fit() + } + + fn slice(&self, offset: i64, length: usize) -> Series { + unsafe { self.apply_on_phys(|cats| cats.slice(offset, length)).into_series() } + } + + fn split_at(&self, offset: i64) -> (Series, Series) { + unsafe { + let (a, b) = self.0.physical().split_at(offset); + let a = <$ca>::from_cats_and_dtype_unchecked(a, self.0.dtype().clone()).into_series(); + let b = <$ca>::from_cats_and_dtype_unchecked(b, self.0.dtype().clone()).into_series(); + (a, b) + } + } + + fn append(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + self.0.append(other.cat::<$pdt>().unwrap()) + } + + fn append_owned(&mut self, mut other: Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + self.0.physical_mut().append_owned(std::mem::take( + other + ._get_inner_mut() + .as_any_mut() + .downcast_mut::<$ca>() + .unwrap() + .physical_mut(), + )) + } + + fn extend(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), extend); + self.0.extend(other.cat::<$pdt>().unwrap()) + } + + fn filter(&self, filter: &BooleanChunked) -> PolarsResult { + unsafe { Ok(self.try_apply_on_phys(|cats| cats.filter(filter))?.into_series()) } + } + + fn take(&self, indices: &IdxCa) -> PolarsResult { + unsafe { Ok(self.try_apply_on_phys(|cats| cats.take(indices))?.into_series() ) } + } + + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + unsafe { self.apply_on_phys(|cats| cats.take_unchecked(indices)).into_series() } + } + + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + unsafe { Ok(self.try_apply_on_phys(|cats| cats.take(indices))?.into_series()) } + } + + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + unsafe { self.apply_on_phys(|cats| cats.take_unchecked(indices)).into_series() } + } + + fn len(&self) -> usize { + self.0.len() + } + + fn rechunk(&self) -> Series { + unsafe { self.apply_on_phys(|cats| cats.rechunk().into_owned()).into_series() } + } + + fn new_from_index(&self, index: usize, length: usize) -> Series { + unsafe { self.apply_on_phys(|cats| cats.new_from_index(index, length)).into_series() } + } + + fn cast(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + self.0.cast_with_options(dtype, options) + } + + #[inline] + unsafe fn get_unchecked(&self, index: usize) -> AnyValue<'_> { + self.0.get_any_value_unchecked(index) + } + + fn sort_with(&self, options: SortOptions) -> PolarsResult { + Ok(self.0.sort_with(options).into_series()) + } + + fn arg_sort(&self, options: SortOptions) -> IdxCa { + self.0.arg_sort(options) + } + + fn null_count(&self) -> usize { + self.0.physical().null_count() + } + + fn has_nulls(&self) -> bool { + self.0.physical().has_nulls() + } + + #[cfg(feature = "algorithm_group_by")] + fn unique(&self) -> PolarsResult { + unsafe { Ok(self.try_apply_on_phys(|cats| cats.unique())?.into_series()) } + } + + #[cfg(feature = "algorithm_group_by")] + fn n_unique(&self) -> PolarsResult { + self.0.physical().n_unique() + } + + #[cfg(feature = "algorithm_group_by")] + fn arg_unique(&self) -> PolarsResult { + self.0.physical().arg_unique() + } + + fn is_null(&self) -> BooleanChunked { + self.0.physical().is_null() + } + + fn is_not_null(&self) -> BooleanChunked { + self.0.physical().is_not_null() + } + + fn reverse(&self) -> Series { + unsafe { self.apply_on_phys(|cats| cats.reverse()).into_series() } + } + + fn as_single_ptr(&mut self) -> PolarsResult { + self.0.physical_mut().as_single_ptr() + } + + fn shift(&self, periods: i64) -> Series { + unsafe { self.apply_on_phys(|ca| ca.shift(periods)).into_series() } + } + + fn clone_inner(&self) -> Arc { + Arc::new(SeriesWrap(Clone::clone(&self.0))) + } + + fn min_reduce(&self) -> PolarsResult { + Ok(ChunkAggSeries::min_reduce(&self.0)) + } + + fn max_reduce(&self) -> PolarsResult { + Ok(ChunkAggSeries::max_reduce(&self.0)) + } + + fn find_validity_mismatch(&self, other: &Series, idxs: &mut Vec) { + self.0.physical().find_validity_mismatch(other, idxs) + } + + fn as_any(&self) -> &dyn Any { + &self.0 + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } + + fn as_phys_any(&self) -> &dyn Any { + self.0.physical() + } + + fn as_arc_any(self: Arc) -> Arc { + self as _ + } + } - fn as_arc_any(self: Arc) -> Arc { - self as _ + impl private::PrivateSeriesNumeric for SeriesWrap<$ca> { + fn bit_repr(&self) -> Option { + Some(self.0.physical().to_bit_repr()) + } + } } } -impl private::PrivateSeriesNumeric for SeriesWrap { - fn bit_repr(&self) -> Option { - Some(BitRepr::Small(self.0.physical().clone())) - } -} +impl_cat_series!(Categorical8Chunked, Categorical8Type); +impl_cat_series!(Categorical16Chunked, Categorical16Type); +impl_cat_series!(Categorical32Chunked, Categorical32Type); diff --git a/crates/polars-core/src/series/implementations/date.rs b/crates/polars-core/src/series/implementations/date.rs index 87722c77968d..6ecf0adc89e3 100644 --- a/crates/polars-core/src/series/implementations/date.rs +++ b/crates/polars-core/src/series/implementations/date.rs @@ -20,7 +20,7 @@ unsafe impl IntoSeries for DateChunked { impl private::PrivateSeries for SeriesWrap { fn compute_len(&mut self) { - self.0.compute_len() + self.0.physical_mut().compute_len() } fn _field(&self) -> Cow<'_, Field> { @@ -32,17 +32,18 @@ impl private::PrivateSeries for SeriesWrap { } fn _get_flags(&self) -> StatisticsFlags { - self.0.get_flags() + self.0.physical().get_flags() } fn _set_flags(&mut self, flags: StatisticsFlags) { - self.0.set_flags(flags) + self.0.physical_mut().set_flags(flags) } #[cfg(feature = "zip_with")] fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult { let other = other.to_physical_repr().into_owned(); self.0 + .physical() .zip_with(mask, other.as_ref().as_ref()) .map(|ca| ca.into_date().into_series()) } @@ -50,6 +51,7 @@ impl private::PrivateSeries for SeriesWrap { fn into_total_eq_inner<'a>(&'a self) -> Box { self.0.physical().into_total_eq_inner() } + fn into_total_ord_inner<'a>(&'a self) -> Box { self.0.physical().into_total_ord_inner() } @@ -59,7 +61,7 @@ impl private::PrivateSeries for SeriesWrap { random_state: PlSeedableRandomStateQuality, buf: &mut Vec, ) -> PolarsResult<()> { - self.0.vec_hash(random_state, buf)?; + self.0.physical().vec_hash(random_state, buf)?; Ok(()) } @@ -68,24 +70,25 @@ impl private::PrivateSeries for SeriesWrap { build_hasher: PlSeedableRandomStateQuality, hashes: &mut [u64], ) -> PolarsResult<()> { - self.0.vec_hash_combine(build_hasher, hashes)?; + self.0.physical().vec_hash_combine(build_hasher, hashes)?; Ok(()) } #[cfg(feature = "algorithm_group_by")] unsafe fn agg_min(&self, groups: &GroupsType) -> Series { - self.0.agg_min(groups).into_date().into_series() + self.0.physical().agg_min(groups).into_date().into_series() } #[cfg(feature = "algorithm_group_by")] unsafe fn agg_max(&self, groups: &GroupsType) -> Series { - self.0.agg_max(groups).into_date().into_series() + self.0.physical().agg_max(groups).into_date().into_series() } #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsType) -> Series { // we cannot cast and dispatch as the inner type of the list would be incorrect self.0 + .physical() .agg_list(groups) .cast(&DataType::List(Box::new(self.dtype().clone()))) .unwrap() @@ -138,7 +141,7 @@ impl private::PrivateSeries for SeriesWrap { } #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { - self.0.group_tuples(multithreaded, sorted) + self.0.physical().group_tuples(multithreaded, sorted) } fn arg_sort_multiple( @@ -146,7 +149,7 @@ impl private::PrivateSeries for SeriesWrap { by: &[Column], options: &SortMultipleOptions, ) -> PolarsResult { - self.0.deref().arg_sort_multiple(by, options) + self.0.physical().arg_sort_multiple(by, options) } } @@ -156,52 +159,57 @@ impl SeriesTrait for SeriesWrap { } fn chunk_lengths(&self) -> ChunkLenIter<'_> { - self.0.chunk_lengths() + self.0.physical().chunk_lengths() } + fn name(&self) -> &PlSmallStr { self.0.name() } fn chunks(&self) -> &Vec { - self.0.chunks() + self.0.physical().chunks() } + unsafe fn chunks_mut(&mut self) -> &mut Vec { - self.0.chunks_mut() + self.0.physical_mut().chunks_mut() } fn shrink_to_fit(&mut self) { - self.0.shrink_to_fit() + self.0.physical_mut().shrink_to_fit() } fn slice(&self, offset: i64, length: usize) -> Series { - self.0.slice(offset, length).into_date().into_series() + self.0.slice(offset, length).into_series() } + fn split_at(&self, offset: i64) -> (Series, Series) { let (a, b) = self.0.split_at(offset); - (a.into_date().into_series(), b.into_date().into_series()) + (a.into_series(), b.into_series()) } fn _sum_as_f64(&self) -> f64 { - self.0._sum_as_f64() + self.0.physical()._sum_as_f64() } fn mean(&self) -> Option { - self.0.mean() + self.0.physical().mean() } fn median(&self) -> Option { - self.0.median() + self.0.physical().median() } fn append(&mut self, other: &Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), append); let mut other = other.to_physical_repr().into_owned(); self.0 + .physical_mut() .append_owned(std::mem::take(other._get_inner_mut().as_mut())) } + fn append_owned(&mut self, mut other: Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), append); - self.0.append_owned(std::mem::take( + self.0.physical_mut().append_owned(std::mem::take( &mut other ._get_inner_mut() .as_any_mut() @@ -210,6 +218,7 @@ impl SeriesTrait for SeriesWrap { .phys, )) } + fn extend(&mut self, other: &Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), extend); // 3 refs @@ -217,28 +226,41 @@ impl SeriesTrait for SeriesWrap { // ref SeriesTrait // ref ChunkedArray let other = other.to_physical_repr(); - self.0.extend(other.as_ref().as_ref().as_ref())?; + self.0 + .physical_mut() + .extend(other.as_ref().as_ref().as_ref())?; Ok(()) } fn filter(&self, filter: &BooleanChunked) -> PolarsResult { - self.0.filter(filter).map(|ca| ca.into_date().into_series()) + self.0 + .physical() + .filter(filter) + .map(|ca| ca.into_date().into_series()) } fn take(&self, indices: &IdxCa) -> PolarsResult { - Ok(self.0.take(indices)?.into_date().into_series()) + Ok(self.0.physical().take(indices)?.into_date().into_series()) } unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { - self.0.take_unchecked(indices).into_date().into_series() + self.0 + .physical() + .take_unchecked(indices) + .into_date() + .into_series() } fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { - Ok(self.0.take(indices)?.into_date().into_series()) + Ok(self.0.physical().take(indices)?.into_date().into_series()) } unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { - self.0.take_unchecked(indices).into_date().into_series() + self.0 + .physical() + .take_unchecked(indices) + .into_date() + .into_series() } fn len(&self) -> usize { @@ -246,11 +268,17 @@ impl SeriesTrait for SeriesWrap { } fn rechunk(&self) -> Series { - self.0.rechunk().into_owned().into_date().into_series() + self.0 + .physical() + .rechunk() + .into_owned() + .into_date() + .into_series() } fn new_from_index(&self, index: usize, length: usize) -> Series { self.0 + .physical() .new_from_index(index, length) .into_date() .into_series() @@ -269,7 +297,7 @@ impl SeriesTrait for SeriesWrap { #[cfg(feature = "dtype-datetime")] DataType::Datetime(_, _) => { let mut out = self.0.cast_with_options(dtype, CastOptions::NonStrict)?; - out.set_sorted_flag(self.0.is_sorted_flag()); + out.set_sorted_flag(self.0.physical().is_sorted_flag()); Ok(out) }, _ => self.0.cast_with_options(dtype, cast_options), @@ -282,11 +310,16 @@ impl SeriesTrait for SeriesWrap { } fn sort_with(&self, options: SortOptions) -> PolarsResult { - Ok(self.0.sort_with(options).into_date().into_series()) + Ok(self + .0 + .physical() + .sort_with(options) + .into_date() + .into_series()) } fn arg_sort(&self, options: SortOptions) -> IdxCa { - self.0.arg_sort(options) + self.0.physical().arg_sort(options) } fn null_count(&self) -> usize { @@ -299,17 +332,20 @@ impl SeriesTrait for SeriesWrap { #[cfg(feature = "algorithm_group_by")] fn unique(&self) -> PolarsResult { - self.0.unique().map(|ca| ca.into_date().into_series()) + self.0 + .physical() + .unique() + .map(|ca| ca.into_date().into_series()) } #[cfg(feature = "algorithm_group_by")] fn n_unique(&self) -> PolarsResult { - self.0.n_unique() + self.0.physical().n_unique() } #[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { - self.0.arg_unique() + self.0.physical().arg_unique() } fn is_null(&self) -> BooleanChunked { @@ -321,25 +357,25 @@ impl SeriesTrait for SeriesWrap { } fn reverse(&self) -> Series { - self.0.reverse().into_date().into_series() + self.0.physical().reverse().into_date().into_series() } fn as_single_ptr(&mut self) -> PolarsResult { - self.0.as_single_ptr() + self.0.physical_mut().as_single_ptr() } fn shift(&self, periods: i64) -> Series { - self.0.shift(periods).into_date().into_series() + self.0.physical().shift(periods).into_date().into_series() } fn max_reduce(&self) -> PolarsResult { - let sc = self.0.max_reduce(); + let sc = self.0.physical().max_reduce(); let av = sc.value().cast(self.dtype()).into_static(); Ok(Scalar::new(self.dtype().clone(), av)) } fn min_reduce(&self) -> PolarsResult { - let sc = self.0.min_reduce(); + let sc = self.0.physical().min_reduce(); let av = sc.value().cast(self.dtype()).into_static(); Ok(Scalar::new(self.dtype().clone(), av)) } @@ -360,7 +396,7 @@ impl SeriesTrait for SeriesWrap { } fn find_validity_mismatch(&self, other: &Series, idxs: &mut Vec) { - self.0.find_validity_mismatch(other, idxs) + self.0.physical().find_validity_mismatch(other, idxs) } fn as_any(&self) -> &dyn Any { @@ -382,6 +418,6 @@ impl SeriesTrait for SeriesWrap { impl private::PrivateSeriesNumeric for SeriesWrap { fn bit_repr(&self) -> Option { - Some(self.0.to_bit_repr()) + Some(self.0.physical().to_bit_repr()) } } diff --git a/crates/polars-core/src/series/implementations/datetime.rs b/crates/polars-core/src/series/implementations/datetime.rs index 8947d1b07102..e143053ffa8e 100644 --- a/crates/polars-core/src/series/implementations/datetime.rs +++ b/crates/polars-core/src/series/implementations/datetime.rs @@ -13,13 +13,13 @@ unsafe impl IntoSeries for DatetimeChunked { impl private::PrivateSeriesNumeric for SeriesWrap { fn bit_repr(&self) -> Option { - Some(self.0.to_bit_repr()) + Some(self.0.physical().to_bit_repr()) } } impl private::PrivateSeries for SeriesWrap { fn compute_len(&mut self) { - self.0.compute_len() + self.0.physical_mut().compute_len() } fn _field(&self) -> Cow<'_, Field> { Cow::Owned(self.0.field()) @@ -28,19 +28,22 @@ impl private::PrivateSeries for SeriesWrap { self.0.dtype() } fn _get_flags(&self) -> StatisticsFlags { - self.0.get_flags() + self.0.physical().get_flags() } fn _set_flags(&mut self, flags: StatisticsFlags) { - self.0.set_flags(flags) + self.0.physical_mut().set_flags(flags) } #[cfg(feature = "zip_with")] fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult { let other = other.to_physical_repr().into_owned(); - self.0.zip_with(mask, other.as_ref().as_ref()).map(|ca| { - ca.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) - .into_series() - }) + self.0 + .physical() + .zip_with(mask, other.as_ref().as_ref()) + .map(|ca| { + ca.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) + .into_series() + }) } fn into_total_eq_inner<'a>(&'a self) -> Box { @@ -55,7 +58,7 @@ impl private::PrivateSeries for SeriesWrap { random_state: PlSeedableRandomStateQuality, buf: &mut Vec, ) -> PolarsResult<()> { - self.0.vec_hash(random_state, buf)?; + self.0.physical().vec_hash(random_state, buf)?; Ok(()) } @@ -64,13 +67,14 @@ impl private::PrivateSeries for SeriesWrap { build_hasher: PlSeedableRandomStateQuality, hashes: &mut [u64], ) -> PolarsResult<()> { - self.0.vec_hash_combine(build_hasher, hashes)?; + self.0.physical().vec_hash_combine(build_hasher, hashes)?; Ok(()) } #[cfg(feature = "algorithm_group_by")] unsafe fn agg_min(&self, groups: &GroupsType) -> Series { self.0 + .physical() .agg_min(groups) .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) .into_series() @@ -79,6 +83,7 @@ impl private::PrivateSeries for SeriesWrap { #[cfg(feature = "algorithm_group_by")] unsafe fn agg_max(&self, groups: &GroupsType) -> Series { self.0 + .physical() .agg_max(groups) .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) .into_series() @@ -87,6 +92,7 @@ impl private::PrivateSeries for SeriesWrap { unsafe fn agg_list(&self, groups: &GroupsType) -> Series { // we cannot cast and dispatch as the inner type of the list would be incorrect self.0 + .physical() .agg_list(groups) .cast(&DataType::List(Box::new(self.dtype().clone()))) .unwrap() @@ -138,7 +144,7 @@ impl private::PrivateSeries for SeriesWrap { } #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { - self.0.group_tuples(multithreaded, sorted) + self.0.physical().group_tuples(multithreaded, sorted) } fn arg_sort_multiple( @@ -146,7 +152,7 @@ impl private::PrivateSeries for SeriesWrap { by: &[Column], options: &SortMultipleOptions, ) -> PolarsResult { - self.0.deref().arg_sort_multiple(by, options) + self.0.physical().arg_sort_multiple(by, options) } } @@ -156,60 +162,53 @@ impl SeriesTrait for SeriesWrap { } fn chunk_lengths(&self) -> ChunkLenIter<'_> { - self.0.chunk_lengths() + self.0.physical().chunk_lengths() } fn name(&self) -> &PlSmallStr { self.0.name() } fn chunks(&self) -> &Vec { - self.0.chunks() + self.0.physical().chunks() } unsafe fn chunks_mut(&mut self) -> &mut Vec { - self.0.chunks_mut() + self.0.physical_mut().chunks_mut() } fn shrink_to_fit(&mut self) { - self.0.shrink_to_fit() + self.0.physical_mut().shrink_to_fit() } fn slice(&self, offset: i64, length: usize) -> Series { - self.0 - .slice(offset, length) - .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) - .into_series() + self.0.slice(offset, length).into_series() } fn split_at(&self, offset: i64) -> (Series, Series) { let (a, b) = self.0.split_at(offset); - ( - a.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) - .into_series(), - b.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) - .into_series(), - ) + (a.into_series(), b.into_series()) } fn _sum_as_f64(&self) -> f64 { - self.0._sum_as_f64() + self.0.physical()._sum_as_f64() } fn mean(&self) -> Option { - self.0.mean() + self.0.physical().mean() } fn median(&self) -> Option { - self.0.median() + self.0.physical().median() } fn append(&mut self, other: &Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), append); let mut other = other.to_physical_repr().into_owned(); self.0 + .physical_mut() .append_owned(std::mem::take(other._get_inner_mut().as_mut())) } fn append_owned(&mut self, mut other: Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), append); - self.0.append_owned(std::mem::take( + self.0.physical_mut().append_owned(std::mem::take( &mut other ._get_inner_mut() .as_any_mut() @@ -222,39 +221,41 @@ impl SeriesTrait for SeriesWrap { fn extend(&mut self, other: &Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), extend); let other = other.to_physical_repr(); - self.0.extend(other.as_ref().as_ref().as_ref())?; + self.0 + .physical_mut() + .extend(other.as_ref().as_ref().as_ref())?; Ok(()) } fn filter(&self, filter: &BooleanChunked) -> PolarsResult { - self.0.filter(filter).map(|ca| { + self.0.physical().filter(filter).map(|ca| { ca.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) .into_series() }) } fn take(&self, indices: &IdxCa) -> PolarsResult { - let ca = self.0.take(indices)?; + let ca = self.0.physical().take(indices)?; Ok(ca .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) .into_series()) } unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { - let ca = self.0.take_unchecked(indices); + let ca = self.0.physical().take_unchecked(indices); ca.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) .into_series() } fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { - let ca = self.0.take(indices)?; + let ca = self.0.physical().take(indices)?; Ok(ca .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) .into_series()) } unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { - let ca = self.0.take_unchecked(indices); + let ca = self.0.physical().take_unchecked(indices); ca.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) .into_series() } @@ -265,6 +266,7 @@ impl SeriesTrait for SeriesWrap { fn rechunk(&self) -> Series { self.0 + .physical() .rechunk() .into_owned() .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) @@ -273,6 +275,7 @@ impl SeriesTrait for SeriesWrap { fn new_from_index(&self, index: usize, length: usize) -> Series { self.0 + .physical() .new_from_index(index, length) .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) .into_series() @@ -293,13 +296,14 @@ impl SeriesTrait for SeriesWrap { fn sort_with(&self, options: SortOptions) -> PolarsResult { Ok(self .0 + .physical() .sort_with(options) .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) .into_series()) } fn arg_sort(&self, options: SortOptions) -> IdxCa { - self.0.arg_sort(options) + self.0.physical().arg_sort(options) } fn null_count(&self) -> usize { @@ -312,7 +316,7 @@ impl SeriesTrait for SeriesWrap { #[cfg(feature = "algorithm_group_by")] fn unique(&self) -> PolarsResult { - self.0.unique().map(|ca| { + self.0.physical().unique().map(|ca| { ca.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) .into_series() }) @@ -320,12 +324,12 @@ impl SeriesTrait for SeriesWrap { #[cfg(feature = "algorithm_group_by")] fn n_unique(&self) -> PolarsResult { - self.0.n_unique() + self.0.physical().n_unique() } #[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { - self.0.arg_unique() + self.0.physical().arg_unique() } fn is_null(&self) -> BooleanChunked { @@ -338,30 +342,32 @@ impl SeriesTrait for SeriesWrap { fn reverse(&self) -> Series { self.0 + .physical() .reverse() .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) .into_series() } fn as_single_ptr(&mut self) -> PolarsResult { - self.0.as_single_ptr() + self.0.physical_mut().as_single_ptr() } fn shift(&self, periods: i64) -> Series { self.0 + .physical() .shift(periods) .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) .into_series() } fn max_reduce(&self) -> PolarsResult { - let sc = self.0.max_reduce(); + let sc = self.0.physical().max_reduce(); Ok(Scalar::new(self.dtype().clone(), sc.value().clone())) } fn min_reduce(&self) -> PolarsResult { - let sc = self.0.min_reduce(); + let sc = self.0.physical().min_reduce(); Ok(Scalar::new(self.dtype().clone(), sc.value().clone())) } @@ -380,7 +386,7 @@ impl SeriesTrait for SeriesWrap { } fn find_validity_mismatch(&self, other: &Series, idxs: &mut Vec) { - self.0.find_validity_mismatch(other, idxs) + self.0.physical().find_validity_mismatch(other, idxs) } fn as_any(&self) -> &dyn Any { diff --git a/crates/polars-core/src/series/implementations/decimal.rs b/crates/polars-core/src/series/implementations/decimal.rs index b9aee41bb3ac..f052f9cbf3f5 100644 --- a/crates/polars-core/src/series/implementations/decimal.rs +++ b/crates/polars-core/src/series/implementations/decimal.rs @@ -17,13 +17,13 @@ impl private::PrivateSeriesNumeric for SeriesWrap { impl SeriesWrap { fn apply_physical_to_s Int128Chunked>(&self, f: F) -> Series { - f(&self.0) + f(self.0.physical()) .into_decimal_unchecked(self.0.precision(), self.0.scale()) .into_series() } fn apply_physical T>(&self, f: F) -> T { - f(&self.0) + f(self.0.physical()) } fn scale_factor(&self) -> u128 { @@ -45,7 +45,7 @@ impl SeriesWrap { } fn agg_helper Series>(&self, f: F) -> Series { - let agg_s = f(&self.0); + let agg_s = f(self.0.physical()); match agg_s.dtype() { DataType::Int128 => { let ca = agg_s.i128().unwrap(); @@ -96,7 +96,7 @@ impl SeriesWrap { impl private::PrivateSeries for SeriesWrap { fn compute_len(&mut self) { - self.0.compute_len() + self.0.physical_mut().compute_len() } fn _field(&self) -> Cow<'_, Field> { @@ -107,10 +107,10 @@ impl private::PrivateSeries for SeriesWrap { self.0.dtype() } fn _get_flags(&self) -> StatisticsFlags { - self.0.get_flags() + self.0.physical().get_flags() } fn _set_flags(&mut self, flags: StatisticsFlags) { - self.0.set_flags(flags) + self.0.physical_mut().set_flags(flags) } #[cfg(feature = "zip_with")] @@ -125,10 +125,10 @@ impl private::PrivateSeries for SeriesWrap { .into_series()) } fn into_total_eq_inner<'a>(&'a self) -> Box { - (&self.0).into_total_eq_inner() + self.0.physical().into_total_eq_inner() } fn into_total_ord_inner<'a>(&'a self) -> Box { - (&self.0).into_total_ord_inner() + self.0.physical().into_total_ord_inner() } fn vec_hash( @@ -136,7 +136,7 @@ impl private::PrivateSeries for SeriesWrap { random_state: PlSeedableRandomStateQuality, buf: &mut Vec, ) -> PolarsResult<()> { - self.0.vec_hash(random_state, buf)?; + self.0.physical().vec_hash(random_state, buf)?; Ok(()) } @@ -145,7 +145,7 @@ impl private::PrivateSeries for SeriesWrap { build_hasher: PlSeedableRandomStateQuality, hashes: &mut [u64], ) -> PolarsResult<()> { - self.0.vec_hash_combine(build_hasher, hashes)?; + self.0.physical().vec_hash_combine(build_hasher, hashes)?; Ok(()) } @@ -187,14 +187,14 @@ impl private::PrivateSeries for SeriesWrap { } #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { - self.0.group_tuples(multithreaded, sorted) + self.0.physical().group_tuples(multithreaded, sorted) } fn arg_sort_multiple( &self, by: &[Column], options: &SortMultipleOptions, ) -> PolarsResult { - self.0.arg_sort_multiple(by, options) + self.0.physical().arg_sort_multiple(by, options) } } @@ -204,7 +204,7 @@ impl SeriesTrait for SeriesWrap { } fn chunk_lengths(&self) -> ChunkLenIter<'_> { - self.0.chunk_lengths() + self.0.physical().chunk_lengths() } fn name(&self) -> &PlSmallStr { @@ -212,10 +212,10 @@ impl SeriesTrait for SeriesWrap { } fn chunks(&self) -> &Vec { - self.0.chunks() + self.0.physical().chunks() } unsafe fn chunks_mut(&mut self) -> &mut Vec { - self.0.chunks_mut() + self.0.physical_mut().chunks_mut() } fn slice(&self, offset: i64, length: usize) -> Series { @@ -224,24 +224,19 @@ impl SeriesTrait for SeriesWrap { fn split_at(&self, offset: i64) -> (Series, Series) { let (a, b) = self.0.split_at(offset); - let a = a - .into_decimal_unchecked(self.0.precision(), self.0.scale()) - .into_series(); - let b = b - .into_decimal_unchecked(self.0.precision(), self.0.scale()) - .into_series(); - (a, b) + (a.into_series(), b.into_series()) } fn append(&mut self, other: &Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), append); let mut other = other.to_physical_repr().into_owned(); self.0 + .physical_mut() .append_owned(std::mem::take(other._get_inner_mut().as_mut())) } fn append_owned(&mut self, mut other: Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), append); - self.0.append_owned(std::mem::take( + self.0.physical_mut().append_owned(std::mem::take( &mut other ._get_inner_mut() .as_any_mut() @@ -258,13 +253,16 @@ impl SeriesTrait for SeriesWrap { // ref SeriesTrait // ref ChunkedArray let other = other.to_physical_repr(); - self.0.extend(other.as_ref().as_ref().as_ref())?; + self.0 + .physical_mut() + .extend(other.as_ref().as_ref().as_ref())?; Ok(()) } fn filter(&self, filter: &BooleanChunked) -> PolarsResult { Ok(self .0 + .physical() .filter(filter)? .into_decimal_unchecked(self.0.precision(), self.0.scale()) .into_series()) @@ -273,6 +271,7 @@ impl SeriesTrait for SeriesWrap { fn take(&self, indices: &IdxCa) -> PolarsResult { Ok(self .0 + .physical() .take(indices)? .into_decimal_unchecked(self.0.precision(), self.0.scale()) .into_series()) @@ -280,6 +279,7 @@ impl SeriesTrait for SeriesWrap { unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { self.0 + .physical() .take_unchecked(indices) .into_decimal_unchecked(self.0.precision(), self.0.scale()) .into_series() @@ -288,6 +288,7 @@ impl SeriesTrait for SeriesWrap { fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { Ok(self .0 + .physical() .take(indices)? .into_decimal_unchecked(self.0.precision(), self.0.scale()) .into_series()) @@ -295,6 +296,7 @@ impl SeriesTrait for SeriesWrap { unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { self.0 + .physical() .take_unchecked(indices) .into_decimal_unchecked(self.0.precision(), self.0.scale()) .into_series() @@ -305,13 +307,14 @@ impl SeriesTrait for SeriesWrap { } fn rechunk(&self) -> Series { - let ca = self.0.rechunk().into_owned(); + let ca = self.0.physical().rechunk().into_owned(); ca.into_decimal_unchecked(self.0.precision(), self.0.scale()) .into_series() } fn new_from_index(&self, index: usize, length: usize) -> Series { self.0 + .physical() .new_from_index(index, length) .into_decimal_unchecked(self.0.precision(), self.0.scale()) .into_series() @@ -329,13 +332,14 @@ impl SeriesTrait for SeriesWrap { fn sort_with(&self, options: SortOptions) -> PolarsResult { Ok(self .0 + .physical() .sort_with(options) .into_decimal_unchecked(self.0.precision(), self.0.scale()) .into_series()) } fn arg_sort(&self, options: SortOptions) -> IdxCa { - self.0.arg_sort(options) + self.0.physical().arg_sort(options) } fn null_count(&self) -> usize { @@ -353,12 +357,12 @@ impl SeriesTrait for SeriesWrap { #[cfg(feature = "algorithm_group_by")] fn n_unique(&self) -> PolarsResult { - self.0.n_unique() + self.0.physical().n_unique() } #[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { - self.0.arg_unique() + self.0.physical().arg_unique() } fn is_null(&self) -> BooleanChunked { @@ -421,35 +425,45 @@ impl SeriesTrait for SeriesWrap { } fn _sum_as_f64(&self) -> f64 { - self.0._sum_as_f64() / self.scale_factor() as f64 + self.0.physical()._sum_as_f64() / self.scale_factor() as f64 } fn mean(&self) -> Option { - self.0.mean().map(|v| v / self.scale_factor() as f64) + self.0 + .physical() + .mean() + .map(|v| v / self.scale_factor() as f64) } fn median(&self) -> Option { - self.0.median().map(|v| v / self.scale_factor() as f64) + self.0 + .physical() + .median() + .map(|v| v / self.scale_factor() as f64) } fn median_reduce(&self) -> PolarsResult { - Ok(self.apply_scale(self.0.median_reduce())) + Ok(self.apply_scale(self.0.physical().median_reduce())) } fn std(&self, ddof: u8) -> Option { - self.0.std(ddof).map(|v| v / self.scale_factor() as f64) + self.0 + .physical() + .std(ddof) + .map(|v| v / self.scale_factor() as f64) } fn std_reduce(&self, ddof: u8) -> PolarsResult { - Ok(self.apply_scale(self.0.std_reduce(ddof))) + Ok(self.apply_scale(self.0.physical().std_reduce(ddof))) } fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { self.0 + .physical() .quantile_reduce(quantile, method) .map(|v| self.apply_scale(v)) } fn find_validity_mismatch(&self, other: &Series, idxs: &mut Vec) { - self.0.find_validity_mismatch(other, idxs) + self.0.physical().find_validity_mismatch(other, idxs) } fn as_any(&self) -> &dyn Any { diff --git a/crates/polars-core/src/series/implementations/duration.rs b/crates/polars-core/src/series/implementations/duration.rs index 3d0f1811f2ea..47279bac62dd 100644 --- a/crates/polars-core/src/series/implementations/duration.rs +++ b/crates/polars-core/src/series/implementations/duration.rs @@ -1,5 +1,3 @@ -use std::ops::DerefMut; - use polars_compute::rolling::QuantileMethod; use super::*; @@ -16,13 +14,13 @@ unsafe impl IntoSeries for DurationChunked { impl private::PrivateSeriesNumeric for SeriesWrap { fn bit_repr(&self) -> Option { - Some(self.0.to_bit_repr()) + Some(self.0.physical().to_bit_repr()) } } impl private::PrivateSeries for SeriesWrap { fn compute_len(&mut self) { - self.0.compute_len() + self.0.physical_mut().compute_len() } fn _field(&self) -> Cow<'_, Field> { Cow::Owned(self.0.field()) @@ -32,20 +30,22 @@ impl private::PrivateSeries for SeriesWrap { } fn _set_flags(&mut self, flags: StatisticsFlags) { - self.0.deref_mut().set_flags(flags) + self.0.physical_mut().set_flags(flags) } + fn _get_flags(&self) -> StatisticsFlags { - self.0.deref().get_flags() + self.0.physical().get_flags() } unsafe fn equal_element(&self, idx_self: usize, idx_other: usize, other: &Series) -> bool { - self.0.equal_element(idx_self, idx_other, other) + self.0.physical().equal_element(idx_self, idx_other, other) } #[cfg(feature = "zip_with")] fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult { let other = other.to_physical_repr().into_owned(); self.0 + .physical() .zip_with(mask, other.as_ref().as_ref()) .map(|ca| ca.into_duration(self.0.time_unit()).into_series()) } @@ -62,7 +62,7 @@ impl private::PrivateSeries for SeriesWrap { random_state: PlSeedableRandomStateQuality, buf: &mut Vec, ) -> PolarsResult<()> { - self.0.vec_hash(random_state, buf)?; + self.0.physical().vec_hash(random_state, buf)?; Ok(()) } @@ -71,13 +71,14 @@ impl private::PrivateSeries for SeriesWrap { build_hasher: PlSeedableRandomStateQuality, hashes: &mut [u64], ) -> PolarsResult<()> { - self.0.vec_hash_combine(build_hasher, hashes)?; + self.0.physical().vec_hash_combine(build_hasher, hashes)?; Ok(()) } #[cfg(feature = "algorithm_group_by")] unsafe fn agg_min(&self, groups: &GroupsType) -> Series { self.0 + .physical() .agg_min(groups) .into_duration(self.0.time_unit()) .into_series() @@ -86,6 +87,7 @@ impl private::PrivateSeries for SeriesWrap { #[cfg(feature = "algorithm_group_by")] unsafe fn agg_max(&self, groups: &GroupsType) -> Series { self.0 + .physical() .agg_max(groups) .into_duration(self.0.time_unit()) .into_series() @@ -94,6 +96,7 @@ impl private::PrivateSeries for SeriesWrap { #[cfg(feature = "algorithm_group_by")] unsafe fn agg_sum(&self, groups: &GroupsType) -> Series { self.0 + .physical() .agg_sum(groups) .into_duration(self.0.time_unit()) .into_series() @@ -102,6 +105,7 @@ impl private::PrivateSeries for SeriesWrap { #[cfg(feature = "algorithm_group_by")] unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Series { self.0 + .physical() .agg_std(groups, ddof) // cast f64 back to physical type .cast(&DataType::Int64) @@ -113,6 +117,7 @@ impl private::PrivateSeries for SeriesWrap { #[cfg(feature = "algorithm_group_by")] unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Series { self.0 + .physical() .agg_var(groups, ddof) // cast f64 back to physical type .cast(&DataType::Int64) @@ -125,6 +130,7 @@ impl private::PrivateSeries for SeriesWrap { unsafe fn agg_list(&self, groups: &GroupsType) -> Series { // we cannot cast and dispatch as the inner type of the list would be incorrect self.0 + .physical() .agg_list(groups) .cast(&DataType::List(Box::new(self.dtype().clone()))) .unwrap() @@ -256,7 +262,7 @@ impl private::PrivateSeries for SeriesWrap { } #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { - self.0.group_tuples(multithreaded, sorted) + self.0.physical().group_tuples(multithreaded, sorted) } fn arg_sort_multiple( @@ -264,7 +270,7 @@ impl private::PrivateSeries for SeriesWrap { by: &[Column], options: &SortMultipleOptions, ) -> PolarsResult { - self.0.deref().arg_sort_multiple(by, options) + self.0.physical().arg_sort_multiple(by, options) } } @@ -274,66 +280,62 @@ impl SeriesTrait for SeriesWrap { } fn chunk_lengths(&self) -> ChunkLenIter<'_> { - self.0.chunk_lengths() + self.0.physical().chunk_lengths() } fn name(&self) -> &PlSmallStr { self.0.name() } fn chunks(&self) -> &Vec { - self.0.chunks() + self.0.physical().chunks() } unsafe fn chunks_mut(&mut self) -> &mut Vec { - self.0.chunks_mut() + self.0.physical_mut().chunks_mut() } fn shrink_to_fit(&mut self) { - self.0.shrink_to_fit() + self.0.physical_mut().shrink_to_fit() } fn slice(&self, offset: i64, length: usize) -> Series { - self.0 - .slice(offset, length) - .into_duration(self.0.time_unit()) - .into_series() + self.0.slice(offset, length).into_series() } fn split_at(&self, offset: i64) -> (Series, Series) { let (a, b) = self.0.split_at(offset); - let a = a.into_duration(self.0.time_unit()).into_series(); - let b = b.into_duration(self.0.time_unit()).into_series(); - (a, b) + (a.into_series(), b.into_series()) } fn _sum_as_f64(&self) -> f64 { - self.0._sum_as_f64() + self.0.physical()._sum_as_f64() } fn mean(&self) -> Option { - self.0.mean() + self.0.physical().mean() } fn median(&self) -> Option { - self.0.median() + self.0.physical().median() } fn std(&self, ddof: u8) -> Option { - self.0.std(ddof) + self.0.physical().std(ddof) } fn var(&self, ddof: u8) -> Option { - self.0.var(ddof) + self.0.physical().var(ddof) } fn append(&mut self, other: &Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), append); let mut other = other.to_physical_repr().into_owned(); self.0 + .physical_mut() .append_owned(std::mem::take(other._get_inner_mut().as_mut())) } fn append_owned(&mut self, mut other: Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), append); - self.0.append_owned(std::mem::take( + self.0.physical_mut().append_owned(std::mem::take( &mut other ._get_inner_mut() .as_any_mut() @@ -346,12 +348,15 @@ impl SeriesTrait for SeriesWrap { fn extend(&mut self, other: &Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), extend); let other = other.to_physical_repr(); - self.0.extend(other.as_ref().as_ref().as_ref())?; + self.0 + .physical_mut() + .extend(other.as_ref().as_ref().as_ref())?; Ok(()) } fn filter(&self, filter: &BooleanChunked) -> PolarsResult { self.0 + .physical() .filter(filter) .map(|ca| ca.into_duration(self.0.time_unit()).into_series()) } @@ -359,6 +364,7 @@ impl SeriesTrait for SeriesWrap { fn take(&self, indices: &IdxCa) -> PolarsResult { Ok(self .0 + .physical() .take(indices)? .into_duration(self.0.time_unit()) .into_series()) @@ -366,6 +372,7 @@ impl SeriesTrait for SeriesWrap { unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { self.0 + .physical() .take_unchecked(indices) .into_duration(self.0.time_unit()) .into_series() @@ -374,6 +381,7 @@ impl SeriesTrait for SeriesWrap { fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { Ok(self .0 + .physical() .take(indices)? .into_duration(self.0.time_unit()) .into_series()) @@ -381,6 +389,7 @@ impl SeriesTrait for SeriesWrap { unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { self.0 + .physical() .take_unchecked(indices) .into_duration(self.0.time_unit()) .into_series() @@ -392,6 +401,7 @@ impl SeriesTrait for SeriesWrap { fn rechunk(&self) -> Series { self.0 + .physical() .rechunk() .into_owned() .into_duration(self.0.time_unit()) @@ -400,6 +410,7 @@ impl SeriesTrait for SeriesWrap { fn new_from_index(&self, index: usize, length: usize) -> Series { self.0 + .physical() .new_from_index(index, length) .into_duration(self.0.time_unit()) .into_series() @@ -417,13 +428,14 @@ impl SeriesTrait for SeriesWrap { fn sort_with(&self, options: SortOptions) -> PolarsResult { Ok(self .0 + .physical() .sort_with(options) .into_duration(self.0.time_unit()) .into_series()) } fn arg_sort(&self, options: SortOptions) -> IdxCa { - self.0.arg_sort(options) + self.0.physical().arg_sort(options) } fn null_count(&self) -> usize { @@ -437,18 +449,19 @@ impl SeriesTrait for SeriesWrap { #[cfg(feature = "algorithm_group_by")] fn unique(&self) -> PolarsResult { self.0 + .physical() .unique() .map(|ca| ca.into_duration(self.0.time_unit()).into_series()) } #[cfg(feature = "algorithm_group_by")] fn n_unique(&self) -> PolarsResult { - self.0.n_unique() + self.0.physical().n_unique() } #[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { - self.0.arg_unique() + self.0.physical().arg_unique() } fn is_null(&self) -> BooleanChunked { @@ -461,40 +474,42 @@ impl SeriesTrait for SeriesWrap { fn reverse(&self) -> Series { self.0 + .physical() .reverse() .into_duration(self.0.time_unit()) .into_series() } fn as_single_ptr(&mut self) -> PolarsResult { - self.0.as_single_ptr() + self.0.physical_mut().as_single_ptr() } fn shift(&self, periods: i64) -> Series { self.0 + .physical() .shift(periods) .into_duration(self.0.time_unit()) .into_series() } fn sum_reduce(&self) -> PolarsResult { - let sc = self.0.sum_reduce(); + let sc = self.0.physical().sum_reduce(); let v = sc.value().as_duration(self.0.time_unit()); Ok(Scalar::new(self.dtype().clone(), v)) } fn max_reduce(&self) -> PolarsResult { - let sc = self.0.max_reduce(); + let sc = self.0.physical().max_reduce(); let v = sc.value().as_duration(self.0.time_unit()); Ok(Scalar::new(self.dtype().clone(), v)) } fn min_reduce(&self) -> PolarsResult { - let sc = self.0.min_reduce(); + let sc = self.0.physical().min_reduce(); let v = sc.value().as_duration(self.0.time_unit()); Ok(Scalar::new(self.dtype().clone(), v)) } fn std_reduce(&self, ddof: u8) -> PolarsResult { - let sc = self.0.std_reduce(ddof); + let sc = self.0.physical().std_reduce(ddof); let to = self.dtype().to_physical(); let v = sc.value().cast(&to); Ok(Scalar::new( @@ -509,6 +524,7 @@ impl SeriesTrait for SeriesWrap { let sc = self .0 .cast_time_unit(TimeUnit::Milliseconds) + .physical() .var_reduce(ddof); let to = self.dtype().to_physical(); let v = sc.value().cast(&to); @@ -527,7 +543,7 @@ impl SeriesTrait for SeriesWrap { )) } fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { - let v = self.0.quantile_reduce(quantile, method)?; + let v = self.0.physical().quantile_reduce(quantile, method)?; let to = self.dtype().to_physical(); let v = v.value().cast(&to); Ok(Scalar::new( @@ -541,7 +557,7 @@ impl SeriesTrait for SeriesWrap { } fn find_validity_mismatch(&self, other: &Series, idxs: &mut Vec) { - self.0.find_validity_mismatch(other, idxs) + self.0.physical().find_validity_mismatch(other, idxs) } fn as_any(&self) -> &dyn Any { diff --git a/crates/polars-core/src/series/implementations/time.rs b/crates/polars-core/src/series/implementations/time.rs index 23dec06f1022..300de56cd594 100644 --- a/crates/polars-core/src/series/implementations/time.rs +++ b/crates/polars-core/src/series/implementations/time.rs @@ -20,7 +20,7 @@ unsafe impl IntoSeries for TimeChunked { impl private::PrivateSeries for SeriesWrap { fn compute_len(&mut self) { - self.0.compute_len() + self.0.physical_mut().compute_len() } fn _field(&self) -> Cow<'_, Field> { @@ -32,17 +32,18 @@ impl private::PrivateSeries for SeriesWrap { } fn _get_flags(&self) -> StatisticsFlags { - self.0.get_flags() + self.0.physical().get_flags() } fn _set_flags(&mut self, flags: StatisticsFlags) { - self.0.set_flags(flags) + self.0.physical_mut().set_flags(flags) } #[cfg(feature = "zip_with")] fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult { let other = other.to_physical_repr().into_owned(); self.0 + .physical() .zip_with(mask, other.as_ref().as_ref()) .map(|ca| ca.into_time().into_series()) } @@ -59,7 +60,7 @@ impl private::PrivateSeries for SeriesWrap { random_state: PlSeedableRandomStateQuality, buf: &mut Vec, ) -> PolarsResult<()> { - self.0.vec_hash(random_state, buf)?; + self.0.physical().vec_hash(random_state, buf)?; Ok(()) } @@ -68,24 +69,25 @@ impl private::PrivateSeries for SeriesWrap { build_hasher: PlSeedableRandomStateQuality, hashes: &mut [u64], ) -> PolarsResult<()> { - self.0.vec_hash_combine(build_hasher, hashes)?; + self.0.physical().vec_hash_combine(build_hasher, hashes)?; Ok(()) } #[cfg(feature = "algorithm_group_by")] unsafe fn agg_min(&self, groups: &GroupsType) -> Series { - self.0.agg_min(groups).into_time().into_series() + self.0.physical().agg_min(groups).into_time().into_series() } #[cfg(feature = "algorithm_group_by")] unsafe fn agg_max(&self, groups: &GroupsType) -> Series { - self.0.agg_max(groups).into_time().into_series() + self.0.physical().agg_max(groups).into_time().into_series() } #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsType) -> Series { // we cannot cast and dispatch as the inner type of the list would be incorrect self.0 + .physical() .agg_list(groups) .cast(&DataType::List(Box::new(self.dtype().clone()))) .unwrap() @@ -120,7 +122,7 @@ impl private::PrivateSeries for SeriesWrap { #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { - self.0.group_tuples(multithreaded, sorted) + self.0.physical().group_tuples(multithreaded, sorted) } fn arg_sort_multiple( @@ -128,7 +130,7 @@ impl private::PrivateSeries for SeriesWrap { by: &[Column], options: &SortMultipleOptions, ) -> PolarsResult { - self.0.deref().arg_sort_multiple(by, options) + self.0.physical().arg_sort_multiple(by, options) } } @@ -138,52 +140,55 @@ impl SeriesTrait for SeriesWrap { } fn chunk_lengths(&self) -> ChunkLenIter<'_> { - self.0.chunk_lengths() + self.0.physical().chunk_lengths() } fn name(&self) -> &PlSmallStr { self.0.name() } fn chunks(&self) -> &Vec { - self.0.chunks() + self.0.physical().chunks() } + unsafe fn chunks_mut(&mut self) -> &mut Vec { - self.0.chunks_mut() + self.0.physical_mut().chunks_mut() } fn shrink_to_fit(&mut self) { - self.0.shrink_to_fit() + self.0.physical_mut().shrink_to_fit() } fn slice(&self, offset: i64, length: usize) -> Series { - self.0.slice(offset, length).into_time().into_series() + self.0.slice(offset, length).into_series() } fn split_at(&self, offset: i64) -> (Series, Series) { let (a, b) = self.0.split_at(offset); - (a.into_time().into_series(), b.into_time().into_series()) + (a.into_series(), b.into_series()) } fn _sum_as_f64(&self) -> f64 { - self.0._sum_as_f64() + self.0.physical()._sum_as_f64() } fn mean(&self) -> Option { - self.0.mean() + self.0.physical().mean() } fn median(&self) -> Option { - self.0.median() + self.0.physical().median() } fn append(&mut self, other: &Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), append); let mut other = other.to_physical_repr().into_owned(); self.0 + .physical_mut() .append_owned(std::mem::take(other._get_inner_mut().as_mut())) } + fn append_owned(&mut self, mut other: Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), append); - self.0.append_owned(std::mem::take( + self.0.physical_mut().append_owned(std::mem::take( &mut other ._get_inner_mut() .as_any_mut() @@ -200,28 +205,41 @@ impl SeriesTrait for SeriesWrap { // ref SeriesTrait // ref ChunkedArray let other = other.to_physical_repr(); - self.0.extend(other.as_ref().as_ref().as_ref())?; + self.0 + .physical_mut() + .extend(other.as_ref().as_ref().as_ref())?; Ok(()) } fn filter(&self, filter: &BooleanChunked) -> PolarsResult { - self.0.filter(filter).map(|ca| ca.into_time().into_series()) + self.0 + .physical() + .filter(filter) + .map(|ca| ca.into_time().into_series()) } fn take(&self, indices: &IdxCa) -> PolarsResult { - Ok(self.0.take(indices)?.into_time().into_series()) + Ok(self.0.physical().take(indices)?.into_time().into_series()) } unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { - self.0.take_unchecked(indices).into_time().into_series() + self.0 + .physical() + .take_unchecked(indices) + .into_time() + .into_series() } fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { - Ok(self.0.take(indices)?.into_time().into_series()) + Ok(self.0.physical().take(indices)?.into_time().into_series()) } unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { - self.0.take_unchecked(indices).into_time().into_series() + self.0 + .physical() + .take_unchecked(indices) + .into_time() + .into_series() } fn len(&self) -> usize { @@ -229,11 +247,17 @@ impl SeriesTrait for SeriesWrap { } fn rechunk(&self) -> Series { - self.0.rechunk().into_owned().into_time().into_series() + self.0 + .physical() + .rechunk() + .into_owned() + .into_time() + .into_series() } fn new_from_index(&self, index: usize, length: usize) -> Series { self.0 + .physical() .new_from_index(index, length) .into_time() .into_series() @@ -259,11 +283,16 @@ impl SeriesTrait for SeriesWrap { } fn sort_with(&self, options: SortOptions) -> PolarsResult { - Ok(self.0.sort_with(options).into_time().into_series()) + Ok(self + .0 + .physical() + .sort_with(options) + .into_time() + .into_series()) } fn arg_sort(&self, options: SortOptions) -> IdxCa { - self.0.arg_sort(options) + self.0.physical().arg_sort(options) } fn null_count(&self) -> usize { @@ -276,17 +305,20 @@ impl SeriesTrait for SeriesWrap { #[cfg(feature = "algorithm_group_by")] fn unique(&self) -> PolarsResult { - self.0.unique().map(|ca| ca.into_time().into_series()) + self.0 + .physical() + .unique() + .map(|ca| ca.into_time().into_series()) } #[cfg(feature = "algorithm_group_by")] fn n_unique(&self) -> PolarsResult { - self.0.n_unique() + self.0.physical().n_unique() } #[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { - self.0.arg_unique() + self.0.physical().arg_unique() } fn is_null(&self) -> BooleanChunked { @@ -298,25 +330,25 @@ impl SeriesTrait for SeriesWrap { } fn reverse(&self) -> Series { - self.0.reverse().into_time().into_series() + self.0.physical().reverse().into_time().into_series() } fn as_single_ptr(&mut self) -> PolarsResult { - self.0.as_single_ptr() + self.0.physical_mut().as_single_ptr() } fn shift(&self, periods: i64) -> Series { - self.0.shift(periods).into_time().into_series() + self.0.physical().shift(periods).into_time().into_series() } fn max_reduce(&self) -> PolarsResult { - let sc = self.0.max_reduce(); + let sc = self.0.physical().max_reduce(); let av = sc.value().cast(self.dtype()).into_static(); Ok(Scalar::new(self.dtype().clone(), av)) } fn min_reduce(&self) -> PolarsResult { - let sc = self.0.min_reduce(); + let sc = self.0.physical().min_reduce(); let av = sc.value().cast(self.dtype()).into_static(); Ok(Scalar::new(self.dtype().clone(), av)) } @@ -333,7 +365,7 @@ impl SeriesTrait for SeriesWrap { } fn find_validity_mismatch(&self, other: &Series, idxs: &mut Vec) { - self.0.find_validity_mismatch(other, idxs) + self.0.physical().find_validity_mismatch(other, idxs) } fn as_any(&self) -> &dyn Any { @@ -355,6 +387,6 @@ impl SeriesTrait for SeriesWrap { impl private::PrivateSeriesNumeric for SeriesWrap { fn bit_repr(&self) -> Option { - Some(self.0.to_bit_repr()) + Some(self.0.physical().to_bit_repr()) } } diff --git a/crates/polars-core/src/series/into.rs b/crates/polars-core/src/series/into.rs index c42b54e349f5..5780187e8676 100644 --- a/crates/polars-core/src/series/into.rs +++ b/crates/polars-core/src/series/into.rs @@ -18,7 +18,7 @@ impl Series { /// Convert a chunk in the Series to the correct Arrow type. /// This conversion is needed because polars doesn't use a - /// 1 on 1 mapping for logical/ categoricals, etc. + /// 1 on 1 mapping for logical/categoricals, etc. pub fn to_arrow(&self, chunk_idx: usize, compat_level: CompatLevel) -> ArrayRef { match self.dtype() { // make sure that we recursively apply all logical types. @@ -114,23 +114,16 @@ impl Series { Box::new(arr) }, #[cfg(feature = "dtype-categorical")] - dt @ (DataType::Categorical(_, ordering) | DataType::Enum(_, ordering)) => { - let ca = self.categorical().unwrap(); - let arr = ca.physical().chunks()[chunk_idx].clone(); - // SAFETY: categoricals are always u32's. - let cats = unsafe { UInt32Chunked::from_chunks(PlSmallStr::EMPTY, vec![arr]) }; - - // SAFETY: we only take a single chunk and change nothing about the index/rev_map mapping. - let new = unsafe { - CategoricalChunked::from_cats_and_rev_map_unchecked( - cats, - ca.get_rev_map().clone(), - matches!(dt, DataType::Enum(_, _)), - *ordering, - ) - }; - - new.to_arrow(compat_level, false) + dt @ (DataType::Categorical(_, _) | DataType::Enum(_, _)) => { + with_match_categorical_physical_type!(dt.cat_physical().unwrap(), |$C| { + let ca = self.cat::<$C>().unwrap(); + let arr = ca.physical().chunks()[chunk_idx].clone(); + unsafe { + let new_phys = ChunkedArray::from_chunks(PlSmallStr::EMPTY, vec![arr]); + let new = CategoricalChunked::<$C>::from_cats_and_dtype_unchecked(new_phys, dt.clone()); + new.to_arrow(compat_level).boxed() + } + }) }, #[cfg(feature = "dtype-date")] DataType::Date => cast( @@ -157,7 +150,7 @@ impl Series { ) .unwrap(), #[cfg(feature = "dtype-decimal")] - DataType::Decimal(_, _) => self.decimal().unwrap().chunks()[chunk_idx] + DataType::Decimal(_, _) => self.decimal().unwrap().physical().chunks()[chunk_idx] .as_any() .downcast_ref::>() .unwrap() diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index 0490e8fa92a8..520fe772331c 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -392,10 +392,6 @@ impl Series { true }, dt if dt.is_primitive() && dt == slf.dtype() => true, - #[cfg(feature = "dtype-categorical")] - D::Enum(None, _) => { - polars_bail!(InvalidOperation: "cannot cast / initialize Enum without categories present"); - }, _ => false, }; @@ -514,36 +510,29 @@ impl Series { }, #[cfg(feature = "dtype-categorical")] - (D::UInt32, D::Categorical(revmap, ordering)) => match revmap { - Some(revmap) => Ok(unsafe { - CategoricalChunked::from_cats_and_rev_map_unchecked( - self.u32().unwrap().clone(), - revmap.clone(), - false, - *ordering, - ) - } - .into_series()), - // In the streaming engine this is `None` and the global string cache is turned on - // for the duration of the query. - None => Ok(unsafe { - CategoricalChunked::from_global_indices_unchecked( - self.u32().unwrap().clone(), - *ordering, + (phys, D::Categorical(cats, _)) if &cats.physical().dtype() == phys => { + with_match_categorical_physical_type!(cats.physical(), |$C| { + type CA = ChunkedArray<<$C as PolarsCategoricalType>::PolarsPhysical>; + let ca = self.as_ref().as_any().downcast_ref::().unwrap(); + Ok(CategoricalChunked::<$C>::from_cats_and_dtype_unchecked( + ca.clone(), + dtype.clone(), ) - .into_series() - }), + .into_series()) + }) }, #[cfg(feature = "dtype-categorical")] - (D::UInt32, D::Enum(revmap, ordering)) => Ok(unsafe { - CategoricalChunked::from_cats_and_rev_map_unchecked( - self.u32().unwrap().clone(), - revmap.as_ref().unwrap().clone(), - true, - *ordering, - ) - } - .into_series()), + (phys, D::Enum(fcats, _)) if &fcats.physical().dtype() == phys => { + with_match_categorical_physical_type!(fcats.physical(), |$C| { + type CA = ChunkedArray<<$C as PolarsCategoricalType>::PolarsPhysical>; + let ca = self.as_ref().as_any().downcast_ref::().unwrap(); + Ok(CategoricalChunked::<$C>::from_cats_and_dtype_unchecked( + ca.clone(), + dtype.clone(), + ) + .into_series()) + }) + }, (D::Int32, D::Date) => feature_gated!("dtype-time", Ok(self.clone().into_date())), (D::Int64, D::Datetime(tu, tz)) => feature_gated!( @@ -711,7 +700,7 @@ impl Series { /// * Duration -> Int64 /// * Decimal -> Int128 /// * Time -> Int64 - /// * Categorical -> UInt32 + /// * Categorical -> U8/U16/U32 /// * List(inner) -> List(physical of inner) /// * Array(inner) -> Array(physical of inner) /// * Struct -> Struct with physical repr of each struct column @@ -729,9 +718,11 @@ impl Series { #[cfg(feature = "dtype-time")] Time => Cow::Owned(self.time().unwrap().phys.clone().into_series()), #[cfg(feature = "dtype-categorical")] - Categorical(_, _) | Enum(_, _) => { - let ca = self.categorical().unwrap(); - Cow::Owned(ca.physical().clone().into_series()) + dt @ (Categorical(_, _) | Enum(_, _)) => { + with_match_categorical_physical_type!(dt.cat_physical().unwrap(), |$C| { + let ca = self.cat::<$C>().unwrap(); + Cow::Owned(ca.physical().clone().into_series()) + }) }, #[cfg(feature = "dtype-decimal")] Decimal(_, _) => Cow::Owned(self.decimal().unwrap().phys.clone().into_series()), @@ -847,7 +838,7 @@ impl Series { DataType::Time => self .time() .unwrap() - .as_ref() + .physical() .clone() .into_time() .into_series(), @@ -866,7 +857,7 @@ impl Series { DataType::Date => self .date() .unwrap() - .as_ref() + .physical() .clone() .into_date() .into_series(), @@ -892,7 +883,7 @@ impl Series { DataType::Datetime(_, _) => self .datetime() .unwrap() - .as_ref() + .physical() .clone() .into_datetime(timeunit, tz) .into_series(), @@ -917,7 +908,7 @@ impl Series { DataType::Duration(_) => self .duration() .unwrap() - .as_ref() + .physical() .clone() .into_duration(timeunit) .into_series(), @@ -991,13 +982,7 @@ impl Series { pub fn estimated_size(&self) -> usize { let mut size = 0; match self.dtype() { - #[cfg(feature = "dtype-categorical")] - DataType::Categorical(Some(rv), _) | DataType::Enum(Some(rv), _) => match &**rv { - RevMapping::Local(arr, _) => size += estimated_bytes_size(arr), - RevMapping::Global(map, arr, _) => { - size += map.capacity() * size_of::() * 2 + estimated_bytes_size(arr); - }, - }, + // TODO @ cat-rework: include mapping size here? #[cfg(feature = "object")] DataType::Object(_) => { let ArrowDataType::FixedSizeBinary(size) = self.chunks()[0].dtype() else { diff --git a/crates/polars-core/src/series/ops/downcast.rs b/crates/polars-core/src/series/ops/downcast.rs index d060833e0b29..e8f068b94987 100644 --- a/crates/polars-core/src/series/ops/downcast.rs +++ b/crates/polars-core/src/series/ops/downcast.rs @@ -9,9 +9,9 @@ macro_rules! unpack_chunked_err { } macro_rules! try_unpack_chunked { - ($series:expr, $expected:pat => $ca:ty) => { + ($series:expr, $expected:pat $(if $guard: expr)? => $ca:ty) => { match $series.dtype() { - $expected => { + $expected $(if $guard)? => { // Check downcast in debug compiles #[cfg(debug_assertions)] { @@ -159,10 +159,25 @@ impl Series { try_unpack_chunked!(self, DataType::Array(_, _) => ArrayChunked) } - /// Unpack to [`ChunkedArray`] of dtype [`DataType::Categorical`] #[cfg(feature = "dtype-categorical")] - pub fn try_categorical(&self) -> Option<&CategoricalChunked> { - try_unpack_chunked!(self, DataType::Categorical(_, _) | DataType::Enum(_, _) => CategoricalChunked) + pub fn try_cat(&self) -> Option<&CategoricalChunked> { + try_unpack_chunked!(self, dt @ DataType::Enum(_, _) | dt @ DataType::Categorical(_, _) if dt.cat_physical().unwrap() == T::physical() => CategoricalChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Categorical`] or [`DataType::Enum`] with a physical type of UInt8. + #[cfg(feature = "dtype-categorical")] + pub fn try_cat8(&self) -> Option<&Categorical8Chunked> { + self.try_cat::() + } + + #[cfg(feature = "dtype-categorical")] + pub fn try_cat16(&self) -> Option<&Categorical16Chunked> { + self.try_cat::() + } + + #[cfg(feature = "dtype-categorical")] + pub fn try_cat32(&self) -> Option<&Categorical32Chunked> { + self.try_cat::() } /// Unpack to [`ChunkedArray`] of dtype [`DataType::Struct`] @@ -335,13 +350,34 @@ impl Series { .ok_or_else(|| unpack_chunked_err!(self => "Array")) } - /// Unpack to [`ChunkedArray`] of dtype [`DataType::Categorical`] + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Categorical`] or [`DataType::Enum`]. #[cfg(feature = "dtype-categorical")] - pub fn categorical(&self) -> PolarsResult<&CategoricalChunked> { - self.try_categorical() + pub fn cat(&self) -> PolarsResult<&CategoricalChunked> { + self.try_cat::() .ok_or_else(|| unpack_chunked_err!(self => "Enum | Categorical")) } + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Categorical`] or [`DataType::Enum`] with a physical type of UInt8. + #[cfg(feature = "dtype-categorical")] + pub fn cat8(&self) -> PolarsResult<&CategoricalChunked> { + self.try_cat8() + .ok_or_else(|| unpack_chunked_err!(self => "Enum8 | Categorical8")) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Categorical`] or [`DataType::Enum`] with a physical type of UInt16. + #[cfg(feature = "dtype-categorical")] + pub fn cat16(&self) -> PolarsResult<&CategoricalChunked> { + self.try_cat16() + .ok_or_else(|| unpack_chunked_err!(self => "Enum16 | Categorical16")) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Categorical`] or [`DataType::Enum`] with a physical type of UInt32. + #[cfg(feature = "dtype-categorical")] + pub fn cat32(&self) -> PolarsResult<&CategoricalChunked> { + self.try_cat32() + .ok_or_else(|| unpack_chunked_err!(self => "Enum32 | Categorical32")) + } + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Struct`] #[cfg(feature = "dtype-struct")] pub fn struct_(&self) -> PolarsResult<&StructChunked> { diff --git a/crates/polars-core/src/series/ops/null.rs b/crates/polars-core/src/series/ops/null.rs index 7f17e2120e4d..fed384336bdd 100644 --- a/crates/polars-core/src/series/ops/null.rs +++ b/crates/polars-core/src/series/ops/null.rs @@ -18,18 +18,15 @@ impl Series { ArrayChunked::full_null_with_dtype(name, size, inner_dtype, *width).into_series() }, #[cfg(feature = "dtype-categorical")] - dt @ (DataType::Categorical(rev_map, ord) | DataType::Enum(rev_map, ord)) => { - let mut ca = CategoricalChunked::full_null( - name, - matches!(dt, DataType::Enum(_, _)), - size, - *ord, - ); - // ensure we keep the rev-map of a cleared series - if let Some(rev_map) = rev_map { - unsafe { ca.set_rev_map(rev_map.clone(), false) } - } - ca.into_series() + dt @ (DataType::Categorical(_, _) | DataType::Enum(_, _)) => { + with_match_categorical_physical_type!(dt.cat_physical().unwrap(), |$C| { + CategoricalChunked::<$C>::full_null_with_dtype( + name, + size, + dtype.clone() + ) + .into_series() + }) }, #[cfg(feature = "dtype-date")] DataType::Date => Int32Chunked::full_null(name, size) diff --git a/crates/polars-core/src/utils/mod.rs b/crates/polars-core/src/utils/mod.rs index 9aa683f303de..2ced9e6b2ee6 100644 --- a/crates/polars-core/src/utils/mod.rs +++ b/crates/polars-core/src/utils/mod.rs @@ -571,6 +571,18 @@ macro_rules! with_match_physical_integer_polars_type {( } })} +#[macro_export] +macro_rules! with_match_categorical_physical_type {( + $dtype:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + match $dtype { + CategoricalPhysical::U8 => __with_ty__! { Categorical8Type }, + CategoricalPhysical::U16 => __with_ty__! { Categorical16Type }, + CategoricalPhysical::U32 => __with_ty__! { Categorical32Type }, + } +})} + /// Apply a macro on the Downcasted ChunkedArrays of DataTypes that are logical numerics. /// So no logical. #[macro_export] diff --git a/crates/polars-core/src/utils/supertype.rs b/crates/polars-core/src/utils/supertype.rs index 0e1d2859aa1c..c87b146874bd 100644 --- a/crates/polars-core/src/utils/supertype.rs +++ b/crates/polars-core/src/utils/supertype.rs @@ -413,10 +413,7 @@ pub fn get_supertype_with_options( UnknownKind::Str if dt.is_string() | dt.is_enum() => Some(dt.clone()), // Materialize str #[cfg(feature = "dtype-categorical")] - UnknownKind::Str if dt.is_categorical() => { - let Categorical(_, ord) = dt else { unreachable!()}; - Some(Categorical(None, *ord)) - }, + UnknownKind::Str if dt.is_categorical() => Some(dt.clone()), // Keep unknown dynam if dt.is_null() => Some(Unknown(*dynam)), // Find integers sizes @@ -612,42 +609,5 @@ pub fn merge_dtypes_many + Clone, D: AsRef>( st = try_get_supertype(d.as_ref(), &st)?; } - match st { - #[cfg(feature = "dtype-categorical")] - DataType::Categorical(Some(_), ordering) => { - // This merges the global rev maps with linear complexity. - // If we do a binary reduce, it would be quadratic. - let mut iter = into_iter.into_iter(); - let first_dt = iter.next().unwrap(); - let first_dt = first_dt.as_ref(); - let DataType::Categorical(Some(rm), _) = first_dt else { - unreachable!() - }; - polars_ensure!(matches!(rm.as_ref(), RevMapping::Global(_, _, _)), ComputeError: "global string cache must be set to merge categorical columns"); - - let mut merger = GlobalRevMapMerger::new(rm.clone()); - - for d in iter { - if let DataType::Categorical(Some(rm), _) = d.as_ref() { - merger.merge_map(rm)? - } - } - let rev_map = merger.finish(); - - Ok(DataType::Categorical(Some(rev_map), ordering)) - }, - // This would be quadratic if we do this with the binary `merge_dtypes`. - DataType::List(inner) if inner.contains_categoricals() => { - polars_bail!(ComputeError: "merging nested categoricals not yet supported") - }, - #[cfg(feature = "dtype-array")] - DataType::Array(inner, _) if inner.contains_categoricals() => { - polars_bail!(ComputeError: "merging nested categoricals not yet supported") - }, - #[cfg(feature = "dtype-struct")] - DataType::Struct(fields) if fields.iter().any(|f| f.dtype().contains_categoricals()) => { - polars_bail!(ComputeError: "merging nested categoricals not yet supported") - }, - _ => Ok(st), - } + Ok(st) } diff --git a/crates/polars-dtype/Cargo.toml b/crates/polars-dtype/Cargo.toml new file mode 100644 index 000000000000..318a6cd8fa71 --- /dev/null +++ b/crates/polars-dtype/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "polars-dtype" +version.workspace = true +authors.workspace = true +edition.workspace = true +homepage.workspace = true +license.workspace = true +repository.workspace = true +description = "Low-level datatype definitions of the Polars project." + +[dependencies] +arrow = { workspace = true } +boxcar = { workspace = true } +hashbrown = { workspace = true } +polars-error = { workspace = true } +polars-utils = { workspace = true } +schemars = { workspace = true, optional = true } +serde = { workspace = true, optional = true } + +[features] +default = [] +serde = ["dep:serde", "arrow/serde", "polars-utils/serde"] +dsl-schema = ["dep:schemars", "arrow/dsl-schema", "polars-utils/dsl-schema"] + +[lints] +workspace = true diff --git a/crates/polars-dtype/LICENSE b/crates/polars-dtype/LICENSE new file mode 120000 index 000000000000..30cff7403da0 --- /dev/null +++ b/crates/polars-dtype/LICENSE @@ -0,0 +1 @@ +../../LICENSE \ No newline at end of file diff --git a/crates/polars-dtype/README.md b/crates/polars-dtype/README.md new file mode 100644 index 000000000000..e2119224985b --- /dev/null +++ b/crates/polars-dtype/README.md @@ -0,0 +1,9 @@ +# polars-dtype + +Low-level datatype definitions of the Polars project. + +`polars-dtype` is an **internal sub-crate** of the [Polars](https://crates.io/crates/polars) +library. + +**Important Note**: This crate is **not intended for external usage**. Please refer to the main +[Polars crate](https://crates.io/crates/polars) for intended usage. diff --git a/crates/polars-dtype/src/categorical/catsize.rs b/crates/polars-dtype/src/categorical/catsize.rs new file mode 100644 index 000000000000..8ac8e5a309b6 --- /dev/null +++ b/crates/polars-dtype/src/categorical/catsize.rs @@ -0,0 +1,52 @@ +pub type CatSize = u32; + +pub trait CatNative { + fn as_cat(&self) -> CatSize; + fn from_cat(cat: CatSize) -> Self; +} + +impl CatNative for u8 { + fn as_cat(&self) -> CatSize { + *self as CatSize + } + + fn from_cat(cat: CatSize) -> Self { + #[cfg(debug_assertions)] + { + cat.try_into().unwrap() + } + + #[cfg(not(debug_assertions))] + { + cat as Self + } + } +} + +impl CatNative for u16 { + fn as_cat(&self) -> CatSize { + *self as CatSize + } + + fn from_cat(cat: CatSize) -> Self { + #[cfg(debug_assertions)] + { + cat.try_into().unwrap() + } + + #[cfg(not(debug_assertions))] + { + cat as Self + } + } +} + +impl CatNative for u32 { + fn as_cat(&self) -> CatSize { + *self + } + + fn from_cat(cat: CatSize) -> Self { + cat + } +} diff --git a/crates/polars-core/src/datatypes/categories/mapping.rs b/crates/polars-dtype/src/categorical/mapping.rs similarity index 65% rename from crates/polars-core/src/datatypes/categories/mapping.rs rename to crates/polars-dtype/src/categorical/mapping.rs index 394a55366f47..5b113ef964e1 100644 --- a/crates/polars-core/src/datatypes/categories/mapping.rs +++ b/crates/polars-dtype/src/categorical/mapping.rs @@ -1,12 +1,18 @@ +use std::fmt; use std::hash::BuildHasher; use std::sync::atomic::{AtomicUsize, Ordering}; +use arrow::array::builder::StaticArrayBuilder; +use arrow::array::{Array, MutableUtf8Array, Utf8Array, Utf8ViewArrayBuilder}; +use arrow::datatypes::ArrowDataType; use polars_error::{PolarsResult, polars_bail}; use polars_utils::aliases::PlSeedableRandomStateQuality; use polars_utils::parma::raw::RawTable; +use super::CatSize; + pub struct CategoricalMapping { - str_to_cat: RawTable, + str_to_cat: RawTable, cat_to_str: boxcar::Vec<&'static str>, max_categories: usize, upper_bound: AtomicUsize, @@ -33,29 +39,34 @@ impl CategoricalMapping { &self.hasher } + pub fn set_max_categories(&mut self, max_categories: usize) { + assert!(max_categories >= self.num_cats_upper_bound()); + self.max_categories = max_categories + } + /// Try to convert a string to a categorical id, but don't insert it if it is missing. #[inline(always)] - pub fn get_cat(&self, s: &str) -> Option { + pub fn get_cat(&self, s: &str) -> Option { let hash = self.hasher.hash_one(s); self.get_cat_with_hash(s, hash) } /// Same as get_cat, but with the hash pre-computed. #[inline(always)] - pub fn get_cat_with_hash(&self, s: &str, hash: u64) -> Option { + pub fn get_cat_with_hash(&self, s: &str, hash: u64) -> Option { self.str_to_cat.get(hash, |k| k == s).copied() } /// Convert a string to a categorical id. #[inline(always)] - pub fn insert_cat(&self, s: &str) -> PolarsResult { + pub fn insert_cat(&self, s: &str) -> PolarsResult { let hash = self.hasher.hash_one(s); self.insert_cat_with_hash(s, hash) } /// Same as to_cat, but with the hash pre-computed. #[inline(always)] - pub fn insert_cat_with_hash(&self, s: &str, hash: u64) -> PolarsResult { + pub fn insert_cat_with_hash(&self, s: &str, hash: u64) -> PolarsResult { self.str_to_cat .try_get_or_insert_with( hash, @@ -63,14 +74,14 @@ impl CategoricalMapping { |k| k == s, |k| { let old_upper_bound = self.upper_bound.fetch_add(1, Ordering::Relaxed); - if old_upper_bound + 1 >= self.max_categories { + if old_upper_bound + 1 > self.max_categories { self.upper_bound.fetch_sub(1, Ordering::Relaxed); polars_bail!(ComputeError: "attempted to insert more categories than the maximum allowed"); } let idx = self .cat_to_str .push(unsafe { core::mem::transmute::<&str, &'static str>(k) }); - Ok(idx as u32) + Ok(idx as CatSize) }, ) .copied() @@ -79,7 +90,7 @@ impl CategoricalMapping { /// Try to convert a categorical id to its corresponding string, returning /// None if the string is not in the data structure. #[inline(always)] - pub fn cat_to_str(&self, cat: u32) -> Option<&str> { + pub fn cat_to_str(&self, cat: CatSize) -> Option<&str> { self.cat_to_str.get(cat as usize).copied() } @@ -89,7 +100,7 @@ impl CategoricalMapping { /// The categorical id must have been returned from `to_cat`, and you must /// have synchronized with the call which inserted it. #[inline(always)] - pub unsafe fn cat_to_str_unchecked(&self, cat: u32) -> &str { + pub unsafe fn cat_to_str_unchecked(&self, cat: CatSize) -> &str { unsafe { self.cat_to_str.get_unchecked(cat as usize) } } @@ -115,4 +126,35 @@ impl CategoricalMapping { pub fn is_empty(&mut self) -> bool { self.len() == 0 } + + pub fn to_arrow(&self, as_views: bool) -> Box { + let n = self.num_cats_upper_bound(); + if as_views { + let mut builder = Utf8ViewArrayBuilder::new(ArrowDataType::Utf8View); + builder.reserve(n); + for i in 0..n { + let s = self.cat_to_str(i as CatSize).unwrap_or_default(); + builder.push_value_ignore_validity(s); + } + builder.freeze().boxed() + } else { + let mut builder = MutableUtf8Array::new(); + builder.reserve(n, 0); + for i in 0..n { + let s = self.cat_to_str(i as CatSize).unwrap_or_default(); + builder.push(Some(s)); + } + let arr: Utf8Array = builder.into(); + arr.boxed() + } + } +} + +impl fmt::Debug for CategoricalMapping { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CategoricalMapping") + .field("max_categories", &self.max_categories) + .field("upper_bound", &self.upper_bound.load(Ordering::Relaxed)) + .finish() + } } diff --git a/crates/polars-dtype/src/categorical/mod.rs b/crates/polars-dtype/src/categorical/mod.rs new file mode 100644 index 000000000000..e836568b34be --- /dev/null +++ b/crates/polars-dtype/src/categorical/mod.rs @@ -0,0 +1,341 @@ +use std::fmt; +use std::hash::{BuildHasher, Hasher}; +use std::str::FromStr; +use std::sync::{Arc, LazyLock, Mutex, Weak}; + +use arrow::array::builder::StaticArrayBuilder; +use arrow::array::{Utf8ViewArray, Utf8ViewArrayBuilder}; +use arrow::datatypes::ArrowDataType; +use hashbrown::HashTable; +use hashbrown::hash_table::Entry; +use polars_error::{PolarsResult, polars_bail, polars_ensure}; +use polars_utils::aliases::*; +use polars_utils::pl_str::PlSmallStr; + +mod catsize; +mod mapping; + +pub use catsize::{CatNative, CatSize}; +pub use mapping::CategoricalMapping; + +/// The physical datatype backing a categorical / enum. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))] +pub enum CategoricalPhysical { + U8, + U16, + U32, +} + +impl CategoricalPhysical { + pub fn max_categories(&self) -> usize { + // We might use T::MAX as an indicator, so the maximum number of categories is T::MAX + // (giving T::MAX - 1 as the largest category). + match self { + Self::U8 => u8::MAX as usize, + Self::U16 => u16::MAX as usize, + Self::U32 => u32::MAX as usize, + } + } + + pub fn smallest_physical(num_cats: usize) -> PolarsResult { + if num_cats < u8::MAX as usize { + Ok(Self::U8) + } else if num_cats < u16::MAX as usize { + Ok(Self::U16) + } else if num_cats < u32::MAX as usize { + Ok(Self::U32) + } else { + polars_bail!(ComputeError: "attempted to insert more categories than the maximum allowed") + } + } + + pub fn as_str(&self) -> &'static str { + match self { + Self::U8 => "u8", + Self::U16 => "u16", + Self::U32 => "u32", + } + } +} + +impl FromStr for CategoricalPhysical { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "u8" => Ok(Self::U8), + "u16" => Ok(Self::U16), + "u32" => Ok(Self::U32), + _ => Err(()), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct CategoricalId { + name: PlSmallStr, + namespace: PlSmallStr, + physical: CategoricalPhysical, +} + +impl CategoricalId { + fn global() -> Self { + Self { + name: PlSmallStr::from_static("__POLARS_GLOBAL_CATEGORIES"), + namespace: PlSmallStr::from_static(""), + physical: CategoricalPhysical::U32, + } + } +} + +// Used to maintain a 1:1 mapping between Categories' ID and the Categories objects themselves. +// This is important for serialization. +static CATEGORIES_REGISTRY: LazyLock>>> = + LazyLock::new(|| Mutex::new(PlHashMap::new())); + +// Used to make FrozenCategories unique based on their content. This allows comparison of datatypes +// in constant time by comparing pointers. +#[expect(clippy::type_complexity)] +static FROZEN_CATEGORIES_REGISTRY: LazyLock)>>> = + LazyLock::new(|| Mutex::new(HashTable::new())); + +static FROZEN_CATEGORIES_HASHER: LazyLock = + LazyLock::new(PlSeedableRandomStateQuality::random); + +static GLOBAL_CATEGORIES: LazyLock> = LazyLock::new(|| { + let categories = Arc::new(Categories { + id: CategoricalId::global(), + mapping: Mutex::new(Weak::new()), + }); + CATEGORIES_REGISTRY + .lock() + .unwrap() + .insert(CategoricalId::global(), Arc::downgrade(&categories)); + categories +}); + +/// A (named) object which is used to indicate which categorical data types have the same mapping. +pub struct Categories { + id: CategoricalId, + mapping: Mutex>, +} + +impl Categories { + /// Creates a new Categories object with the given name, namespace and physical type if none exists, otherwise + /// get a reference to an existing object with the same name, namespace and physical type. + pub fn new( + name: PlSmallStr, + namespace: PlSmallStr, + physical: CategoricalPhysical, + ) -> Arc { + let id = CategoricalId { + name, + namespace, + physical, + }; + let mut registry = CATEGORIES_REGISTRY.lock().unwrap(); + if let Some(cats_ref) = registry.get(&id) { + if let Some(cats) = cats_ref.upgrade() { + return cats; + } + } + let mapping = Mutex::new(Weak::new()); + let slf = Arc::new(Self { + id: id.clone(), + mapping, + }); + registry.insert(id, Arc::downgrade(&slf)); + slf + } + + /// Returns the global Categories. + pub fn global() -> Arc { + GLOBAL_CATEGORIES.clone() + } + + /// The name of this Categories object. + pub fn name(&self) -> &PlSmallStr { + &self.id.name + } + + /// The namespace of this Categories object. + pub fn namespace(&self) -> &PlSmallStr { + &self.id.namespace + } + + /// The physical dtype of the category ids. + pub fn physical(&self) -> CategoricalPhysical { + self.id.physical + } + + /// The mapping for this Categories object. If no mapping currently exists + /// it creates a new empty mapping. + pub fn mapping(&self) -> Arc { + let mut guard = self.mapping.lock().unwrap(); + if let Some(arc) = guard.upgrade() { + return arc; + } + let arc = Arc::new(CategoricalMapping::new(self.id.physical.max_categories())); + *guard = Arc::downgrade(&arc); + arc + } + + pub fn freeze(&self) -> Arc { + let mapping = self.mapping(); + let n = mapping.num_cats_upper_bound(); + FrozenCategories::new((0..n).flat_map(|i| mapping.cat_to_str(i as CatSize))).unwrap() + } +} + +impl fmt::Debug for Categories { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Categories") + .field("name", &self.id.name) + .field("namespace", &self.id.namespace) + .field("physical", &self.id.physical) + .finish() + } +} + +impl Drop for Categories { + fn drop(&mut self) { + CATEGORIES_REGISTRY.lock().unwrap().remove(&self.id); + } +} + +/// An ordered collection of unique strings with an associated pre-computed +/// mapping to go from string <-> index. +/// +/// FrozenCategories are globally unique to facilitate constant-time comparison. +pub struct FrozenCategories { + physical: CategoricalPhysical, + combined_hash: u64, + categories: Utf8ViewArray, + mapping: Arc, +} + +impl FrozenCategories { + /// Creates a new FrozenCategories object (or returns a reference to an existing one + /// in case these are already known). Returns an error if the categories are not unique. + /// It is guaranteed that the nth string ends up with category n (0-indexed). + pub fn new<'a, I: IntoIterator>(strings: I) -> PolarsResult> { + let strings = strings.into_iter(); + let hasher = *FROZEN_CATEGORIES_HASHER; + let mut mapping = CategoricalMapping::with_hasher(usize::MAX, hasher); + let mut builder = Utf8ViewArrayBuilder::new(ArrowDataType::Utf8); + builder.reserve(strings.size_hint().0); + + let mut combined_hasher = PlFixedStateQuality::default().build_hasher(); + for s in strings { + combined_hasher.write(s.as_bytes()); + mapping.insert_cat(s)?; + builder.push_value_ignore_validity(s); + polars_ensure!(mapping.len() == builder.len(), ComputeError: "FrozenCategories must contain unique strings; found duplicate '{s}'"); + } + + let combined_hash = combined_hasher.finish(); + let categories = builder.freeze(); + mapping.set_max_categories(categories.len()); // Don't allow any further inserts. + + let physical = CategoricalPhysical::smallest_physical(categories.len())?; + let mut registry = FROZEN_CATEGORIES_REGISTRY.lock().unwrap(); + let mut last_compared = None; // We have to store the strong reference to avoid a race condition. + match registry.entry( + combined_hash, + |(hash, weak)| { + *hash == combined_hash && { + if let Some(frozen_cats) = weak.upgrade() { + let cmp = frozen_cats.categories == categories; + last_compared = Some(frozen_cats); + cmp + } else { + false + } + } + }, + |(hash, _weak)| *hash, + ) { + Entry::Occupied(_) => Ok(last_compared.unwrap()), + Entry::Vacant(v) => { + let slf = Arc::new(Self { + physical, + combined_hash, + categories, + mapping: Arc::new(mapping), + }); + v.insert((combined_hash, Arc::downgrade(&slf))); + Ok(slf) + }, + } + } + + /// The categories contained in this FrozenCategories object. + pub fn categories(&self) -> &Utf8ViewArray { + &self.categories + } + + /// The physical dtype of the category ids. + pub fn physical(&self) -> CategoricalPhysical { + self.physical + } + + /// The mapping for this FrozenCategories object. + pub fn mapping(&self) -> &Arc { + &self.mapping + } +} + +impl fmt::Debug for FrozenCategories { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("FrozenCategories") + .field("physical", &self.physical) + .field("categories", &self.categories) + .finish() + } +} + +impl Drop for FrozenCategories { + fn drop(&mut self) { + let mut registry = FROZEN_CATEGORIES_REGISTRY.lock().unwrap(); + while let Ok(entry) = + registry.find_entry(self.combined_hash, |(_, weak)| weak.strong_count() == 0) + { + entry.remove(); + } + } +} + +pub fn ensure_same_categories(left: &Arc, right: &Arc) -> PolarsResult<()> { + if Arc::ptr_eq(left, right) { + return Ok(()); + } + + if left.name() != right.name() { + polars_bail!(SchemaMismatch: "Categories name mismatch, left: '{}', right: '{}'. + +Operations mixing different Categories are often not supported, you may have to cast.", left.name(), right.name()) + } else if left.namespace() != right.namespace() { + polars_bail!(SchemaMismatch: "Categories have same name ('{}'), but have a mismatch in namespace, left: {}, right: {}. + +Operations mixing different Categories are often not supported, you may have to cast.", left.name(), left.namespace(), right.namespace()) + } else { + polars_bail!(SchemaMismatch: "Categories have same name and namespace ('{}', {}), but have a mismatch in dtype, left: {}, right: {}. + +Operations mixing different Categories are often not supported, you may have to cast.", left.name(), left.namespace(), left.physical().as_str(), right.physical().as_str()) + } +} + +pub fn ensure_same_frozen_categories( + left: &Arc, + right: &Arc, +) -> PolarsResult<()> { + if Arc::ptr_eq(left, right) { + return Ok(()); + } + + polars_bail!(SchemaMismatch: r#"Enum mismatch. + +Operations mixing different Enums are often not supported, you may have to cast."#) +} diff --git a/crates/polars-dtype/src/lib.rs b/crates/polars-dtype/src/lib.rs new file mode 100644 index 000000000000..d9a51ea0f2fe --- /dev/null +++ b/crates/polars-dtype/src/lib.rs @@ -0,0 +1,4 @@ +// Other data Polars type definitions will be moved into this crate later, for +// now it only contains the categorical mappings. + +pub mod categorical; diff --git a/crates/polars-expr/src/expressions/aggregation.rs b/crates/polars-expr/src/expressions/aggregation.rs index 2c8bcc6ef641..f9502bc9c4b6 100644 --- a/crates/polars-expr/src/expressions/aggregation.rs +++ b/crates/polars-expr/src/expressions/aggregation.rs @@ -768,17 +768,11 @@ where #[cfg(not(debug_assertions))] let thread_boundary = 100_000; - // Temporary until categorical min/max multithreading implementation is corrected. - #[cfg(feature = "dtype-categorical")] - let is_categorical = matches!(s.dtype(), &DataType::Categorical(_, _)); - #[cfg(not(feature = "dtype-categorical"))] - let is_categorical = false; // threading overhead/ splitting work stealing is costly.. if !allow_threading || s.len() < thread_boundary || POOL.current_thread_has_pending_tasks().unwrap_or(false) - || is_categorical { return f(s); } diff --git a/crates/polars-expr/src/expressions/window.rs b/crates/polars-expr/src/expressions/window.rs index 80dded7729f6..ded0836fb576 100644 --- a/crates/polars-expr/src/expressions/window.rs +++ b/crates/polars-expr/src/expressions/window.rs @@ -476,12 +476,6 @@ impl PhysicalExpr for WindowExpr { } let gb = GroupBy::new(df, group_by_columns.clone(), groups, Some(apply_columns)); - // If the aggregation creates categoricals and `MapStrategy` is `Join`, - // the string cache was needed. So we hold it for that case. - // Worst case is that a categorical is created with indexes from the string - // cache which is fine, as the physical representation is undefined. - #[cfg(feature = "dtype-categorical")] - let _sc = polars_core::StringCacheHolder::hold(); let mut ac = self.run_aggregation(df, state, &gb)?; use MapStrategy::*; diff --git a/crates/polars-expr/src/groups/mod.rs b/crates/polars-expr/src/groups/mod.rs index d4aa713420ea..46736aa966be 100644 --- a/crates/polars-expr/src/groups/mod.rs +++ b/crates/polars-expr/src/groups/mod.rs @@ -2,6 +2,8 @@ use std::any::Any; use arrow::bitmap::BitmapBuilder; use polars_core::prelude::*; +#[cfg(feature = "dtype-categorical")] +use polars_core::with_match_categorical_physical_type; use polars_core::with_match_physical_numeric_polars_type; use polars_utils::IdxSize; use polars_utils::hashing::HashPartitioner; @@ -85,7 +87,11 @@ pub fn new_hash_grouper(key_schema: Arc) -> Box { Box::new(single_key::SingleKeyHashGrouper::::new()) }, #[cfg(feature = "dtype-categorical")] - DataType::Enum(_, _) => Box::new(single_key::SingleKeyHashGrouper::::new()), + DataType::Enum(fcats, _) => { + with_match_categorical_physical_type!(fcats.physical(), |$C| { + Box::new(single_key::SingleKeyHashGrouper::<<$C as PolarsCategoricalType>::PolarsPhysical>::new()) + }) + }, DataType::String | DataType::Binary => Box::new(binview::BinviewHashGrouper::new()), diff --git a/crates/polars-expr/src/groups/row_encoded.rs b/crates/polars-expr/src/groups/row_encoded.rs index a214bcde4bb2..d0e4b4952177 100644 --- a/crates/polars-expr/src/groups/row_encoded.rs +++ b/crates/polars-expr/src/groups/row_encoded.rs @@ -42,7 +42,7 @@ impl RowEncodedHashGrouper { .collect::>(); let ctxts = key_schema .iter() - .map(|(_, dt)| get_row_encoding_context(dt, false)) + .map(|(_, dt)| get_row_encoding_context(dt)) .collect::>(); let fields = vec![RowEncodingOptions::new_unsorted(); key_dtypes.len()]; let key_columns = diff --git a/crates/polars-expr/src/hash_keys.rs b/crates/polars-expr/src/hash_keys.rs index 1b52357f5021..808b0e7e8723 100644 --- a/crates/polars-expr/src/hash_keys.rs +++ b/crates/polars-expr/src/hash_keys.rs @@ -4,10 +4,9 @@ use std::hash::BuildHasher; use arrow::array::{Array, BinaryArray, BinaryViewArray, PrimitiveArray, StaticArray, UInt64Array}; use arrow::bitmap::Bitmap; use arrow::compute::utils::combine_validities_and_many; -use polars_core::error::polars_err; use polars_core::frame::DataFrame; use polars_core::prelude::row_encode::_get_rows_encoded_unordered; -use polars_core::prelude::{ChunkedArray, DataType, PlRandomState, PolarsDataType}; +use polars_core::prelude::{ChunkedArray, DataType, PlRandomState, PolarsDataType, *}; use polars_core::series::Series; use polars_utils::IdxSize; use polars_utils::cardinality_sketch::CardinalitySketch; @@ -66,18 +65,24 @@ macro_rules! downcast_single_key_ca { DataType::Float64 => { let $ca = $self.f64().unwrap(); $($body)* }, #[cfg(feature = "dtype-date")] - DataType::Date => { let $ca = $self.date().unwrap(); $($body)* }, + DataType::Date => { let $ca = $self.date().unwrap().physical(); $($body)* }, #[cfg(feature = "dtype-time")] - DataType::Time => { let $ca = $self.time().unwrap(); $($body)* }, + DataType::Time => { let $ca = $self.time().unwrap().physical(); $($body)* }, #[cfg(feature = "dtype-datetime")] - DataType::Datetime(..) => { let $ca = $self.datetime().unwrap(); $($body)* }, + DataType::Datetime(..) => { let $ca = $self.datetime().unwrap().physical(); $($body)* }, #[cfg(feature = "dtype-duration")] - DataType::Duration(..) => { let $ca = $self.duration().unwrap(); $($body)* }, + DataType::Duration(..) => { let $ca = $self.duration().unwrap().physical(); $($body)* }, #[cfg(feature = "dtype-decimal")] - DataType::Decimal(..) => { let $ca = $self.decimal().unwrap(); $($body)* }, + DataType::Decimal(..) => { let $ca = $self.decimal().unwrap().physical(); $($body)* }, #[cfg(feature = "dtype-categorical")] - DataType::Enum(..) => { let $ca = $self.categorical().unwrap().physical(); $($body)* }, + DataType::Enum(fcats, _) => { + match fcats.physical() { + CategoricalPhysical::U8 => { let $ca = $self.cat8().unwrap().physical(); $($body)* }, + CategoricalPhysical::U16 => { let $ca = $self.cat16().unwrap().physical(); $($body)* }, + CategoricalPhysical::U32 => { let $ca = $self.cat32().unwrap().physical(); $($body)* }, + } + }, _ => unreachable!(), } @@ -108,16 +113,6 @@ impl HashKeys { || first_col_variant == HashKeysVariant::RowEncoded; if use_row_encoding { let keys = df.get_columns(); - #[cfg(feature = "dtype-categorical")] - for key in keys { - if let DataType::Categorical(Some(rev_map), _) = key.dtype() { - assert!( - rev_map.is_active_global(), - "{}", - polars_err!(string_cache_mismatch) - ); - } - } let mut keys_encoded = _get_rows_encoded_unordered(keys).unwrap().into_array(); if !null_is_valid { diff --git a/crates/polars-expr/src/hot_groups/mod.rs b/crates/polars-expr/src/hot_groups/mod.rs index 256772cbd5f9..55757379e770 100644 --- a/crates/polars-expr/src/hot_groups/mod.rs +++ b/crates/polars-expr/src/hot_groups/mod.rs @@ -82,7 +82,11 @@ pub fn new_hash_hot_grouper(key_schema: Arc, num_groups: usize) -> Box Box::new(SK::::new(dt, ng)), #[cfg(feature = "dtype-categorical")] - DataType::Enum(_, _) => Box::new(SK::::new(dt, ng)), + DataType::Enum(ref fcats, _) => { + with_match_categorical_physical_type!(fcats.physical(), |$C| { + Box::new(SK::<<$C as PolarsCategoricalType>::PolarsPhysical>::new(dt.clone(), ng)) + }) + }, DataType::String | DataType::Binary => { Box::new(binview::BinviewHashHotGrouper::new(ng)) diff --git a/crates/polars-expr/src/idx_table/mod.rs b/crates/polars-expr/src/idx_table/mod.rs index f29a89ba1ca0..54eba807f581 100644 --- a/crates/polars-expr/src/idx_table/mod.rs +++ b/crates/polars-expr/src/idx_table/mod.rs @@ -104,7 +104,11 @@ pub fn new_idx_table(key_schema: Arc) -> Box { #[cfg(feature = "dtype-decimal")] DataType::Decimal(_, _) => Box::new(SKIT::::new()), #[cfg(feature = "dtype-categorical")] - DataType::Enum(_, _) => Box::new(SKIT::::new()), + DataType::Enum(fcats, _) => { + with_match_categorical_physical_type!(fcats.physical(), |$C| { + Box::new(SKIT::<<$C as PolarsCategoricalType>::PolarsPhysical>::new()) + }) + }, DataType::String | DataType::Binary => Box::new(binview::BinviewKeyIdxTable::new()), diff --git a/crates/polars-io/src/csv/read/buffer.rs b/crates/polars-io/src/csv/read/buffer.rs index 4db5e4e8f395..6ccecb5de60d 100644 --- a/crates/polars-io/src/csv/read/buffer.rs +++ b/crates/polars-io/src/csv/read/buffer.rs @@ -1,4 +1,6 @@ use arrow::array::MutableBinaryViewArray; +#[cfg(feature = "dtype-categorical")] +use polars_core::chunked_array::builder::CategoricalChunkedBuilder; use polars_core::prelude::*; use polars_error::to_compute_err; #[cfg(any(feature = "dtype-datetime", feature = "dtype-date"))] @@ -256,43 +258,23 @@ impl ParsedBuffer for Utf8Field { } } -#[cfg(not(feature = "dtype-categorical"))] -pub struct CategoricalField { - phantom: std::marker::PhantomData, -} - #[cfg(feature = "dtype-categorical")] -pub struct CategoricalField { +pub struct CategoricalField { escape_scratch: Vec, quote_char: u8, - builder: CategoricalChunkedBuilder, - is_enum: bool, + builder: CategoricalChunkedBuilder, } #[cfg(feature = "dtype-categorical")] -impl CategoricalField { - fn new( - name: PlSmallStr, - capacity: usize, - quote_char: Option, - ordering: CategoricalOrdering, - ) -> Self { - let builder = CategoricalChunkedBuilder::new(name, capacity, ordering); +impl CategoricalField { + fn new(name: PlSmallStr, capacity: usize, quote_char: Option, dtype: DataType) -> Self { + let mut builder = CategoricalChunkedBuilder::new(name, dtype); + builder.reserve(capacity); Self { escape_scratch: vec![], quote_char: quote_char.unwrap_or(b'"'), builder, - is_enum: false, - } - } - - fn new_enum(quote_char: Option, builder: CategoricalChunkedBuilder) -> Self { - Self { - escape_scratch: vec![], - quote_char: quote_char.unwrap_or(b'"'), - builder, - is_enum: true, } } @@ -328,20 +310,12 @@ impl CategoricalField { // SAFETY: // just did utf8 check let key = unsafe { std::str::from_utf8_unchecked(&self.escape_scratch) }; - if self.is_enum { - self.builder.try_append_value(key)?; - } else { - self.builder.append_value(key); - } + self.builder.append_str(key)?; } else { // SAFETY: // just did utf8 check let key = unsafe { std::str::from_utf8_unchecked(bytes) }; - if self.is_enum { - self.builder.try_append_value(key)? - } else { - self.builder.append_value(key) - } + self.builder.append_str(key)?; } } else if ignore_errors { self.builder.append_null() @@ -574,21 +548,33 @@ pub fn init_buffers( #[cfg(feature = "dtype-date")] &DataType::Date => Buffer::Date(DatetimeField::new(name, capacity)), #[cfg(feature = "dtype-categorical")] - DataType::Categorical(_, ordering) => Buffer::Categorical(CategoricalField::new( - name, capacity, quote_char, *ordering, - )), - #[cfg(feature = "dtype-categorical")] - DataType::Enum(rev_map, _) => { - let Some(rev_map) = rev_map else { - polars_bail!(ComputeError: "enum categories must be set") - }; - let cats = rev_map.get_categories(); - let mut builder = - CategoricalChunkedBuilder::new(name, capacity, Default::default()); - for cat in cats.values_iter() { - builder.register_value(cat); + DataType::Categorical(_, _) | DataType::Enum(_, _) => { + match dtype.cat_physical().unwrap() { + CategoricalPhysical::U8 => { + Buffer::Categorical8(CategoricalField::::new( + name, + capacity, + quote_char, + dtype.clone(), + )) + }, + CategoricalPhysical::U16 => { + Buffer::Categorical16(CategoricalField::::new( + name, + capacity, + quote_char, + dtype.clone(), + )) + }, + CategoricalPhysical::U32 => { + Buffer::Categorical32(CategoricalField::::new( + name, + capacity, + quote_char, + dtype.clone(), + )) + }, } - Buffer::Categorical(CategoricalField::new_enum(quote_char, builder)) }, dt => polars_bail!( ComputeError: "unsupported data type when reading CSV: {} when reading CSV", dt, @@ -628,8 +614,12 @@ pub enum Buffer { }, #[cfg(feature = "dtype-date")] Date(DatetimeField), - #[allow(dead_code)] - Categorical(CategoricalField), + #[cfg(feature = "dtype-categorical")] + Categorical8(CategoricalField), + #[cfg(feature = "dtype-categorical")] + Categorical16(CategoricalField), + #[cfg(feature = "dtype-categorical")] + Categorical32(CategoricalField), DecimalFloat32(PrimitiveChunkedBuilder, Vec), DecimalFloat64(PrimitiveChunkedBuilder, Vec), } @@ -680,32 +670,12 @@ impl Buffer { StringChunked::with_chunk(v.name.clone(), unsafe { arr.to_utf8view_unchecked() }) .into_series() }, - #[allow(unused_variables)] - Buffer::Categorical(buf) => { - #[cfg(feature = "dtype-categorical")] - { - let ca = buf.builder.finish(); - - if buf.is_enum { - let DataType::Categorical(Some(rev_map), _) = ca.dtype() else { - unreachable!() - }; - let idx = ca.physical().clone(); - let dtype = DataType::Enum(Some(rev_map.clone()), Default::default()); - - unsafe { - CategoricalChunked::from_cats_and_dtype_unchecked(idx, dtype) - .into_series() - } - } else { - ca.into_series() - } - } - #[cfg(not(feature = "dtype-categorical"))] - { - panic!("activate 'dtype-categorical' feature") - } - }, + #[cfg(feature = "dtype-categorical")] + Buffer::Categorical8(buf) => buf.builder.finish().into_series(), + #[cfg(feature = "dtype-categorical")] + Buffer::Categorical16(buf) => buf.builder.finish().into_series(), + #[cfg(feature = "dtype-categorical")] + Buffer::Categorical32(buf) => buf.builder.finish().into_series(), }; Ok(s) } @@ -742,17 +712,12 @@ impl Buffer { Buffer::Datetime { buf, .. } => buf.builder.append_null(), #[cfg(feature = "dtype-date")] Buffer::Date(v) => v.builder.append_null(), - #[allow(unused_variables)] - Buffer::Categorical(cat_builder) => { - #[cfg(feature = "dtype-categorical")] - { - cat_builder.builder.append_null() - } - #[cfg(not(feature = "dtype-categorical"))] - { - panic!("activate 'dtype-categorical' feature") - } - }, + #[cfg(feature = "dtype-categorical")] + Buffer::Categorical8(buf) => buf.builder.append_null(), + #[cfg(feature = "dtype-categorical")] + Buffer::Categorical16(buf) => buf.builder.append_null(), + #[cfg(feature = "dtype-categorical")] + Buffer::Categorical32(buf) => buf.builder.append_null(), }; } @@ -780,17 +745,12 @@ impl Buffer { Buffer::Datetime { time_unit, .. } => DataType::Datetime(*time_unit, None), #[cfg(feature = "dtype-date")] Buffer::Date(_) => DataType::Date, - Buffer::Categorical(_) => { - #[cfg(feature = "dtype-categorical")] - { - DataType::Categorical(None, Default::default()) - } - - #[cfg(not(feature = "dtype-categorical"))] - { - panic!("activate 'dtype-categorical' feature") - } - }, + #[cfg(feature = "dtype-categorical")] + Buffer::Categorical8(buf) => buf.builder.dtype().clone(), + #[cfg(feature = "dtype-categorical")] + Buffer::Categorical16(buf) => buf.builder.dtype().clone(), + #[cfg(feature = "dtype-categorical")] + Buffer::Categorical32(buf) => buf.builder.dtype().clone(), } } @@ -955,17 +915,17 @@ impl Buffer { missing_is_null, None, ), - #[allow(unused_variables)] - Categorical(buf) => { - #[cfg(feature = "dtype-categorical")] - { - buf.parse_bytes(bytes, ignore_errors, needs_escaping, missing_is_null, None) - } - - #[cfg(not(feature = "dtype-categorical"))] - { - panic!("activate 'dtype-categorical' feature") - } + #[cfg(feature = "dtype-categorical")] + Categorical8(buf) => { + buf.parse_bytes(bytes, ignore_errors, needs_escaping, missing_is_null, None) + }, + #[cfg(feature = "dtype-categorical")] + Categorical16(buf) => { + buf.parse_bytes(bytes, ignore_errors, needs_escaping, missing_is_null, None) + }, + #[cfg(feature = "dtype-categorical")] + Categorical32(buf) => { + buf.parse_bytes(bytes, ignore_errors, needs_escaping, missing_is_null, None) }, } } diff --git a/crates/polars-io/src/csv/read/read_impl.rs b/crates/polars-io/src/csv/read/read_impl.rs index a71631d8106f..d22af5a6abb5 100644 --- a/crates/polars-io/src/csv/read/read_impl.rs +++ b/crates/polars-io/src/csv/read/read_impl.rs @@ -121,8 +121,6 @@ pub(crate) struct CoreReader<'a> { predicate: Option>, to_cast: Vec, row_index: Option, - #[cfg_attr(not(feature = "dtype-categorical"), allow(unused))] - has_categorical: bool, } impl fmt::Debug for CoreReader<'_> { @@ -217,7 +215,7 @@ impl<'a> CoreReader<'a> { } } - let has_categorical = prepare_csv_schema(&mut schema, &mut to_cast)?; + prepare_csv_schema(&mut schema, &mut to_cast)?; // Create a null value for every column let null_values = parse_options @@ -253,7 +251,6 @@ impl<'a> CoreReader<'a> { predicate, to_cast, row_index, - has_categorical, }) } @@ -518,15 +515,7 @@ impl<'a> CoreReader<'a> { /// Read the csv into a DataFrame. The predicate can come from a lazy physical plan. pub fn finish(mut self) -> PolarsResult { - #[cfg(feature = "dtype-categorical")] - let mut _cat_lock = if self.has_categorical { - Some(polars_core::StringCacheHolder::hold()) - } else { - None - }; - let reader_bytes = self.reader_bytes.take().unwrap(); - let mut df = self.parse_csv(&reader_bytes)?; // if multi-threaded the n_rows was probabilistically determined. diff --git a/crates/polars-io/src/csv/read/read_impl/batched.rs b/crates/polars-io/src/csv/read/read_impl/batched.rs index fa6021efa94d..85dd926d4e46 100644 --- a/crates/polars-io/src/csv/read/read_impl/batched.rs +++ b/crates/polars-io/src/csv/read/read_impl/batched.rs @@ -160,17 +160,6 @@ impl<'a> CoreReader<'a> { let projection = self.get_projection()?; - // RAII structure that will ensure we maintain a global stringcache - #[cfg(feature = "dtype-categorical")] - let _cat_lock = if self.has_categorical { - Some(polars_core::StringCacheHolder::hold()) - } else { - None - }; - - #[cfg(not(feature = "dtype-categorical"))] - let _cat_lock = None; - Ok(BatchedCsvReader { reader_bytes, parse_options: self.parse_options, @@ -186,7 +175,6 @@ impl<'a> CoreReader<'a> { remaining: self.n_rows.unwrap_or(usize::MAX), schema: self.schema, rows_read: 0, - _cat_lock, }) } } @@ -206,10 +194,6 @@ pub struct BatchedCsvReader<'a> { remaining: usize, schema: SchemaRef, rows_read: IdxSize, - #[cfg(feature = "dtype-categorical")] - _cat_lock: Option, - #[cfg(not(feature = "dtype-categorical"))] - _cat_lock: Option, } impl BatchedCsvReader<'_> { diff --git a/crates/polars-io/src/csv/read/reader.rs b/crates/polars-io/src/csv/read/reader.rs index cc6551e59cf6..592d335e3b4b 100644 --- a/crates/polars-io/src/csv/read/reader.rs +++ b/crates/polars-io/src/csv/read/reader.rs @@ -197,17 +197,12 @@ impl CsvReader { /// Splits datatypes that cannot be natively read into a `fields_to_cast` for /// post-read casting. -/// -/// # Returns -/// `has_categorical` pub fn prepare_csv_schema( schema: &mut SchemaRef, fields_to_cast: &mut Vec, -) -> PolarsResult { +) -> PolarsResult<()> { // This branch we check if there are dtypes we cannot parse. - // We only support a few dtypes in the parser and later cast to the required dtype - let mut _has_categorical = false; - + // We only support a few dtypes in the parser and later cast to the required dtype. let mut changed = false; let new_schema = schema @@ -223,11 +218,6 @@ pub fn prepare_csv_schema( fld.coerce(String); PolarsResult::Ok(fld) }, - #[cfg(feature = "dtype-categorical")] - Categorical(_, _) => { - _has_categorical = true; - PolarsResult::Ok(fld) - }, #[cfg(feature = "dtype-decimal")] Decimal(precision, scale) => match (precision, scale) { (_, Some(_)) => { @@ -255,5 +245,5 @@ pub fn prepare_csv_schema( *schema = Arc::new(new_schema); } - Ok(_has_categorical) + Ok(()) } diff --git a/crates/polars-io/src/csv/write/write_impl/serializer.rs b/crates/polars-io/src/csv/write/write_impl/serializer.rs index a0fcd447e1b1..aa38b8e3df51 100644 --- a/crates/polars-io/src/csv/write/write_impl/serializer.rs +++ b/crates/polars-io/src/csv/write/write_impl/serializer.rs @@ -886,22 +886,23 @@ pub(super) fn serializer_for<'a>( array, ), #[cfg(feature = "dtype-categorical")] - DataType::Categorical(rev_map, _) | DataType::Enum(rev_map, _) => { - let rev_map = rev_map.as_deref().unwrap(); - string_serializer( - |iter| { - let &idx: &u32 = Iterator::next(iter).expect(TOO_MANY_MSG)?; - Some(rev_map.get(idx)) - }, - options, - |arr| { - arr.as_any() - .downcast_ref::>() - .expect(ARRAY_MISMATCH_MSG) - .iter() - }, - array, - ) + DataType::Categorical(_, mapping) | DataType::Enum(_, mapping) => { + polars_core::with_match_categorical_physical_type!(dtype.cat_physical().unwrap(), |$C| { + string_serializer( + |iter| { + let &idx: &<$C as PolarsCategoricalType>::Native = Iterator::next(iter).expect(TOO_MANY_MSG)?; + Some(unsafe { mapping.cat_to_str_unchecked(idx.as_cat()) }) + }, + options, + |arr| { + arr.as_any() + .downcast_ref::::Native>>() + .expect(ARRAY_MISMATCH_MSG) + .iter() + }, + array, + ) + }) }, #[cfg(feature = "dtype-decimal")] DataType::Decimal(_, scale) => { diff --git a/crates/polars-io/src/parquet/read/read_impl.rs b/crates/polars-io/src/parquet/read/read_impl.rs index 6c01e8484266..547a29c718ce 100644 --- a/crates/polars-io/src/parquet/read/read_impl.rs +++ b/crates/polars-io/src/parquet/read/read_impl.rs @@ -427,22 +427,6 @@ pub fn read_parquet( .unwrap_or_else(|| read::read_metadata(&mut reader).map(Arc::new))?; let n_row_groups = file_metadata.row_groups.len(); - // if there are multiple row groups and categorical data - // we need a string cache - // we keep it alive until the end of the function - let _sc = if n_row_groups > 1 { - #[cfg(feature = "dtype-categorical")] - { - Some(polars_core::StringCacheHolder::hold()) - } - #[cfg(not(feature = "dtype-categorical"))] - { - Some(0u8) - } - } else { - None - }; - let materialized_projection = projection .map(Cow::Borrowed) .unwrap_or_else(|| Cow::Owned((0usize..reader_schema.len()).collect::>())); diff --git a/crates/polars-io/src/predicates.rs b/crates/polars-io/src/predicates.rs index 0fb56dfff86c..e609e8c70597 100644 --- a/crates/polars-io/src/predicates.rs +++ b/crates/polars-io/src/predicates.rs @@ -150,10 +150,9 @@ fn cast_to_parquet_scalar(scalar: Scalar) -> Option { // @TODO: Cast to string #[cfg(feature = "dtype-categorical")] - A::Categorical(_, _, _) - | A::CategoricalOwned(_, _, _) - | A::Enum(_, _, _) - | A::EnumOwned(_, _, _) => return None, + A::Categorical(_, _) | A::CategoricalOwned(_, _) | A::Enum(_, _) | A::EnumOwned(_, _) => { + return None; + }, A::String(v) => P::String(v.into()), A::StringOwned(v) => P::String(v.as_str().into()), diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index f3d2aa6a3fc5..8ba95f0e31b6 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -25,8 +25,6 @@ pub use ndjson::*; pub use parquet::*; use polars_compute::rolling::QuantileMethod; use polars_core::POOL; -#[cfg(all(feature = "new_streaming", feature = "dtype-categorical"))] -use polars_core::StringCacheHolder; use polars_core::error::feature_gated; use polars_core::prelude::*; use polars_expr::{ExpressionConversionState, create_physical_expr}; @@ -737,15 +735,11 @@ impl LazyFrame { match engine { Engine::Auto | Engine::Streaming => feature_gated!("new_streaming", { - #[cfg(feature = "dtype-categorical")] - let string_cache_hold = StringCacheHolder::hold(); let result = polars_stream::run_query( alp_plan.lp_top, &mut alp_plan.lp_arena, &mut alp_plan.expr_arena, ); - #[cfg(feature = "dtype-categorical")] - drop(string_cache_hold); result.map(|v| v.unwrap_single()) }), _ if matches!(payload, SinkType::Partition { .. }) => Err(polars_err!( @@ -822,15 +816,11 @@ impl LazyFrame { if engine == Engine::Streaming { feature_gated!("new_streaming", { - #[cfg(feature = "dtype-categorical")] - let string_cache_hold = StringCacheHolder::hold(); let result = polars_stream::run_query( alp_plan.lp_top, &mut alp_plan.lp_arena, &mut alp_plan.expr_arena, ); - #[cfg(feature = "dtype-categorical")] - drop(string_cache_hold); return result.map(|v| v.unwrap_multiple()); }); } @@ -1141,8 +1131,6 @@ impl LazyFrame { Err(e) => return Some(Err(e)), }; - #[cfg(feature = "dtype-categorical")] - let _hold = StringCacheHolder::hold(); let f = || { polars_stream::run_query( alp_plan.lp_top, diff --git a/crates/polars-lazy/src/tests/predicate_queries.rs b/crates/polars-lazy/src/tests/predicate_queries.rs index 523eee9c5237..d91dbce753f2 100644 --- a/crates/polars-lazy/src/tests/predicate_queries.rs +++ b/crates/polars-lazy/src/tests/predicate_queries.rs @@ -42,7 +42,7 @@ fn test_issue_2472() -> PolarsResult<()> { ]?; let base = df .lazy() - .with_column(col("group").cast(DataType::Categorical(None, Default::default()))); + .with_column(col("group").cast(DataType::from_categories(Categories::global()))); let extract = col("group") .cast(DataType::String) diff --git a/crates/polars-lazy/src/tests/queries.rs b/crates/polars-lazy/src/tests/queries.rs index 29ab99398faa..cd63b1b9bb33 100644 --- a/crates/polars-lazy/src/tests/queries.rs +++ b/crates/polars-lazy/src/tests/queries.rs @@ -1403,8 +1403,8 @@ fn test_categorical_addition() -> PolarsResult<()> { let out = df .lazy() .select([ - col("fruits").cast(DataType::Categorical(None, Default::default())), - col("cars").cast(DataType::Categorical(None, Default::default())), + col("fruits").cast(DataType::from_categories(Categories::global())), + col("cars").cast(DataType::from_categories(Categories::global())), ]) .select([(col("fruits") + lit(" ") + col("cars")).alias("foo")]) .collect()?; diff --git a/crates/polars-mem-engine/src/executors/group_by_partitioned.rs b/crates/polars-mem-engine/src/executors/group_by_partitioned.rs index 8dbde323d291..7afdeb270fba 100644 --- a/crates/polars-mem-engine/src/executors/group_by_partitioned.rs +++ b/crates/polars-mem-engine/src/executors/group_by_partitioned.rs @@ -208,8 +208,8 @@ fn can_run_partitioned( let (unique_estimate, sampled_method) = match (keys.len(), keys[0].dtype()) { #[cfg(feature = "dtype-categorical")] - (1, DataType::Categorical(Some(rev_map), _) | DataType::Enum(Some(rev_map), _)) => { - (rev_map.len(), "known") + (1, DataType::Categorical(_, mapping) | DataType::Enum(_, mapping)) => { + (mapping.num_cats_upper_bound(), "known") }, _ => { // sqrt(N) is a good sample size as it remains low on large numbers diff --git a/crates/polars-mem-engine/src/planner/lp.rs b/crates/polars-mem-engine/src/planner/lp.rs index 25a85525f208..a842e0409061 100644 --- a/crates/polars-mem-engine/src/planner/lp.rs +++ b/crates/polars-mem-engine/src/planner/lp.rs @@ -395,26 +395,8 @@ fn create_physical_plan_impl( Ok(Box::new(executors::SliceExec { input, offset, len })) }, Filter { input, predicate } => { - let mut streamable = - is_elementwise_rec_no_cat_cast(expr_arena.get(predicate.node()), expr_arena); + let streamable = is_elementwise_rec(predicate.node(), expr_arena); let input_schema = lp_arena.get(input).schema(lp_arena).into_owned(); - if streamable { - // This can cause problems with string caches - streamable = !input_schema - .iter_values() - .any(|dt| dt.contains_categoricals()) - || { - #[cfg(feature = "dtype-categorical")] - { - polars_core::using_string_cache() - } - - #[cfg(not(feature = "dtype-categorical"))] - { - false - } - } - } let input = recurse!(input, state)?; let mut state = ExpressionConversionState::new(true); let predicate = create_physical_expr( @@ -559,7 +541,7 @@ fn create_physical_plan_impl( &mut state, )?; - let allow_vertical_parallelism = options.should_broadcast && expr.iter().all(|e| is_elementwise_rec_no_cat_cast(expr_arena.get(e.node()), expr_arena)) + let allow_vertical_parallelism = options.should_broadcast && expr.iter().all(|e| is_elementwise_rec(e.node(), expr_arena)) // If all columns are literal we would get a 1 row per thread. && !phys_expr.iter().all(|p| { p.is_literal() @@ -828,7 +810,7 @@ fn create_physical_plan_impl( let allow_vertical_parallelism = options.should_broadcast && exprs .iter() - .all(|e| is_elementwise_rec_no_cat_cast(expr_arena.get(e.node()), expr_arena)); + .all(|e| is_elementwise_rec(e.node(), expr_arena)); let mut state = ExpressionConversionState::new(POOL.current_num_threads() > exprs.len()); diff --git a/crates/polars-ops/src/chunked_array/datetime/replace_time_zone.rs b/crates/polars-ops/src/chunked_array/datetime/replace_time_zone.rs index 1cf5eb38b5f1..60ca3ae06f9e 100644 --- a/crates/polars-ops/src/chunked_array/datetime/replace_time_zone.rs +++ b/crates/polars-ops/src/chunked_array/datetime/replace_time_zone.rs @@ -32,7 +32,8 @@ pub fn replace_time_zone( .phys .clone() .into_datetime(datetime.time_unit(), time_zone.cloned()); - out.set_sorted_flag(datetime.is_sorted_flag()); + out.physical_mut() + .set_sorted_flag(datetime.physical().is_sorted_flag()); return Ok(out); } let timestamp_to_datetime: fn(i64) -> NaiveDateTime = match datetime.time_unit() { @@ -78,7 +79,8 @@ pub fn replace_time_zone( // - `from_tz` is guaranteed to not observe daylight savings time; // - user is just passing 'raise' to 'ambiguous'. // Both conditions above need to be satisfied. - out.set_sorted_flag(datetime.is_sorted_flag()); + out.physical_mut() + .set_sorted_flag(datetime.physical().is_sorted_flag()); } Ok(out) } @@ -140,8 +142,10 @@ pub fn impl_replace_time_zone( }); ChunkedArray::try_from_chunk_iter(datetime.phys.name().clone(), iter) }, - _ => try_binary_elementwise(datetime, ambiguous, |timestamp_opt, ambiguous_opt| { - match (timestamp_opt, ambiguous_opt) { + _ => try_binary_elementwise( + datetime.physical(), + ambiguous, + |timestamp_opt, ambiguous_opt| match (timestamp_opt, ambiguous_opt) { (Some(timestamp), Some(ambiguous)) => { let ndt = timestamp_to_datetime(timestamp); Ok(convert_to_naive_local( @@ -154,7 +158,7 @@ pub fn impl_replace_time_zone( .map(datetime_to_timestamp)) }, _ => Ok(None), - } - }), + }, + ), } } diff --git a/crates/polars-ops/src/chunked_array/gather/chunked.rs b/crates/polars-ops/src/chunked_array/gather/chunked.rs index 8b1de159148f..3c4c45935993 100644 --- a/crates/polars-ops/src/chunked_array/gather/chunked.rs +++ b/crates/polars-ops/src/chunked_array/gather/chunked.rs @@ -10,7 +10,7 @@ use polars_core::prelude::gather::_update_gather_sorted_flag; use polars_core::prelude::*; use polars_core::series::IsSorted; use polars_core::utils::Container; -use polars_core::with_match_physical_numeric_polars_type; +use polars_core::{with_match_categorical_physical_type, with_match_physical_numeric_polars_type}; use crate::frame::IntoDf; @@ -219,18 +219,15 @@ impl TakeChunked for Series { .into_series() }, #[cfg(feature = "dtype-categorical")] - Categorical(revmap, ord) | Enum(revmap, ord) => { - let ca = self.categorical().unwrap(); - let t = ca - .physical() - .take_chunked_unchecked(by, sorted, avoid_sharing); - CategoricalChunked::from_cats_and_rev_map_unchecked( - t, - revmap.as_ref().unwrap().clone(), - matches!(self.dtype(), Enum(..)), - *ord, - ) - .into_series() + Categorical(_, _) | Enum(_, _) => { + with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| { + let ca = self.cat::<$C>().unwrap(); + CategoricalChunked::<$C>::from_cats_and_dtype_unchecked( + ca.physical().take_chunked_unchecked(by, sorted, avoid_sharing), + self.dtype().clone() + ) + .into_series() + }) }, Null => Series::new_null(self.name().clone(), by.len()), _ => unreachable!(), @@ -322,16 +319,15 @@ impl TakeChunked for Series { .into_series() }, #[cfg(feature = "dtype-categorical")] - Categorical(revmap, ord) | Enum(revmap, ord) => { - let ca = self.categorical().unwrap(); - let ret = ca.physical().take_opt_chunked_unchecked(by, avoid_sharing); - CategoricalChunked::from_cats_and_rev_map_unchecked( - ret, - revmap.as_ref().unwrap().clone(), - matches!(self.dtype(), Enum(..)), - *ord, - ) - .into_series() + Categorical(_, _) | Enum(_, _) => { + with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| { + let ca = self.cat::<$C>().unwrap(); + CategoricalChunked::<$C>::from_cats_and_dtype_unchecked( + ca.physical().take_opt_chunked_unchecked(by, avoid_sharing), + self.dtype().clone() + ) + .into_series() + }) }, Null => Series::new_null(self.name().clone(), by.len()), _ => unreachable!(), diff --git a/crates/polars-ops/src/chunked_array/hist.rs b/crates/polars-ops/src/chunked_array/hist.rs index 127452ff26e4..196c517c9f4b 100644 --- a/crates/polars-ops/src/chunked_array/hist.rs +++ b/crates/polars-ops/src/chunked_array/hist.rs @@ -215,7 +215,7 @@ where } let categories = categories .finish() - .cast(&DataType::Categorical(None, Default::default())) + .cast(&DataType::from_categories(Categories::global())) .unwrap(); fields.push(categories); }; diff --git a/crates/polars-ops/src/chunked_array/list/namespace.rs b/crates/polars-ops/src/chunked_array/list/namespace.rs index 7cb7e719dea8..63fd23da57b6 100644 --- a/crates/polars-ops/src/chunked_array/list/namespace.rs +++ b/crates/polars-ops/src/chunked_array/list/namespace.rs @@ -696,23 +696,9 @@ pub trait ListNameSpaceImpl: AsList { match s.dtype() { DataType::List(inner_type) => { inner_super_type = try_get_supertype(&inner_super_type, inner_type)?; - #[cfg(feature = "dtype-categorical")] - if matches!( - &inner_super_type, - DataType::Categorical(_, _) | DataType::Enum(_, _) - ) { - inner_super_type = merge_dtypes(&inner_super_type, inner_type)?; - } }, dt => { inner_super_type = try_get_supertype(&inner_super_type, dt)?; - #[cfg(feature = "dtype-categorical")] - if matches!( - &inner_super_type, - DataType::Categorical(_, _) | DataType::Enum(_, _) - ) { - inner_super_type = merge_dtypes(&inner_super_type, dt)?; - } }, } } diff --git a/crates/polars-ops/src/chunked_array/list/sets.rs b/crates/polars-ops/src/chunked_array/list/sets.rs index a90d84972abd..0a0e7f2b89db 100644 --- a/crates/polars-ops/src/chunked_array/list/sets.rs +++ b/crates/polars-ops/src/chunked_array/list/sets.rs @@ -428,14 +428,6 @@ pub fn list_set_operation( a.prune_empty_chunks(); b.prune_empty_chunks(); - // Make categoricals compatible - #[cfg(feature = "dtype-categorical")] - if let (DataType::Categorical(_, _), DataType::Categorical(_, _)) = - (&a.inner_dtype(), &b.inner_dtype()) - { - (a, b) = make_rhs_list_categoricals_compatible(a, b)?; - } - // we use the unsafe variant because we want to keep the nested logical types type. unsafe { arity::try_binary_unchecked_same_type( diff --git a/crates/polars-ops/src/frame/join/asof/groups.rs b/crates/polars-ops/src/frame/join/asof/groups.rs index 9d2a3a7ab4f2..b8e7311b9b84 100644 --- a/crates/polars-ops/src/frame/join/asof/groups.rs +++ b/crates/polars-ops/src/frame/join/asof/groups.rs @@ -296,8 +296,6 @@ where polars_ensure!(lhs.dtype() == rhs.dtype(), ComputeError: "mismatching dtypes in 'by' parameter of asof-join: `{}` and `{}`", lhs.dtype(), rhs.dtype() ); - #[cfg(feature = "dtype-categorical")] - _check_categorical_src(lhs.dtype(), rhs.dtype())?; } // TODO: @scalar-opt. @@ -526,8 +524,6 @@ pub trait AsofJoinBy: IntoDf { .iter_mut() .zip(right_by.get_columns_mut().iter_mut()) { - #[cfg(feature = "dtype-categorical")] - _check_categorical_src(l.dtype(), r.dtype())?; *l = l.to_physical_repr(); *r = r.to_physical_repr(); } diff --git a/crates/polars-ops/src/frame/join/asof/mod.rs b/crates/polars-ops/src/frame/join/asof/mod.rs index a5f765fb283e..dc47c4b5ff58 100644 --- a/crates/polars-ops/src/frame/join/asof/mod.rs +++ b/crates/polars-ops/src/frame/join/asof/mod.rs @@ -10,8 +10,6 @@ use polars_utils::pl_str::PlSmallStr; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -#[cfg(feature = "dtype-categorical")] -use super::_check_categorical_src; use super::{_finish_join, build_tables}; use crate::frame::IntoDf; use crate::series::SeriesMethods; diff --git a/crates/polars-ops/src/frame/join/checks.rs b/crates/polars-ops/src/frame/join/checks.rs deleted file mode 100644 index 0fa179afba7b..000000000000 --- a/crates/polars-ops/src/frame/join/checks.rs +++ /dev/null @@ -1,18 +0,0 @@ -use super::*; - -/// If Categorical types are created without a global string cache or under -/// a different global string cache the mapping will be incorrect. -pub(crate) fn _check_categorical_src(l: &DataType, r: &DataType) -> PolarsResult<()> { - match (l, r) { - (DataType::Categorical(Some(l), _), DataType::Categorical(Some(r), _)) - | (DataType::Enum(Some(l), _), DataType::Enum(Some(r), _)) => { - polars_ensure!(l.same_src(r), string_cache_mismatch); - }, - (DataType::Categorical(_, _), DataType::Enum(_, _)) - | (DataType::Enum(_, _), DataType::Categorical(_, _)) => { - polars_bail!(ComputeError: "enum and categorical are not from the same source") - }, - _ => (), - }; - Ok(()) -} diff --git a/crates/polars-ops/src/frame/join/dispatch_left_right.rs b/crates/polars-ops/src/frame/join/dispatch_left_right.rs index 08289b3e86bc..91b2757e4a13 100644 --- a/crates/polars-ops/src/frame/join/dispatch_left_right.rs +++ b/crates/polars-ops/src/frame/join/dispatch_left_right.rs @@ -42,9 +42,6 @@ pub fn materialize_left_join_from_series( verbose: bool, drop_names: Option>, ) -> PolarsResult<(DataFrame, DataFrame)> { - #[cfg(feature = "dtype-categorical")] - _check_categorical_src(s_left.dtype(), s_right.dtype())?; - let mut s_left = s_left.clone(); // Eagerly limit left if possible. if let Some((offset, len)) = args.slice { diff --git a/crates/polars-ops/src/frame/join/hash_join/mod.rs b/crates/polars-ops/src/frame/join/hash_join/mod.rs index 9653e8593609..1ee235ce1cb1 100644 --- a/crates/polars-ops/src/frame/join/hash_join/mod.rs +++ b/crates/polars-ops/src/frame/join/hash_join/mod.rs @@ -133,8 +133,6 @@ pub trait JoinDispatch: IntoDf { nulls_equal: bool, ) -> PolarsResult { let ca_self = self.to_df(); - #[cfg(feature = "dtype-categorical")] - _check_categorical_src(s_left.dtype(), s_right.dtype())?; let idx = s_left.hash_join_semi_anti(s_right, anti, nulls_equal)?; // SAFETY: @@ -149,8 +147,6 @@ pub trait JoinDispatch: IntoDf { args: JoinArgs, ) -> PolarsResult { let df_self = self.to_df(); - #[cfg(feature = "dtype-categorical")] - _check_categorical_src(s_left.dtype(), s_right.dtype())?; // Get the indexes of the joined relations let (mut join_idx_l, mut join_idx_r) = diff --git a/crates/polars-ops/src/frame/join/merge_sorted.rs b/crates/polars-ops/src/frame/join/merge_sorted.rs index 23ec1b5392c1..5f25207589bf 100644 --- a/crates/polars-ops/src/frame/join/merge_sorted.rs +++ b/crates/polars-ops/src/frame/join/merge_sorted.rs @@ -1,31 +1,6 @@ use arrow::legacy::utils::{CustomIterTools, FromTrustedLenIterator}; use polars_core::prelude::*; -use polars_core::with_match_physical_numeric_polars_type; - -fn check_and_union_revmaps( - lhs_revmap: &Option>, - rhs_revmap: &Option>, -) -> PolarsResult>> { - // Ensure we are operating on either identical locals, or compatible globals. - let lhs_revmap = lhs_revmap.as_ref().unwrap(); - let rhs_revmap = rhs_revmap.as_ref().unwrap(); - match (&**lhs_revmap, &**rhs_revmap) { - (RevMapping::Local(_, l_hash), RevMapping::Local(_, r_hash)) => { - // Same local categoricals, we return immediately - polars_ensure!(l_hash == r_hash, ComputeError: "cannot merge-sort incompatible categoricals"); - Ok(None) - }, - // Return revmap that is the union of the two revmaps. - (RevMapping::Global(_, _, l), RevMapping::Global(_, _, r)) => { - polars_ensure!(l == r, ComputeError: "cannot merge-sort incompatible categoricals"); - let mut rev_map_merger = GlobalRevMapMerger::new(lhs_revmap.clone()); - rev_map_merger.merge_map(rhs_revmap)?; - let new_map = rev_map_merger.finish(); - Ok(Some(new_map)) - }, - _ => unreachable!(), - } -} +use polars_core::{with_match_categorical_physical_type, with_match_physical_numeric_polars_type}; pub fn _merge_sorted_dfs( left: &DataFrame, @@ -45,15 +20,6 @@ pub fn _merge_sorted_dfs( ComputeError: "merge-sort datatype mismatch: {} != {}", dtype_lhs, dtype_rhs ); - if dtype_lhs.is_categorical() { - let rev_map_lhs = left_s.categorical().unwrap().get_rev_map(); - let rev_map_rhs = right_s.categorical().unwrap().get_rev_map(); - polars_ensure!( - rev_map_lhs.same_src(rev_map_rhs), - ComputeError: "can only merge-sort categoricals with the same categories" - ); - } - // If one frame is empty, we can return the other immediately. if right_s.is_empty() { return Ok(left.clone()); @@ -76,20 +42,7 @@ pub fn _merge_sorted_dfs( &merge_indicator, )?); - let lhs_dt = lhs.dtype(); - let dtype_out = match (lhs_dt, rhs.dtype()) { - // Global categorical revmaps must be merged for the output. - (DataType::Categorical(lhs_revmap, ord), DataType::Categorical(rhs_revmap, _)) => { - if let Some(new_revmap) = check_and_union_revmaps(lhs_revmap, rhs_revmap)? { - &DataType::Categorical(Some(new_revmap), *ord) - } else { - lhs_dt - } - }, - _ => lhs_dt, - }; - - let mut out = unsafe { out.from_physical_unchecked(dtype_out) }.unwrap(); + let mut out = unsafe { out.from_physical_unchecked(lhs.dtype()) }.unwrap(); out.rename(lhs.name().clone()); Ok(out) }) @@ -182,13 +135,12 @@ where } fn series_to_merge_indicator(lhs: &Series, rhs: &Series) -> PolarsResult> { - if lhs.dtype().is_categorical() { - let lhs_ca = lhs.categorical().unwrap(); - if lhs_ca.uses_lexical_ordering() { - let rhs_ca = rhs.categorical().unwrap(); - let out = get_merge_indicator(lhs_ca.iter_str(), rhs_ca.iter_str()); - return Ok(out); - } + if let Ok(cat_phys) = lhs.dtype().cat_physical() { + with_match_categorical_physical_type!(cat_phys, |$C| { + let lhs = lhs.cat::<$C>().unwrap(); + let rhs = rhs.cat::<$C>().unwrap(); + return Ok(get_merge_indicator(lhs.iter_str(), rhs.iter_str())); + }) } let lhs_s = lhs.to_physical_repr().into_owned(); diff --git a/crates/polars-ops/src/frame/join/mod.rs b/crates/polars-ops/src/frame/join/mod.rs index c64501760ce4..0a5120e1474c 100644 --- a/crates/polars-ops/src/frame/join/mod.rs +++ b/crates/polars-ops/src/frame/join/mod.rs @@ -1,8 +1,6 @@ mod args; #[cfg(feature = "asof_join")] mod asof; -#[cfg(feature = "dtype-categorical")] -mod checks; mod cross_join; mod dispatch_left_right; mod general; @@ -20,8 +18,6 @@ pub use args::*; use arrow::trusted_len::TrustedLen; #[cfg(feature = "asof_join")] pub use asof::{AsOfOptions, AsofJoin, AsofJoinBy, AsofStrategy}; -#[cfg(feature = "dtype-categorical")] -pub(crate) use checks::*; pub use cross_join::CrossJoin; #[cfg(feature = "chunked_ids")] use either::Either; @@ -223,19 +219,6 @@ pub trait DataFrameJoinOps: IntoDf { ); }; - #[cfg(feature = "dtype-categorical")] - for (l, r) in selected_left.iter_mut().zip(selected_right.iter_mut()) { - match _check_categorical_src(l.dtype(), r.dtype()) { - Ok(_) => {}, - Err(_) => { - let (ca_left, ca_right) = - make_rhs_categoricals_compatible(l.categorical()?, r.categorical()?)?; - *l = ca_left.into_series().with_name(l.name().clone()); - *r = ca_right.into_series().with_name(r.name().clone()); - }, - } - } - #[cfg(feature = "iejoin")] if let JoinType::IEJoin = args.how { let Some(JoinTypeOptions::IEJoin(options)) = options else { @@ -559,8 +542,6 @@ trait DataFrameJoinOpsPrivate: IntoDf { drop_names: Option>, ) -> PolarsResult { let left_df = self.to_df(); - #[cfg(feature = "dtype-categorical")] - _check_categorical_src(s_left.dtype(), s_right.dtype())?; let ((join_tuples_left, join_tuples_right), sorted) = _sort_or_hash_inner(s_left, s_right, verbose, args.validation, args.nulls_equal)?; diff --git a/crates/polars-ops/src/frame/pivot/mod.rs b/crates/polars-ops/src/frame/pivot/mod.rs index 709b491bdcbd..a0397c462689 100644 --- a/crates/polars-ops/src/frame/pivot/mod.rs +++ b/crates/polars-ops/src/frame/pivot/mod.rs @@ -29,22 +29,6 @@ pub enum PivotAgg { fn restore_logical_type(s: &Series, logical_type: &DataType) -> Series { // restore logical type match (logical_type, s.dtype()) { - #[cfg(feature = "dtype-categorical")] - (dt @ DataType::Categorical(Some(rev_map), ordering), _) - | (dt @ DataType::Enum(Some(rev_map), ordering), _) => { - let cats = s.u32().unwrap().clone(); - // SAFETY: - // the rev-map comes from these categoricals - unsafe { - CategoricalChunked::from_cats_and_rev_map_unchecked( - cats, - rev_map.clone(), - matches!(dt, DataType::Enum(_, _)), - *ordering, - ) - .into_series() - } - }, (DataType::Float32, DataType::UInt32) => { let ca = s.u32().unwrap(); ca._reinterpret_float().into_series() diff --git a/crates/polars-ops/src/series/ops/abs.rs b/crates/polars-ops/src/series/ops/abs.rs index e93e3c13c60d..0046b8031fcc 100644 --- a/crates/polars-ops/src/series/ops/abs.rs +++ b/crates/polars-ops/src/series/ops/abs.rs @@ -20,7 +20,7 @@ pub fn abs(s: &Series) -> PolarsResult { let precision = ca.precision(); let scale = ca.scale(); - let out = ca.as_ref().wrapping_abs(); + let out = ca.physical().wrapping_abs(); out.into_decimal_unchecked(precision, scale).into_series() }, #[cfg(feature = "dtype-duration")] diff --git a/crates/polars-ops/src/series/ops/arg_min_max.rs b/crates/polars-ops/src/series/ops/arg_min_max.rs index 2fc6ed2d78f3..dc728df69702 100644 --- a/crates/polars-ops/src/series/ops/arg_min_max.rs +++ b/crates/polars-ops/src/series/ops/arg_min_max.rs @@ -4,6 +4,7 @@ use polars_core::chunked_array::ops::float_sorted_arg_max::{ float_arg_max_sorted_ascending, float_arg_max_sorted_descending, }; use polars_core::series::IsSorted; +use polars_core::with_match_categorical_physical_type; use super::*; @@ -42,25 +43,25 @@ macro_rules! with_match_physical_numeric_polars_type {( impl ArgAgg for Series { fn arg_min(&self) -> Option { use DataType::*; - let s = self.to_physical_repr(); + let phys_s = self.to_physical_repr(); match self.dtype() { #[cfg(feature = "dtype-categorical")] - Categorical(_, _) => { - let ca = self.categorical().unwrap(); - if ca.null_count() == ca.len() { - return None; - } - if ca.uses_lexical_ordering() { + Categorical(cats, _) => { + with_match_categorical_physical_type!(cats.physical(), |$C| { + let ca = self.cat::<$C>().unwrap(); + if ca.null_count() == ca.len() { + return None; + } ca.iter_str() .enumerate() .flat_map(|(idx, val)| val.map(|val| (idx, val))) .reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc }) .map(|tpl| tpl.0) - } else { - let ca = s.u32().unwrap(); - arg_min_numeric_dispatch(ca) - } + }) }, + #[cfg(feature = "dtype-categorical")] + Enum(_, _) => phys_s.arg_min(), + Date | Datetime(_, _) | Duration(_) | Time => phys_s.arg_min(), String => { let ca = self.str().unwrap(); arg_min_str(ca) @@ -69,17 +70,9 @@ impl ArgAgg for Series { let ca = self.bool().unwrap(); arg_min_bool(ca) }, - Date => { - let ca = s.i32().unwrap(); - arg_min_numeric_dispatch(ca) - }, - Datetime(_, _) | Duration(_) | Time => { - let ca = s.i64().unwrap(); - arg_min_numeric_dispatch(ca) - }, dt if dt.is_primitive_numeric() => { - with_match_physical_numeric_polars_type!(s.dtype(), |$T| { - let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + with_match_physical_numeric_polars_type!(phys_s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = phys_s.as_ref().as_ref().as_ref(); arg_min_numeric_dispatch(ca) }) }, @@ -89,24 +82,25 @@ impl ArgAgg for Series { fn arg_max(&self) -> Option { use DataType::*; - let s = self.to_physical_repr(); + let phys_s = self.to_physical_repr(); match self.dtype() { #[cfg(feature = "dtype-categorical")] - Categorical(_, _) => { - let ca = self.categorical().unwrap(); - if ca.null_count() == ca.len() { - return None; - } - if ca.uses_lexical_ordering() { + Categorical(cats, _) => { + with_match_categorical_physical_type!(cats.physical(), |$C| { + let ca = self.cat::<$C>().unwrap(); + if ca.null_count() == ca.len() { + return None; + } ca.iter_str() .enumerate() + .flat_map(|(idx, val)| val.map(|val| (idx, val))) .reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc }) .map(|tpl| tpl.0) - } else { - let ca_phys = s.u32().unwrap(); - arg_max_numeric_dispatch(ca_phys) - } + }) }, + #[cfg(feature = "dtype-categorical")] + Enum(_, _) => phys_s.arg_max(), + Date | Datetime(_, _) | Duration(_) | Time => phys_s.arg_max(), String => { let ca = self.str().unwrap(); arg_max_str(ca) @@ -115,17 +109,9 @@ impl ArgAgg for Series { let ca = self.bool().unwrap(); arg_max_bool(ca) }, - Date => { - let ca = s.i32().unwrap(); - arg_max_numeric_dispatch(ca) - }, - Datetime(_, _) | Duration(_) | Time => { - let ca = s.i64().unwrap(); - arg_max_numeric_dispatch(ca) - }, dt if dt.is_primitive_numeric() => { - with_match_physical_numeric_polars_type!(s.dtype(), |$T| { - let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + with_match_physical_numeric_polars_type!(phys_s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = phys_s.as_ref().as_ref().as_ref(); arg_max_numeric_dispatch(ca) }) }, diff --git a/crates/polars-ops/src/series/ops/business.rs b/crates/polars-ops/src/series/ops/business.rs index 448f061e4803..c1bb2590ce71 100644 --- a/crates/polars-ops/src/series/ops/business.rs +++ b/crates/polars-ops/src/series/ops/business.rs @@ -46,8 +46,8 @@ pub fn business_day_count( let out = match (start_dates.len(), end_dates.len()) { (_, 1) => { - if let Some(end_date) = end_dates.get(0) { - start_dates.apply_values(|start_date| { + if let Some(end_date) = end_dates.physical().get(0) { + start_dates.physical().apply_values(|start_date| { business_day_count_impl( start_date, end_date, @@ -61,8 +61,8 @@ pub fn business_day_count( } }, (1, _) => { - if let Some(start_date) = start_dates.get(0) { - end_dates.apply_values(|end_date| { + if let Some(start_date) = start_dates.physical().get(0) { + end_dates.physical().apply_values(|end_date| { business_day_count_impl( start_date, end_date, @@ -82,15 +82,19 @@ pub fn business_day_count( start_dates.len(), end_dates.len() ); - binary_elementwise_values(start_dates, end_dates, |start_date, end_date| { - business_day_count_impl( - start_date, - end_date, - &week_mask, - n_business_days_in_week_mask, - &holidays, - ) - }) + binary_elementwise_values( + start_dates.physical(), + end_dates.physical(), + |start_date, end_date| { + business_day_count_impl( + start_date, + end_date, + &week_mask, + n_business_days_in_week_mask, + &holidays, + ) + }, + ) }, }; Ok(out.into_series()) @@ -217,24 +221,26 @@ pub fn add_business_days( let out: Int32Chunked = match (start_dates.len(), n.len()) { (_, 1) => { if let Some(n) = n.get(0) { - start_dates.try_apply_nonnull_values_generic(|start_date| { - let (start_date, day_of_week) = - roll_start_date(start_date, roll, &week_mask, &holidays)?; - Ok::(add_business_days_impl( - start_date, - day_of_week, - n, - &week_mask, - n_business_days_in_week_mask, - &holidays, - )) - })? + start_dates + .physical() + .try_apply_nonnull_values_generic(|start_date| { + let (start_date, day_of_week) = + roll_start_date(start_date, roll, &week_mask, &holidays)?; + Ok::(add_business_days_impl( + start_date, + day_of_week, + n, + &week_mask, + n_business_days_in_week_mask, + &holidays, + )) + })? } else { Int32Chunked::full_null(start_dates.name().clone(), start_dates.len()) } }, (1, _) => { - if let Some(start_date) = start_dates.get(0) { + if let Some(start_date) = start_dates.physical().get(0) { let (start_date, day_of_week) = roll_start_date(start_date, roll, &week_mask, &holidays)?; n.apply_values(|n| { @@ -258,7 +264,7 @@ pub fn add_business_days( start_dates.len(), n.len() ); - try_binary_elementwise(start_dates, n, |opt_start_date, opt_n| { + try_binary_elementwise(start_dates.physical(), n, |opt_start_date, opt_n| { match (opt_start_date, opt_n) { (Some(start_date), Some(n)) => { let (start_date, day_of_week) = @@ -368,11 +374,17 @@ pub fn is_business_day( // Sort now so we can use `binary_search` in the hot for-loop. let holidays = normalise_holidays(holidays, &week_mask); let dates = dates.date()?; - let out: BooleanChunked = dates.apply_nonnull_values_generic(DataType::Boolean, |date| { - let day_of_week = get_day_of_week(date); - // SAFETY: week_mask is length 7, day_of_week is between 0 and 6 - unsafe { (*week_mask.get_unchecked(day_of_week)) && holidays.binary_search(&date).is_err() } - }); + let out: BooleanChunked = + dates + .physical() + .apply_nonnull_values_generic(DataType::Boolean, |date| { + let day_of_week = get_day_of_week(date); + // SAFETY: week_mask is length 7, day_of_week is between 0 and 6 + unsafe { + (*week_mask.get_unchecked(day_of_week)) + && holidays.binary_search(&date).is_err() + } + }); Ok(out.into_series()) } diff --git a/crates/polars-ops/src/series/ops/cum_agg.rs b/crates/polars-ops/src/series/ops/cum_agg.rs index 761a42872206..1749a040d94b 100644 --- a/crates/polars-ops/src/series/ops/cum_agg.rs +++ b/crates/polars-ops/src/series/ops/cum_agg.rs @@ -230,7 +230,7 @@ pub fn cum_sum(s: &Series, reverse: bool) -> PolarsResult { Float64 => cum_sum_numeric(s.f64()?, reverse).into_series(), #[cfg(feature = "dtype-decimal")] Decimal(precision, scale) => { - let ca = s.decimal().unwrap().as_ref(); + let ca = s.decimal().unwrap().physical(); cum_sum_numeric(ca, reverse) .into_decimal_unchecked(*precision, scale.unwrap()) .into_series() @@ -252,7 +252,7 @@ pub fn cum_min(s: &Series, reverse: bool) -> PolarsResult { DataType::Boolean => Ok(cum_min_bool(s.bool()?, reverse).into_series()), #[cfg(feature = "dtype-decimal")] DataType::Decimal(precision, scale) => { - let ca = s.decimal().unwrap().as_ref(); + let ca = s.decimal().unwrap().physical(); let out = cum_min_numeric(ca, reverse) .into_decimal_unchecked(*precision, scale.unwrap()) .into_series(); @@ -280,7 +280,7 @@ pub fn cum_max(s: &Series, reverse: bool) -> PolarsResult { DataType::Boolean => Ok(cum_max_bool(s.bool()?, reverse).into_series()), #[cfg(feature = "dtype-decimal")] DataType::Decimal(precision, scale) => { - let ca = s.decimal().unwrap().as_ref(); + let ca = s.decimal().unwrap().physical(); let out = cum_max_numeric(ca, reverse) .into_decimal_unchecked(*precision, scale.unwrap()) .into_series(); diff --git a/crates/polars-ops/src/series/ops/cut.rs b/crates/polars-ops/src/series/ops/cut.rs index 7016fd08c1f1..354470e7ef80 100644 --- a/crates/polars-ops/src/series/ops/cut.rs +++ b/crates/polars-ops/src/series/ops/cut.rs @@ -1,4 +1,5 @@ use polars_compute::rolling::QuantileMethod; +use polars_core::chunked_array::builder::CategoricalChunkedBuilder; use polars_core::prelude::*; use polars_utils::format_pl_smallstr; @@ -11,12 +12,6 @@ fn map_cats( ) -> PolarsResult { let out_name = PlSmallStr::from_static("category"); - // Create new categorical and pre-register labels for consistent categorical indexes. - let mut bld = CategoricalChunkedBuilder::new(out_name.clone(), s.len(), Default::default()); - for label in labels { - bld.register_value(label); - } - let s2 = s.cast(&DataType::Float64)?; // It would be nice to parallelize this let s_iter = s2.f64()?.into_iter(); @@ -27,25 +22,23 @@ fn map_cats( PartialOrd::gt }; - // Ensure fast unique is only set if all labels were seen. - let mut label_has_value = vec![false; 1 + sorted_breaks.len()]; - if include_breaks { // This is to replicate the behavior of the old buggy version that only worked on series and // returned a dataframe. That included a column of the right endpoint of the interval. So we // return a struct series instead which can be turned into a dataframe later. let right_ends = [sorted_breaks, &[f64::INFINITY]].concat(); + let mut bld = CategoricalChunkedBuilder::::new( + out_name.clone(), + DataType::from_categories(Categories::global()), + ); let mut brk_vals = PrimitiveChunkedBuilder::::new( PlSmallStr::from_static("breakpoint"), s.len(), ); s_iter .map(|opt| { - opt.filter(|x| !x.is_nan()).map(|x| { - let pt = sorted_breaks.partition_point(|v| op(&x, v)); - unsafe { *label_has_value.get_unchecked_mut(pt) = true }; - pt - }) + opt.filter(|x| !x.is_nan()) + .map(|x| sorted_breaks.partition_point(|v| op(&x, v))) }) .for_each(|idx| match idx { None => { @@ -53,28 +46,24 @@ fn map_cats( brk_vals.append_null(); }, Some(idx) => unsafe { - bld.append_value(labels.get_unchecked(idx)); + bld.append_str(labels.get_unchecked(idx)).unwrap(); brk_vals.append_value(*right_ends.get_unchecked(idx)); }, }); - let outvals = [brk_vals.finish().into_series(), unsafe { - bld.finish() - ._with_fast_unique(label_has_value.iter().all(bool::clone)) - .into_series() - }]; + let outvals = [brk_vals.finish().into_series(), bld.finish().into_series()]; Ok(StructChunked::from_series(out_name, outvals[0].len(), outvals.iter())?.into_series()) } else { - Ok(unsafe { - bld.drain_iter_and_finish(s_iter.map(|opt| { + Ok(CategoricalChunked::::from_str_iter( + out_name, + DataType::from_categories(Categories::global()), + s_iter.map(|opt| { opt.filter(|x| !x.is_nan()).map(|x| { let pt = sorted_breaks.partition_point(|v| op(&x, v)); - *label_has_value.get_unchecked_mut(pt) = true; - labels.get_unchecked(pt).as_str() + unsafe { labels.get_unchecked(pt).as_str() } }) - })) - ._with_fast_unique(label_has_value.iter().all(bool::clone)) - } + }), + )? .into_series()) } } @@ -137,7 +126,7 @@ pub fn qcut( return Ok(Series::full_null( s.name().clone(), s.len(), - &DataType::Categorical(None, Default::default()), + &DataType::from_categories(Categories::global()), )); } @@ -182,13 +171,11 @@ mod test { let include_breaks = false; let out = map_cats(&s, labels, breaks, left_closed, include_breaks).unwrap(); - let out = out.categorical().unwrap(); - assert!(out._can_fast_unique()); + out.cat32().unwrap(); let include_breaks = true; let out = map_cats(&s, labels, breaks, left_closed, include_breaks).unwrap(); let out = out.struct_().unwrap().fields_as_series()[1].clone(); - let out = out.categorical().unwrap(); - assert!(out._can_fast_unique()); + out.cat32().unwrap(); } } diff --git a/crates/polars-ops/src/series/ops/is_in.rs b/crates/polars-ops/src/series/ops/is_in.rs index 999dea62c06c..722e7e6ea0b6 100644 --- a/crates/polars-ops/src/series/ops/is_in.rs +++ b/crates/polars-ops/src/series/ops/is_in.rs @@ -4,7 +4,7 @@ use arrow::array::BooleanArray; use arrow::bitmap::BitmapBuilder; use polars_core::prelude::arity::{unary_elementwise, unary_elementwise_values}; use polars_core::prelude::*; -use polars_core::with_match_physical_numeric_polars_type; +use polars_core::{with_match_categorical_physical_type, with_match_physical_numeric_polars_type}; use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; use self::row_encode::_get_rows_encoded_ca_unordered; @@ -407,144 +407,46 @@ fn is_in_boolean( } #[cfg(feature = "dtype-categorical")] -fn is_in_cat_and_enum( - ca_in: &CategoricalChunked, +fn is_in_cat_and_enum( + ca_in: &CategoricalChunked, other: &Series, nulls_equal: bool, -) -> PolarsResult { - use std::borrow::Cow; - - use arrow::array::{Array, FixedSizeListArray, IntoBoxedArray, ListArray}; - - let mut needs_remap = false; +) -> PolarsResult +where + T::Native: ToTotalOrd, +{ let to_categories = match (ca_in.dtype(), other.dtype().inner_dtype().unwrap()) { - (DataType::Enum(revmap, ordering), DataType::String) => { - let categories = revmap.as_deref().unwrap().get_categories(); - (&|s: Series| { - let ca = s.str()?; - let ca = CategoricalChunked::from_string_to_enum(ca, categories, *ordering)?; - let ca = ca.into_physical(); - Ok(ca.into_series()) - }) as _ - }, - (DataType::Categorical(revmap, ordering), DataType::String) => { + (DataType::Enum(_, mapping) | DataType::Categorical(_, mapping), DataType::String) => { (&|s: Series| { - let categories = revmap.as_deref().unwrap().get_categories(); let ca = s.str()?; - let ca = - if ca_in.get_rev_map().is_local() { - assert!(categories.len() < u32::MAX as usize); - let cats = PlIndexSet::from_iter(categories.values_iter()); - UInt32Chunked::from_iter(ca.iter().map(|v| { - v.map(|v| cats.get_index_of(v).map_or(u32::MAX, |n| n as u32)) - })) - } else { - let cat = ca.cast(&DataType::Categorical(None, *ordering))?; - cat.categorical()?.physical().clone() - }; + let ca: ChunkedArray = ca + .iter() + .flat_map(|opt_s| { + if let Some(s) = opt_s { + Some(mapping.get_cat(s).map(T::Native::from_cat)) + } else { + Some(None) + } + }) + .collect_ca(PlSmallStr::EMPTY); Ok(ca.into_series()) }) as _ }, - (DataType::Categorical(revmap, _), DataType::Categorical(other_revmap, _)) => { - let (Some(revmap), Some(other_revmap)) = (revmap, other_revmap) else { - polars_bail!(ComputeError: "expected revmap to be set at this point"); - }; - needs_remap = !revmap.same_src(other_revmap); - (&|s: Series| { - let ca = s.categorical()?; - let ca = ca.physical().clone(); - Ok(ca.into_series()) - }) as _ + (DataType::Categorical(lcats, _), DataType::Categorical(rcats, _)) => { + ensure_same_categories(lcats, rcats)?; + (&|s: Series| Ok(s.cat::()?.physical().clone().into_series())) as _ }, - (DataType::Enum(revmap, _), DataType::Enum(other_revmap, _)) => { - let (Some(revmap), Some(other_revmap)) = (revmap, other_revmap) else { - polars_bail!(ComputeError: "expected revmap to be set at this point"); - }; - polars_ensure!( - revmap.same_src(other_revmap), - opq = is_in, - ca_in.dtype(), - other.dtype() - ); - (&|s: Series| { - let ca = s.categorical()?; - let ca = ca.physical().clone(); - Ok(ca.into_series()) - }) as _ + (DataType::Enum(lfcats, _), DataType::Enum(rfcats, _)) => { + ensure_same_frozen_categories(lfcats, rfcats)?; + (&|s: Series| Ok(s.cat::()?.physical().clone().into_series())) as _ }, _ => polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()), }; - let mut ca_in = Cow::Borrowed(ca_in); let other = match other.dtype() { - DataType::List(_) => { - let mut other = Cow::Borrowed(other.list()?); - if needs_remap { - let other_rechunked = other.rechunk(); - let other_arr = other_rechunked.downcast_as_array(); - - let other_inner = other.get_inner(); - let other_offsets = other_arr.offsets().clone(); - let other_inner = other_inner.categorical()?; - let (ca_in_remapped, other_inner) = - make_rhs_categoricals_compatible(&ca_in, other_inner)?; - - let other_inner_phys = other_inner.physical().rechunk(); - let other_inner_phys = other_inner_phys.downcast_as_array(); - - let other_phys = ListArray::try_new( - other_arr.dtype().clone(), - other_offsets, - other_inner_phys.clone().into_boxed(), - other_arr.validity().cloned(), - )?; - - other = Cow::Owned(unsafe { - ListChunked::from_chunks_and_dtype( - other.name().clone(), - vec![other_phys.into_boxed()], - DataType::List(Box::new(other_inner.dtype().clone())), - ) - }); - ca_in = Cow::Owned(ca_in_remapped); - } - let other = other.apply_to_inner(to_categories)?; - other.into_series() - }, + DataType::List(_) => other.list()?.apply_to_inner(to_categories)?.into_series(), #[cfg(feature = "dtype-array")] - DataType::Array(_, _) => { - let mut other = Cow::Borrowed(other.array()?); - if needs_remap { - let other_rechunked = other.rechunk(); - let other_arr = other_rechunked.downcast_as_array(); - - let other_inner = other.get_inner(); - let other_inner = other_inner.categorical()?; - let (ca_in_remapped, other_inner) = - make_rhs_categoricals_compatible(&ca_in, other_inner)?; - - let other_inner_phys = other_inner.physical().rechunk(); - let other_inner_phys = other_inner_phys.downcast_as_array(); - - let other_phys = FixedSizeListArray::try_new( - other_arr.dtype().clone(), - other.len(), - other_inner_phys.clone().into_boxed(), - other_arr.validity().cloned(), - )?; - - other = Cow::Owned(unsafe { - ArrayChunked::from_chunks_and_dtype( - other.name().clone(), - vec![other_phys.into_boxed()], - DataType::Array(Box::new(other_inner.dtype().clone()), other.width()), - ) - }); - ca_in = Cow::Owned(ca_in_remapped); - } - let other = other.apply_to_inner(to_categories)?; - other.into_series() - }, + DataType::Array(_, _) => other.array()?.apply_to_inner(to_categories)?.into_series(), _ => polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()), }; @@ -726,9 +628,10 @@ pub fn is_in(s: &Series, other: &Series, nulls_equal: bool) -> PolarsResult { - let ca = s.categorical().unwrap(); - is_in_cat_and_enum(ca, other, nulls_equal) + dt @ DataType::Categorical(_, _) | dt @ DataType::Enum(_, _) => { + with_match_categorical_physical_type!(dt.cat_physical().unwrap(), |$C| { + is_in_cat_and_enum(s.cat::<$C>().unwrap(), other, nulls_equal) + }) }, DataType::String => { let ca = s.str().unwrap(); diff --git a/crates/polars-ops/src/series/ops/negate.rs b/crates/polars-ops/src/series/ops/negate.rs index 7af246d2810a..d47698e09573 100644 --- a/crates/polars-ops/src/series/ops/negate.rs +++ b/crates/polars-ops/src/series/ops/negate.rs @@ -17,7 +17,7 @@ pub fn negate(s: &Series) -> PolarsResult { let precision = ca.precision(); let scale = ca.scale(); - let out = ca.as_ref().wrapping_neg(); + let out = ca.physical().wrapping_neg(); out.into_decimal_unchecked(precision, scale).into_series() }, #[cfg(feature = "dtype-duration")] diff --git a/crates/polars-ops/src/series/ops/replace.rs b/crates/polars-ops/src/series/ops/replace.rs index 35abca02d19e..00fba55faac3 100644 --- a/crates/polars-ops/src/series/ops/replace.rs +++ b/crates/polars-ops/src/series/ops/replace.rs @@ -48,7 +48,7 @@ pub fn replace(s: &Series, old: &ListChunked, new: &ListChunked) -> PolarsResult validate_old(&old)?; let dtype = s.dtype(); - let old = cast_old_to_series_dtype(&old, dtype)?; + let old = old.strict_cast(dtype)?; let new = new.strict_cast(dtype)?; if new.len() == 1 { @@ -107,7 +107,7 @@ pub fn replace_or_default( return Ok(out); } - let old = cast_old_to_series_dtype(&old, s.dtype())?; + let old = old.strict_cast(s.dtype())?; let new = new.cast(&return_dtype)?; if new.len() == 1 { @@ -148,7 +148,7 @@ pub fn replace_strict( } validate_old(&old)?; - let old = cast_old_to_series_dtype(&old, s.dtype())?; + let old = old.strict_cast(s.dtype())?; let new = match return_dtype { Some(dtype) => new.strict_cast(&dtype)?, None => new.clone(), @@ -170,18 +170,6 @@ fn validate_old(old: &Series) -> PolarsResult<()> { Ok(()) } -/// Cast `old` input while enabling String to Categorical casts. -fn cast_old_to_series_dtype(old: &Series, dtype: &DataType) -> PolarsResult { - match (old.dtype(), dtype) { - #[cfg(feature = "dtype-categorical")] - (DataType::String, DataType::Categorical(_, ord)) => { - let empty_categorical_dtype = DataType::Categorical(None, *ord); - old.strict_cast(&empty_categorical_dtype) - }, - _ => old.strict_cast(dtype), - } -} - // Fast path for replacing by a single value fn replace_by_single( s: &Series, diff --git a/crates/polars-ops/src/series/ops/round.rs b/crates/polars-ops/src/series/ops/round.rs index d0493eb7d0ef..23521433fa72 100644 --- a/crates/polars-ops/src/series/ops/round.rs +++ b/crates/polars-ops/src/series/ops/round.rs @@ -177,7 +177,7 @@ pub trait RoundSeries: SeriesSealed { let threshold = multiplier / 2; let res = match mode { - RoundMode::HalfToEven => ca.apply_values(|v| { + RoundMode::HalfToEven => ca.physical().apply_values(|v| { let rem_big = v % (2 * multiplier); let is_v_floor_even = rem_big.abs() < multiplier; let rem = if is_v_floor_even { @@ -196,7 +196,7 @@ pub trait RoundSeries: SeriesSealed { }; v - rem + round_offset }), - RoundMode::HalfAwayFromZero => ca.apply_values(|v| { + RoundMode::HalfAwayFromZero => ca.physical().apply_values(|v| { let rem = v % multiplier; let round_offset = if rem.abs() >= threshold { if v < 0 { -multiplier } else { multiplier } @@ -225,6 +225,7 @@ pub trait RoundSeries: SeriesSealed { let scale = ca.scale() as u32; let s = ca + .physical() .apply_values(|v| { if v == 0 { return 0; @@ -307,6 +308,7 @@ pub trait RoundSeries: SeriesSealed { let multiplier = 10i128.pow(decimal_delta); let ca = ca + .physical() .apply_values(|v| { let rem = v % multiplier; let round_offset = if v < 0 { multiplier + rem } else { rem }; @@ -346,6 +348,7 @@ pub trait RoundSeries: SeriesSealed { let multiplier = 10i128.pow(decimal_delta); let ca = ca + .physical() .apply_values(|v| { let rem = v % multiplier; let round_offset = if v < 0 { -rem } else { multiplier - rem }; diff --git a/crates/polars-parquet/src/arrow/read/deserialize/categorical.rs b/crates/polars-parquet/src/arrow/read/deserialize/categorical.rs index 9c3cf6b77c18..20501fb028b2 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/categorical.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/categorical.rs @@ -1,4 +1,6 @@ -use arrow::array::{DictionaryArray, MutableBinaryViewArray, PrimitiveArray}; +use std::marker::PhantomData; + +use arrow::array::{DictionaryArray, DictionaryKey, MutableBinaryViewArray, PrimitiveArray}; use arrow::bitmap::{Bitmap, BitmapBuilder}; use arrow::datatypes::ArrowDataType; use arrow::types::{AlignedBytes, NativeType}; @@ -10,14 +12,17 @@ use crate::parquet::encoding::Encoding; use crate::parquet::encoding::hybrid_rle::HybridRleDecoder; use crate::parquet::error::ParquetResult; use crate::parquet::page::{DataPage, DictPage}; +use crate::read::deserialize::dictionary_encoded::IndexMapping; -impl<'a> StateTranslation<'a, CategoricalDecoder> for HybridRleDecoder<'a> { +impl<'a, T: DictionaryKey + IndexMapping> + StateTranslation<'a, CategoricalDecoder> for HybridRleDecoder<'a> +{ type PlainDecoder = HybridRleDecoder<'a>; fn new( - _decoder: &CategoricalDecoder, + _decoder: &CategoricalDecoder, page: &'a DataPage, - _dict: Option<&'a ::Dict>, + _dict: Option<&'a as Decoder>::Dict>, page_validity: Option<&Bitmap>, ) -> ParquetResult { if !matches!( @@ -39,29 +44,33 @@ impl<'a> StateTranslation<'a, CategoricalDecoder> for HybridRleDecoder<'a> { /// These are marked as special in the Arrow Field Metadata and they have the properly that for a /// given row group all the values are in the dictionary page and all data pages are dictionary /// encoded. This makes the job of decoding them extremely simple and fast. -pub struct CategoricalDecoder { +pub struct CategoricalDecoder { dict_size: usize, decoder: BinViewDecoder, + key_type: PhantomData, } -impl CategoricalDecoder { +impl CategoricalDecoder { pub fn new() -> Self { Self { dict_size: usize::MAX, decoder: BinViewDecoder::new_string(), + key_type: PhantomData, } } } -impl utils::Decoder for CategoricalDecoder { +impl> utils::Decoder + for CategoricalDecoder +{ type Translation<'a> = HybridRleDecoder<'a>; type Dict = ::Dict; - type DecodedState = (Vec, BitmapBuilder); - type Output = DictionaryArray; + type DecodedState = (Vec, BitmapBuilder); + type Output = DictionaryArray; fn with_capacity(&self, capacity: usize) -> Self::DecodedState { ( - Vec::::with_capacity(capacity), + Vec::::with_capacity(capacity), BitmapBuilder::with_capacity(capacity), ) } @@ -88,7 +97,7 @@ impl utils::Decoder for CategoricalDecoder { ) -> ParquetResult<()> { let additional = additional .as_any() - .downcast_ref::>() + .downcast_ref::>() .unwrap(); decoded.0.extend(additional.keys().values().iter().copied()); match additional.validity() { @@ -105,10 +114,10 @@ impl utils::Decoder for CategoricalDecoder { dtype: ArrowDataType, dict: Option, (values, validity): Self::DecodedState, - ) -> ParquetResult> { + ) -> ParquetResult> { let validity = freeze_validity(validity); let dict = dict.unwrap(); - let keys = PrimitiveArray::new(ArrowDataType::UInt32, values.into(), validity); + let keys = PrimitiveArray::new(T::PRIMITIVE.into(), values.into(), validity); let mut view_dict = MutableBinaryViewArray::with_capacity(dict.len()); let (views, buffers, _, _, _) = dict.into_inner(); @@ -136,13 +145,13 @@ impl utils::Decoder for CategoricalDecoder { ) -> ParquetResult<()> { super::dictionary_encoded::decode_dict_dispatch( state.translation, - self.dict_size, + T::try_from(self.dict_size).ok().unwrap(), state.dict_mask, state.is_optional, state.page_validity.as_ref(), filter, &mut decoded.1, - <::AlignedBytes as AlignedBytes>::cast_vec_ref_mut(&mut decoded.0), + <::AlignedBytes as AlignedBytes>::cast_vec_ref_mut(&mut decoded.0), pred_true_mask, ) } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs index 6e18b36292c3..b4afaa426c29 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs @@ -1,6 +1,8 @@ use arrow::bitmap::bitmask::BitMask; use arrow::bitmap::{Bitmap, BitmapBuilder}; -use arrow::types::{AlignedBytes, Bytes4Alignment4, NativeType}; +use arrow::types::{ + AlignedBytes, Bytes1Alignment1, Bytes2Alignment2, Bytes4Alignment4, NativeType, +}; use polars_compute::filter::filter_boolean_kernel; use super::ParquetError; @@ -16,7 +18,7 @@ mod required_masked_dense; /// A mapping from a `u32` to a value. This is used in to map dictionary encoding to a value. pub trait IndexMapping { - type Output: Copy; + type Output: Copy + AlignedBytes; fn is_empty(&self) -> bool { self.len() == 0 @@ -29,13 +31,14 @@ pub trait IndexMapping { } // Base mapping used for everything except the CategoricalDecoder. -impl IndexMapping for &[T] { +impl IndexMapping for &[T] { type Output = T; #[inline(always)] fn len(&self) -> usize { <[T]>::len(self) } + #[inline(always)] unsafe fn get_unchecked(&self, idx: u32) -> Self::Output { *unsafe { <[T]>::get_unchecked(self, idx as usize) } @@ -43,13 +46,42 @@ impl IndexMapping for &[T] { } // Unit mapping used in the CategoricalDecoder. -impl IndexMapping for usize { +impl IndexMapping for u8 { + type Output = Bytes1Alignment1; + + #[inline(always)] + fn len(&self) -> usize { + *self as usize + } + + #[inline(always)] + unsafe fn get_unchecked(&self, idx: u32) -> Self::Output { + bytemuck::must_cast(idx as u8) + } +} + +impl IndexMapping for u16 { + type Output = Bytes2Alignment2; + + #[inline(always)] + fn len(&self) -> usize { + *self as usize + } + + #[inline(always)] + unsafe fn get_unchecked(&self, idx: u32) -> Self::Output { + bytemuck::must_cast(idx as u16) + } +} + +impl IndexMapping for u32 { type Output = Bytes4Alignment4; #[inline(always)] fn len(&self) -> usize { - *self + *self as usize } + #[inline(always)] unsafe fn get_unchecked(&self, idx: u32) -> Self::Output { bytemuck::must_cast(idx) diff --git a/crates/polars-parquet/src/arrow/read/deserialize/nested.rs b/crates/polars-parquet/src/arrow/read/deserialize/nested.rs index e7415a0eeb9e..5ea4e6e12c96 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/nested.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/nested.rs @@ -1,5 +1,8 @@ use arrow::array::StructArray; -use arrow::datatypes::{DTYPE_CATEGORICAL, DTYPE_ENUM_VALUES, IntegerType}; +use arrow::datatypes::{ + DTYPE_CATEGORICAL_LEGACY, DTYPE_CATEGORICAL_NEW, DTYPE_ENUM_VALUES_LEGACY, + DTYPE_ENUM_VALUES_NEW, IntegerType, +}; use polars_compute::cast::CastOptionsImpl; use self::categorical::CategoricalDecoder; @@ -142,7 +145,10 @@ pub fn columns_to_iter_recursive( init.push(InitNested::Primitive(field.is_nullable)); if field.metadata.as_ref().is_none_or(|md| { - !md.contains_key(DTYPE_ENUM_VALUES) && !md.contains_key(DTYPE_CATEGORICAL) + !md.contains_key(DTYPE_ENUM_VALUES_LEGACY) + && !md.contains_key(DTYPE_ENUM_VALUES_NEW) + && !md.contains_key(DTYPE_CATEGORICAL_NEW) + && !md.contains_key(DTYPE_CATEGORICAL_LEGACY) }) { let (nested, arr, ptm) = PageDecoder::new( &field.name, @@ -162,16 +168,33 @@ pub fn columns_to_iter_recursive( Ok((nested, arr, ptm)) } else { - assert!(matches!(key_type, IntegerType::UInt32)); - - let (nested, arr, ptm) = PageDecoder::new( - &field.name, - columns.pop().unwrap(), - field.dtype().clone(), - CategoricalDecoder::new(), - Some(init), - )? - .collect_boxed(filter)?; + let (nested, arr, ptm) = match key_type { + IntegerType::UInt8 => PageDecoder::new( + &field.name, + columns.pop().unwrap(), + field.dtype().clone(), + CategoricalDecoder::::new(), + Some(init), + )? + .collect_boxed(filter)?, + IntegerType::UInt16 => PageDecoder::new( + &field.name, + columns.pop().unwrap(), + field.dtype().clone(), + CategoricalDecoder::::new(), + Some(init), + )? + .collect_boxed(filter)?, + IntegerType::UInt32 => PageDecoder::new( + &field.name, + columns.pop().unwrap(), + field.dtype().clone(), + CategoricalDecoder::::new(), + Some(init), + )? + .collect_boxed(filter)?, + _ => unimplemented!(), + }; Ok((nested.unwrap(), arr, ptm)) } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/simple.rs b/crates/polars-parquet/src/arrow/read/deserialize/simple.rs index 5d07ef19772c..cf418a8016e6 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/simple.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/simple.rs @@ -1,7 +1,8 @@ use arrow::array::{Array, FixedSizeBinaryArray, PrimitiveArray}; use arrow::bitmap::Bitmap; use arrow::datatypes::{ - ArrowDataType, DTYPE_CATEGORICAL, DTYPE_ENUM_VALUES, Field, IntegerType, IntervalUnit, TimeUnit, + ArrowDataType, DTYPE_CATEGORICAL_LEGACY, DTYPE_CATEGORICAL_NEW, DTYPE_ENUM_VALUES_LEGACY, + DTYPE_ENUM_VALUES_NEW, Field, IntegerType, IntervalUnit, TimeUnit, }; use arrow::types::{NativeType, days_ms, i256}; use ethnum::I256; @@ -418,17 +419,41 @@ pub fn page_iter_to_array( assert_eq!(value_type.as_ref(), &ArrowDataType::Utf8View); if field.metadata.is_some_and(|md| { - md.contains_key(DTYPE_ENUM_VALUES) || md.contains_key(DTYPE_CATEGORICAL) - }) && matches!(key_type, IntegerType::UInt32) - { - PageDecoder::new( - &field.name, - pages, - dtype, - CategoricalDecoder::new(), - init_nested, - )? - .collect_boxed(filter)? + md.contains_key(DTYPE_ENUM_VALUES_LEGACY) + || md.contains_key(DTYPE_ENUM_VALUES_NEW) + || md.contains_key(DTYPE_CATEGORICAL_NEW) + || md.contains_key(DTYPE_CATEGORICAL_LEGACY) + }) && matches!( + key_type, + IntegerType::UInt8 | IntegerType::UInt16 | IntegerType::UInt32 + ) { + match key_type { + IntegerType::UInt8 => PageDecoder::new( + &field.name, + pages, + dtype, + CategoricalDecoder::::new(), + init_nested, + )? + .collect_boxed(filter)?, + IntegerType::UInt16 => PageDecoder::new( + &field.name, + pages, + dtype, + CategoricalDecoder::::new(), + init_nested, + )? + .collect_boxed(filter)?, + IntegerType::UInt32 => PageDecoder::new( + &field.name, + pages, + dtype, + CategoricalDecoder::::new(), + init_nested, + )? + .collect_boxed(filter)?, + _ => unreachable!(), + } } else { let (nested, array, ptm) = PageDecoder::new( &field.name, diff --git a/crates/polars-parquet/src/arrow/read/schema/metadata.rs b/crates/polars-parquet/src/arrow/read/schema/metadata.rs index 5234f88adea5..3e2d997c6aed 100644 --- a/crates/polars-parquet/src/arrow/read/schema/metadata.rs +++ b/crates/polars-parquet/src/arrow/read/schema/metadata.rs @@ -1,5 +1,6 @@ use arrow::datatypes::{ - ArrowDataType, ArrowSchema, DTYPE_CATEGORICAL, DTYPE_ENUM_VALUES, Field, IntegerType, Metadata, + ArrowDataType, ArrowSchema, DTYPE_CATEGORICAL_LEGACY, DTYPE_CATEGORICAL_NEW, + DTYPE_ENUM_VALUES_LEGACY, DTYPE_ENUM_VALUES_NEW, Field, IntegerType, Metadata, }; use arrow::io::ipc::read::deserialize_schema; use base64::Engine as _; @@ -30,10 +31,16 @@ fn convert_field(field: &mut Field) { // generic dictionary type. field.dtype = match std::mem::take(&mut field.dtype) { ArrowDataType::Dictionary(key_type, value_type, sorted) => { - let is_pl_enum_or_categorical = field.metadata.as_ref().is_some_and(|md| { - md.contains_key(DTYPE_ENUM_VALUES) || md.contains_key(DTYPE_CATEGORICAL) - }) && matches!(key_type, IntegerType::UInt32) - && matches!(value_type.as_ref(), ArrowDataType::Utf8View); + let is_pl_enum_or_categorical = + field.metadata.as_ref().is_some_and(|md| { + md.contains_key(DTYPE_ENUM_VALUES_LEGACY) + || md.contains_key(DTYPE_ENUM_VALUES_NEW) + || md.contains_key(DTYPE_CATEGORICAL_NEW) + || md.contains_key(DTYPE_CATEGORICAL_LEGACY) + }) && matches!( + key_type, + IntegerType::UInt8 | IntegerType::UInt16 | IntegerType::UInt32 + ) && matches!(value_type.as_ref(), ArrowDataType::Utf8View); let is_int_to_str = matches!( value_type.as_ref(), ArrowDataType::Utf8View | ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 diff --git a/crates/polars-plan/dsl-schema.sha256 b/crates/polars-plan/dsl-schema.sha256 index 19718ec2f8c1..f7d5ae66ec49 100644 --- a/crates/polars-plan/dsl-schema.sha256 +++ b/crates/polars-plan/dsl-schema.sha256 @@ -1 +1 @@ -ce306f74eb7b3006fa89fa887c2b0787f6775f0dcfd192e3b1d9f5dec90c2303 \ No newline at end of file +5122d21aa9a266943a47f43b6ab7c064555ec7f359eee4b0c78bc199839f94df \ No newline at end of file diff --git a/crates/polars-plan/src/dsl/plan.rs b/crates/polars-plan/src/dsl/plan.rs index 1b19e13285fe..bc564927dc50 100644 --- a/crates/polars-plan/src/dsl/plan.rs +++ b/crates/polars-plan/src/dsl/plan.rs @@ -48,7 +48,7 @@ use super::*; // - changing a name, type, or meaning of a field or an enum variant // - changing a default value of a field or a default enum variant // - restricting the range of allowed values a field can have -pub static DSL_VERSION: (u16, u16) = (16, 0); +pub static DSL_VERSION: (u16, u16) = (17, 0); static DSL_MAGIC_BYTES: &[u8] = b"DSL_VERSION"; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] diff --git a/crates/polars-plan/src/plans/aexpr/function_expr/cat.rs b/crates/polars-plan/src/plans/aexpr/function_expr/cat.rs index bf1000477965..c51389ec78b7 100644 --- a/crates/polars-plan/src/plans/aexpr/function_expr/cat.rs +++ b/crates/polars-plan/src/plans/aexpr/function_expr/cat.rs @@ -93,29 +93,19 @@ impl From for IRFunctionExpr { } fn get_categories(s: &Column) -> PolarsResult { - // categorical check - let ca = s.categorical()?; - let rev_map = ca.get_rev_map(); - let arr = rev_map.get_categories().clone().boxed(); - Series::try_from((ca.name().clone(), arr)).map(Column::from) + let mapping = s.dtype().cat_mapping()?; + let ca = unsafe { StringChunked::from_chunks(s.name().clone(), vec![mapping.to_arrow(true)]) }; + Ok(Column::from(ca.into_series())) } // Determine mapping between categories and underlying physical. For local, this is just 0..n. // For global, this is the global indexes. -fn _get_cat_phys_map(ca: &CategoricalChunked) -> (StringChunked, Series) { - let (categories, phys) = match &**ca.get_rev_map() { - RevMapping::Local(c, _) => (c, ca.physical().cast(&IDX_DTYPE).unwrap()), - RevMapping::Global(physical_map, c, _) => { - // Map physical to its local representation for use with take() later. - let phys = ca - .physical() - .apply(|opt_v| opt_v.map(|v| *physical_map.get(&v).unwrap())); - let out = phys.cast(&IDX_DTYPE).unwrap(); - (c, out) - }, - }; - let categories = StringChunked::with_chunk(ca.name().clone(), categories.clone()); - (categories, phys) +fn _get_cat_phys_map(col: &Column) -> (StringChunked, Series) { + let mapping = col.dtype().cat_mapping().unwrap(); + let cats = + unsafe { StringChunked::from_chunks(col.name().clone(), vec![mapping.to_arrow(true)]) }; + let phys = col.to_physical_repr().as_materialized_series().clone(); + (cats, phys) } /// Fast path: apply a string function to the categories of a categorical column and broadcast the @@ -123,27 +113,11 @@ fn _get_cat_phys_map(ca: &CategoricalChunked) -> (StringChunked, Series) { // fn apply_to_cats(ca: &CategoricalChunked, mut op: F) -> PolarsResult fn apply_to_cats(c: &Column, mut op: F) -> PolarsResult where - F: FnMut(&StringChunked) -> ChunkedArray, - T: PolarsPhysicalType, -{ - let ca = c.categorical()?; - let (categories, phys) = _get_cat_phys_map(ca); - let result = op(&categories); - // SAFETY: physical idx array is valid. - let out = unsafe { result.take_unchecked(phys.idx().unwrap()) }; - Ok(out.into_column()) -} - -/// Fast path: apply a binary function to the categories of a categorical column and broadcast the -/// result back to the array. -fn apply_to_cats_binary(c: &Column, mut op: F) -> PolarsResult -where - F: FnMut(&BinaryChunked) -> ChunkedArray, + F: FnMut(StringChunked) -> ChunkedArray, T: PolarsPhysicalType, { - let ca = c.categorical()?; - let (categories, phys) = _get_cat_phys_map(ca); - let result = op(&categories.as_binary()); + let (categories, phys) = _get_cat_phys_map(c); + let result = op(categories); // SAFETY: physical idx array is valid. let out = unsafe { result.take_unchecked(phys.idx().unwrap()) }; Ok(out.into_column()) @@ -161,19 +135,18 @@ fn len_chars(c: &Column) -> PolarsResult { #[cfg(feature = "strings")] fn starts_with(c: &Column, prefix: &str) -> PolarsResult { - apply_to_cats_binary(c, |s| s.starts_with(prefix.as_bytes())) + apply_to_cats(c, |s| s.as_binary().starts_with(prefix.as_bytes())) } #[cfg(feature = "strings")] fn ends_with(c: &Column, suffix: &str) -> PolarsResult { - apply_to_cats_binary(c, |s| s.ends_with(suffix.as_bytes())) + apply_to_cats(c, |s| s.as_binary().ends_with(suffix.as_bytes())) } #[cfg(feature = "strings")] fn slice(c: &Column, offset: i64, length: Option) -> PolarsResult { let length = length.unwrap_or(usize::MAX) as u64; - let ca = c.categorical()?; - let (categories, phys) = _get_cat_phys_map(ca); + let (categories, phys) = _get_cat_phys_map(c); let result = unsafe { categories.apply_views(|view, val| { diff --git a/crates/polars-plan/src/plans/aexpr/function_expr/fill_null.rs b/crates/polars-plan/src/plans/aexpr/function_expr/fill_null.rs index eaa65c2f753e..74a73622bc06 100644 --- a/crates/polars-plan/src/plans/aexpr/function_expr/fill_null.rs +++ b/crates/polars-plan/src/plans/aexpr/function_expr/fill_null.rs @@ -19,34 +19,12 @@ pub(super) fn fill_null(s: &[Column]) -> PolarsResult { series.zip_with_same_type(&mask, &fill_value) } - match series.dtype() { - #[cfg(feature = "dtype-categorical")] - // for Categoricals we first need to check if the category already exist - DataType::Categorical(Some(rev_map), _) => { - if rev_map.is_local() && fill_value.len() == 1 && fill_value.null_count() == 0 { - let fill_av = fill_value.get(0).unwrap(); - let fill_str = fill_av.get_str().unwrap(); - - if let Some(idx) = rev_map.find(fill_str) { - let cats = series.to_physical_repr(); - let mask = cats.is_not_null(); - let out = cats - .zip_with_same_type(&mask, &Column::new(PlSmallStr::EMPTY, &[idx])) - .unwrap(); - unsafe { return out.from_physical_unchecked(series.dtype()) } - } - } - let fill_value = if fill_value.dtype().is_string() { - fill_value - .cast(&DataType::Categorical(None, Default::default())) - .unwrap() - } else { - fill_value - }; - default(series, fill_value) - }, - _ => default(series, fill_value), - } + let fill_value = if series.dtype().is_categorical() && fill_value.dtype().is_string() { + fill_value.cast(series.dtype()).unwrap() + } else { + fill_value + }; + default(series, fill_value) }, (1, other_len) => { if s[0].has_nulls() { diff --git a/crates/polars-plan/src/plans/aexpr/function_expr/peaks.rs b/crates/polars-plan/src/plans/aexpr/function_expr/peaks.rs index 702a9dc3c86d..cd5fc7ace6f7 100644 --- a/crates/polars-plan/src/plans/aexpr/function_expr/peaks.rs +++ b/crates/polars-plan/src/plans/aexpr/function_expr/peaks.rs @@ -9,7 +9,7 @@ pub(super) fn peak_min(s: &Column) -> PolarsResult { let s = match s.dtype() { DataType::Boolean => polars_bail!(opq = peak_min, DataType::Boolean), #[cfg(feature = "dtype-decimal")] - DataType::Decimal(_, _) => pmin(s.decimal()?).into_column(), + DataType::Decimal(_, _) => pmin(s.decimal()?.physical()).into_column(), dt => { with_match_physical_numeric_polars_type!(dt, |$T| { let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); @@ -26,7 +26,7 @@ pub(super) fn peak_max(s: &Column) -> PolarsResult { let s = match s.dtype() { DataType::Boolean => polars_bail!(opq = peak_max, DataType::Boolean), #[cfg(feature = "dtype-decimal")] - DataType::Decimal(_, _) => pmax(s.decimal()?).into_column(), + DataType::Decimal(_, _) => pmax(s.decimal()?.physical()).into_column(), dt => { with_match_physical_numeric_polars_type!(dt, |$T| { let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); diff --git a/crates/polars-plan/src/plans/aexpr/function_expr/range/datetime_range.rs b/crates/polars-plan/src/plans/aexpr/function_expr/range/datetime_range.rs index 66d6f5f0136d..5acc8280a770 100644 --- a/crates/polars-plan/src/plans/aexpr/function_expr/range/datetime_range.rs +++ b/crates/polars-plan/src/plans/aexpr/function_expr/range/datetime_range.rs @@ -210,7 +210,7 @@ pub(super) fn datetime_ranges( tu, tz.as_ref(), )?; - builder.append_slice(rng.cont_slice().unwrap()); + builder.append_slice(rng.physical().cont_slice().unwrap()); Ok(()) }; diff --git a/crates/polars-plan/src/plans/aexpr/function_expr/range/time_range.rs b/crates/polars-plan/src/plans/aexpr/function_expr/range/time_range.rs index 640be245c40a..df4bf8546357 100644 --- a/crates/polars-plan/src/plans/aexpr/function_expr/range/time_range.rs +++ b/crates/polars-plan/src/plans/aexpr/function_expr/range/time_range.rs @@ -55,7 +55,7 @@ pub(super) fn time_ranges( let range_impl = |start, end, builder: &mut ListPrimitiveChunkedBuilder| { let rng = time_range_impl(PlSmallStr::EMPTY, start, end, interval, closed)?; - builder.append_slice(rng.cont_slice().unwrap()); + builder.append_slice(rng.physical().cont_slice().unwrap()); Ok(()) }; diff --git a/crates/polars-plan/src/plans/aexpr/function_expr/schema.rs b/crates/polars-plan/src/plans/aexpr/function_expr/schema.rs index c050bc6962f8..093e52038c3f 100644 --- a/crates/polars-plan/src/plans/aexpr/function_expr/schema.rs +++ b/crates/polars-plan/src/plans/aexpr/function_expr/schema.rs @@ -164,7 +164,7 @@ impl IRFunctionExpr { if *include_category { fields.push(Field::new( PlSmallStr::from_static("category"), - DataType::Categorical(None, Default::default()), + DataType::from_categories(Categories::global()), )); } fields.push(Field::new(PlSmallStr::from_static("count"), IDX_DTYPE)); @@ -240,7 +240,7 @@ impl IRFunctionExpr { Cut { include_breaks: false, .. - } => mapper.with_dtype(DataType::Categorical(None, Default::default())), + } => mapper.with_dtype(DataType::from_categories(Categories::global())), #[cfg(feature = "cutqcut")] Cut { include_breaks: true, @@ -250,7 +250,7 @@ impl IRFunctionExpr { Field::new(PlSmallStr::from_static("breakpoint"), DataType::Float64), Field::new( PlSmallStr::from_static("category"), - DataType::Categorical(None, Default::default()), + DataType::from_categories(Categories::global()), ), ]); mapper.with_dtype(struct_dt) @@ -299,7 +299,7 @@ impl IRFunctionExpr { QCut { include_breaks: false, .. - } => mapper.with_dtype(DataType::Categorical(None, Default::default())), + } => mapper.with_dtype(DataType::from_categories(Categories::global())), #[cfg(feature = "cutqcut")] QCut { include_breaks: true, @@ -309,7 +309,7 @@ impl IRFunctionExpr { Field::new(PlSmallStr::from_static("breakpoint"), DataType::Float64), Field::new( PlSmallStr::from_static("category"), - DataType::Categorical(None, Default::default()), + DataType::from_categories(Categories::global()), ), ]); mapper.with_dtype(struct_dt) @@ -654,7 +654,7 @@ pub(crate) fn args_to_supertype>(dtypes: &[D]) -> PolarsResul match (dtypes[0].as_ref(), &st) { #[cfg(feature = "dtype-categorical")] - (DataType::Categorical(_, ord), DataType::String) => st = DataType::Categorical(None, *ord), + (cat @ DataType::Categorical(_, _), DataType::String) => st = cat.clone(), _ => { if let DataType::Unknown(kind) = st { match kind { diff --git a/crates/polars-plan/src/plans/aexpr/properties.rs b/crates/polars-plan/src/plans/aexpr/properties.rs index 113ec7f8e22a..a0dcf690948c 100644 --- a/crates/polars-plan/src/plans/aexpr/properties.rs +++ b/crates/polars-plan/src/plans/aexpr/properties.rs @@ -149,37 +149,6 @@ pub fn is_elementwise_rec(node: Node, expr_arena: &Arena) -> bool { property_rec(node, expr_arena, is_elementwise) } -/// Recursive variant of `is_elementwise` that also forbids casting to categoricals. This function -/// is used to determine if an expression evaluation can be vertically parallelized. -pub fn is_elementwise_rec_no_cat_cast<'a>(mut ae: &'a AExpr, expr_arena: &'a Arena) -> bool { - let mut stack = unitvec![]; - - loop { - if !is_elementwise(&mut stack, ae, expr_arena) { - return false; - } - - #[cfg(feature = "dtype-categorical")] - { - if let AExpr::Cast { - dtype: DataType::Categorical(..), - .. - } = ae - { - return false; - } - } - - let Some(node) = stack.pop() else { - break; - }; - - ae = expr_arena.get(node); - } - - true -} - #[derive(Debug, Clone)] pub enum ExprPushdownGroup { /// Can be pushed. (elementwise, infallible) diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_to_ir.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_to_ir.rs index c5ea3b94b29a..358e9bc72c05 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_to_ir.rs @@ -356,6 +356,7 @@ pub(super) fn to_aexpr_impl( dtype: DataType::Categorical(_, _) | DataType::Enum(_, _), .. } => { + // TODO @ cat-rework: why not? polars_bail!( ComputeError: "casting to categorical not allowed in `list.eval`" ) diff --git a/crates/polars-plan/src/plans/conversion/type_coercion/is_in.rs b/crates/polars-plan/src/plans/conversion/type_coercion/is_in.rs index 79b20c3fc2ef..d84f276debb6 100644 --- a/crates/polars-plan/src/plans/conversion/type_coercion/is_in.rs +++ b/crates/polars-plan/src/plans/conversion/type_coercion/is_in.rs @@ -77,36 +77,22 @@ See https://github.com/pola-rs/polars/issues/22149 for more information." strict: true, }, #[cfg(feature = "dtype-categorical")] - (DataType::String, DataType::Categorical(Some(rm), ordering)) if rm.is_global() => { - IsInTypeCoercionResult::SelfCast { - dtype: DataType::Categorical(None, *ordering), - strict: false, - } + (DataType::String, DataType::Categorical(_, _)) => IsInTypeCoercionResult::SelfCast { + dtype: type_other_inner.clone(), + strict: false, }, - // @NOTE: Local Categorical coercion has to happen in the kernel, which makes it streaming - // incompatible. #[cfg(feature = "dtype-categorical")] - (DataType::Categorical(Some(rm), ordering), DataType::String) if rm.is_global() => { - IsInTypeCoercionResult::OtherCast { - dtype: match &type_other { - DataType::List(_) => { - DataType::List(Box::new(DataType::Categorical(None, *ordering))) - }, - #[cfg(feature = "dtype-array")] - DataType::Array(_, width) => { - DataType::Array(Box::new(DataType::Categorical(None, *ordering)), *width) - }, - _ => unreachable!(), - }, - strict: false, - } + (DataType::Categorical(_, _), DataType::String) => IsInTypeCoercionResult::OtherCast { + dtype: match &type_other { + DataType::List(_) => DataType::List(Box::new(type_left.clone())), + #[cfg(feature = "dtype-array")] + DataType::Array(_, width) => DataType::Array(Box::new(type_left.clone()), *width), + _ => unreachable!(), + }, + strict: false, }, - #[cfg(feature = "dtype-categorical")] - (DataType::Categorical(_, _), DataType::String) => return Ok(None), - #[cfg(feature = "dtype-categorical")] - (DataType::String, DataType::Categorical(_, _)) => return Ok(None), #[cfg(feature = "dtype-decimal")] (DataType::Decimal(_, _), dt) if dt.is_primitive_numeric() => { IsInTypeCoercionResult::OtherCast { diff --git a/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs b/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs index cfb185fd5eef..f2dee665e8c9 100644 --- a/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs +++ b/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs @@ -38,26 +38,20 @@ fn modify_supertype( match (type_left, type_right, left, right) { // if the we compare a categorical to a literal string we want to cast the literal to categorical #[cfg(feature = "dtype-categorical")] - (Categorical(_, ordering), String | Unknown(UnknownKind::Str), _, AExpr::Literal(_)) - | (String | Unknown(UnknownKind::Str), Categorical(_, ordering), AExpr::Literal(_), _) => { - st = Categorical(None, *ordering) - }, - #[cfg(feature = "dtype-categorical")] - (dt @ Enum(_, _), String | Unknown(UnknownKind::Str), _, AExpr::Literal(_)) + (dt @ Categorical(_, _), String | Unknown(UnknownKind::Str), _, AExpr::Literal(_)) + | (String | Unknown(UnknownKind::Str), dt @ Categorical(_, _), AExpr::Literal(_), _) + | (dt @ Enum(_, _), String | Unknown(UnknownKind::Str), _, AExpr::Literal(_)) | (String | Unknown(UnknownKind::Str), dt @ Enum(_, _), AExpr::Literal(_), _) => { st = dt.clone() }, + // when then expression literals can have a different list type. // so we cast the literal to the other hand side. (List(inner), List(other), _, AExpr::Literal(_)) | (List(other), List(inner), AExpr::Literal(_), _) if inner != other => { - st = match &**inner { - #[cfg(feature = "dtype-categorical")] - Categorical(_, ordering) => List(Box::new(Categorical(None, *ordering))), - _ => List(inner.clone()), - }; + st = List(inner.clone()) }, // do nothing _ => {}, @@ -898,11 +892,11 @@ fn try_inline_literal_cast( #[cfg(feature = "dtype-duration")] (AnyValue::Duration(_, _), _) => return Ok(None), #[cfg(feature = "dtype-categorical")] - (AnyValue::Categorical(_, _, _), _) | (_, DataType::Categorical(_, _)) => { + (AnyValue::Categorical(_, _), _) | (_, DataType::Categorical(_, _)) => { return Ok(None); }, #[cfg(feature = "dtype-categorical")] - (AnyValue::Enum(_, _, _), _) | (_, DataType::Enum(_, _)) => return Ok(None), + (AnyValue::Enum(_, _), _) | (_, DataType::Enum(_, _)) => return Ok(None), #[cfg(feature = "dtype-struct")] (_, DataType::Struct(_)) => return Ok(None), (av, _) => { @@ -1037,7 +1031,7 @@ mod test { let df = DataFrame::new(Vec::from([Column::new_empty( PlSmallStr::from_static("fruits"), - &DataType::Categorical(None, Default::default()), + &DataType::from_categories(Categories::global()), )])) .unwrap(); diff --git a/crates/polars-python/src/conversion/any_value.rs b/crates/polars-python/src/conversion/any_value.rs index ba4009893cb1..1a26fa7fb619 100644 --- a/crates/polars-python/src/conversion/any_value.rs +++ b/crates/polars-python/src/conversion/any_value.rs @@ -79,21 +79,11 @@ pub(crate) fn any_value_into_py_object<'py>( AnyValue::Boolean(v) => v.into_bound_py_any(py), AnyValue::String(v) => v.into_bound_py_any(py), AnyValue::StringOwned(v) => v.into_bound_py_any(py), - AnyValue::Categorical(idx, rev, arr) | AnyValue::Enum(idx, rev, arr) => { - let s = if arr.is_null() { - rev.get(idx) - } else { - unsafe { arr.deref_unchecked().value(idx as usize) } - }; - s.into_bound_py_any(py) + AnyValue::Categorical(cat, map) | AnyValue::Enum(cat, map) => unsafe { + map.cat_to_str_unchecked(cat).into_bound_py_any(py) }, - AnyValue::CategoricalOwned(idx, rev, arr) | AnyValue::EnumOwned(idx, rev, arr) => { - let s = if arr.is_null() { - rev.get(idx) - } else { - unsafe { arr.deref_unchecked().value(idx as usize) } - }; - s.into_bound_py_any(py) + AnyValue::CategoricalOwned(cat, map) | AnyValue::EnumOwned(cat, map) => unsafe { + map.cat_to_str_unchecked(cat).into_bound_py_any(py) }, AnyValue::Date(v) => { let date = date32_to_date(v); diff --git a/crates/polars-python/src/conversion/chunked_array.rs b/crates/polars-python/src/conversion/chunked_array.rs index a386e5c6a10d..0f68cf329cc5 100644 --- a/crates/polars-python/src/conversion/chunked_array.rs +++ b/crates/polars-python/src/conversion/chunked_array.rs @@ -66,6 +66,7 @@ impl<'py> IntoPyObject<'py> for &Wrap<&DurationChunked> { let time_unit = self.0.time_unit(); let iter = self .0 + .physical() .iter() .map(|opt_v| opt_v.map(|v| elapsed_offset_to_timedelta(v, time_unit))); PyList::new(py, iter) @@ -80,7 +81,7 @@ impl<'py> IntoPyObject<'py> for &Wrap<&DatetimeChunked> { fn into_pyobject(self, py: Python<'py>) -> Result { let time_zone = self.0.time_zone().as_ref(); let time_unit = self.0.time_unit(); - let iter = self.0.iter().map(|opt_v| { + let iter = self.0.physical().iter().map(|opt_v| { opt_v.map(|v| datetime_to_py_object(py, v, time_unit, time_zone).unwrap()) }); PyList::new(py, iter) @@ -112,7 +113,11 @@ impl<'py> IntoPyObject<'py> for &Wrap<&DateChunked> { type Error = PyErr; fn into_pyobject(self, py: Python<'py>) -> Result { - let iter = self.0.into_iter().map(|opt_v| opt_v.map(date32_to_date)); + let iter = self + .0 + .physical() + .into_iter() + .map(|opt_v| opt_v.map(date32_to_date)); PyList::new(py, iter) } } @@ -136,7 +141,7 @@ pub(crate) fn decimal_to_pyobject_iter<'py, 'a>( let py_scale = (-(ca.scale() as i32)).into_pyobject(py)?; // if we don't know precision, the only safe bet is to set it to 39 let py_precision = ca.precision().unwrap_or(39).into_pyobject(py)?; - Ok(ca.iter().map(move |opt_v| { + Ok(ca.physical().iter().map(move |opt_v| { opt_v.map(|v| { // TODO! use AnyValue so that we have a single impl. const N: usize = 3; diff --git a/crates/polars-python/src/conversion/mod.rs b/crates/polars-python/src/conversion/mod.rs index 37ce8cb81456..8eb8519c42a5 100644 --- a/crates/polars-python/src/conversion/mod.rs +++ b/crates/polars-python/src/conversion/mod.rs @@ -281,18 +281,19 @@ impl<'py> IntoPyObject<'py> for &Wrap { let class = pl.getattr(intern!(py, "Object"))?; class.call0() }, - DataType::Categorical(_, ordering) => { + DataType::Categorical(_, _) => { let class = pl.getattr(intern!(py, "Categorical"))?; - class.call1((Wrap(*ordering),)) - }, - DataType::Enum(rev_map, _) => { - // we should always have an initialized rev_map coming from rust - let categories = rev_map.as_ref().unwrap().get_categories(); + class.call1((Wrap(CategoricalOrdering::Lexical),)) + }, + DataType::Enum(_, mapping) => { + let categories = unsafe { + StringChunked::from_chunks( + PlSmallStr::from_static("category"), + vec![mapping.to_arrow(true)], + ) + }; let class = pl.getattr(intern!(py, "Enum"))?; - let s = - Series::from_arrow(PlSmallStr::from_static("category"), categories.to_boxed()) - .map_err(PyPolarsErr::from)?; - let series = to_series(py, s.into())?; + let series = to_series(py, categories.into_series().into())?; class.call1((series,)) }, DataType::Time => pl.getattr(intern!(py, "Time")), @@ -366,8 +367,8 @@ impl<'py> FromPyObject<'py> for Wrap { "Boolean" => DataType::Boolean, "String" => DataType::String, "Binary" => DataType::Binary, - "Categorical" => DataType::Categorical(None, Default::default()), - "Enum" => DataType::Enum(None, Default::default()), + "Categorical" => DataType::from_categories(Categories::global()), + "Enum" => DataType::from_frozen_categories(FrozenCategories::new([]).unwrap()), "Date" => DataType::Date, "Time" => DataType::Time, "Datetime" => DataType::Datetime(TimeUnit::Microseconds, None), @@ -401,17 +402,16 @@ impl<'py> FromPyObject<'py> for Wrap { "Boolean" => DataType::Boolean, "String" => DataType::String, "Binary" => DataType::Binary, - "Categorical" => { - let ordering = ob.getattr(intern!(py, "ordering")).unwrap(); - let ordering = ordering.extract::>()?.0; - DataType::Categorical(None, ordering) - }, + "Categorical" => DataType::from_categories(Categories::global()), "Enum" => { let categories = ob.getattr(intern!(py, "categories")).unwrap(); let s = get_series(&categories.as_borrowed())?; let ca = s.str().map_err(PyPolarsErr::from)?; let categories = ca.downcast_iter().next().unwrap().clone(); - create_enum_dtype(categories) + assert!(!categories.has_nulls()); + DataType::from_frozen_categories( + FrozenCategories::new(categories.values_iter()).unwrap(), + ) }, "Date" => DataType::Date, "Time" => DataType::Time, @@ -470,17 +470,17 @@ impl<'py> FromPyObject<'py> for Wrap { } } +enum CategoricalOrdering { + Lexical, +} + impl<'py> IntoPyObject<'py> for Wrap { type Target = PyString; type Output = Bound<'py, Self::Target>; type Error = Infallible; fn into_pyobject(self, py: Python<'py>) -> Result { - match self.0 { - CategoricalOrdering::Physical => "physical", - CategoricalOrdering::Lexical => "lexical", - } - .into_pyobject(py) + "lexical".into_pyobject(py) } } @@ -791,8 +791,14 @@ impl<'py> FromPyObject<'py> for Wrap> { impl<'py> FromPyObject<'py> for Wrap { fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { let parsed = match &*ob.extract::()? { - "physical" => CategoricalOrdering::Physical, "lexical" => CategoricalOrdering::Lexical, + "physical" => { + polars_warn!( + Deprecation, + "physical ordering is deprecated, will use lexical ordering instead" + ); + CategoricalOrdering::Lexical + }, v => { return Err(PyValueError::new_err(format!( "categorical `ordering` must be one of {{'physical', 'lexical'}}, got {v}", diff --git a/crates/polars-python/src/functions/io.rs b/crates/polars-python/src/functions/io.rs index 0bedac4b2874..010bc28a0020 100644 --- a/crates/polars-python/src/functions/io.rs +++ b/crates/polars-python/src/functions/io.rs @@ -1,16 +1,13 @@ use std::io::BufReader; -use arrow::array::Utf8ViewArray; #[cfg(any(feature = "ipc", feature = "parquet"))] use polars::prelude::ArrowSchema; -use polars_core::datatypes::create_enum_dtype; use pyo3::prelude::*; use pyo3::types::PyDict; use crate::conversion::Wrap; use crate::error::PyPolarsErr; use crate::file::{EitherRustPythonFile, get_either_file}; -use crate::prelude::ArrowDataType; #[cfg(feature = "ipc")] #[pyfunction] @@ -52,13 +49,7 @@ pub fn read_parquet_metadata(py: Python, py_f: PyObject) -> PyResult) -> PyResult<()> { for field in schema.iter_values() { - let dt = if field.is_enum() { - Wrap(create_enum_dtype(Utf8ViewArray::new_empty( - ArrowDataType::Utf8View, - ))) - } else { - Wrap(polars::prelude::DataType::from_arrow_field(field)) - }; + let dt = Wrap(polars::prelude::DataType::from_arrow_field(field)); dict.set_item(field.name.as_str(), &dt)?; } Ok(()) diff --git a/crates/polars-python/src/functions/string_cache.rs b/crates/polars-python/src/functions/string_cache.rs index fbe8c0432b05..11eb2821f6dc 100644 --- a/crates/polars-python/src/functions/string_cache.rs +++ b/crates/polars-python/src/functions/string_cache.rs @@ -1,32 +1,28 @@ -use polars_core::StringCacheHolder; use pyo3::prelude::*; #[pyfunction] pub fn enable_string_cache() { - polars_core::enable_string_cache() + // The string cache no longer exists. } #[pyfunction] pub fn disable_string_cache() { - polars_core::disable_string_cache() + // The string cache no longer exists. } #[pyfunction] pub fn using_string_cache() -> bool { - polars_core::using_string_cache() + // The string cache no longer exists. + true } #[pyclass] -pub struct PyStringCacheHolder { - _inner: StringCacheHolder, -} +pub struct PyStringCacheHolder; #[pymethods] impl PyStringCacheHolder { #[new] fn new() -> Self { - Self { - _inner: StringCacheHolder::hold(), - } + Self } } diff --git a/crates/polars-python/src/interop/numpy/to_numpy_series.rs b/crates/polars-python/src/interop/numpy/to_numpy_series.rs index b91c0ad5ca00..4f24cd099882 100644 --- a/crates/polars-python/src/interop/numpy/to_numpy_series.rs +++ b/crates/polars-python/src/interop/numpy/to_numpy_series.rs @@ -247,9 +247,11 @@ fn series_to_numpy_with_copy(py: Python<'_>, s: &Series, writable: bool) -> PyOb PyArray1::from_iter(py, values).into_py_any(py).unwrap() }, Categorical(_, _) | Enum(_, _) => { - let ca = s.categorical().unwrap(); - let values = ca.iter_str().map(|s| s.into_py_any(py).unwrap()); - PyArray1::from_iter(py, values).into_py_any(py).unwrap() + with_match_categorical_physical_type!(s.dtype().cat_physical().unwrap(), |$C| { + let ca = s.cat::<$C>().unwrap(); + let values = ca.iter_str().map(|s| s.into_py_any(py).unwrap()); + PyArray1::from_iter(py, values).into_py_any(py).unwrap() + }) }, Decimal(_, _) => { let ca = s.decimal().unwrap(); diff --git a/crates/polars-python/src/series/export.rs b/crates/polars-python/src/series/export.rs index 3ff2a45f233d..eb949308efbd 100644 --- a/crates/polars-python/src/series/export.rs +++ b/crates/polars-python/src/series/export.rs @@ -31,10 +31,11 @@ impl PySeries { DataType::Int128 => PyList::new(py, series.i128().map_err(PyPolarsErr::from)?)?, DataType::Float32 => PyList::new(py, series.f32().map_err(PyPolarsErr::from)?)?, DataType::Float64 => PyList::new(py, series.f64().map_err(PyPolarsErr::from)?)?, - DataType::Categorical(_, _) | DataType::Enum(_, _) => PyList::new( - py, - series.categorical().map_err(PyPolarsErr::from)?.iter_str(), - )?, + DataType::Categorical(_, _) | DataType::Enum(_, _) => { + with_match_categorical_physical_type!(series.dtype().cat_physical().unwrap(), |$C| { + PyList::new(py, series.cat::<$C>().unwrap().iter_str())? + }) + }, #[cfg(feature = "object")] DataType::Object(_) => { let v = PyList::empty(py); diff --git a/crates/polars-python/src/series/general.rs b/crates/polars-python/src/series/general.rs index b124c8f9251a..78775873356e 100644 --- a/crates/polars-python/src/series/general.rs +++ b/crates/polars-python/src/series/general.rs @@ -46,17 +46,15 @@ impl PySeries { } pub fn cat_uses_lexical_ordering(&self) -> PyResult { - let ca = self.series.categorical().map_err(PyPolarsErr::from)?; - Ok(ca.uses_lexical_ordering()) + Ok(true) } pub fn cat_is_local(&self) -> PyResult { - let ca = self.series.categorical().map_err(PyPolarsErr::from)?; - Ok(ca.get_rev_map().is_local()) + Ok(false) } - pub fn cat_to_local(&self, py: Python) -> PyResult { - py.enter_polars_series(|| Ok(self.series.categorical()?.to_local())) + pub fn cat_to_local(&self, _py: Python) -> PyResult { + Ok(self.clone()) } fn estimated_size(&self) -> usize { @@ -485,9 +483,7 @@ impl PySeries { let dicts = dtypes .iter() - .map(|(_, dt)| dt) - .zip(opts.iter()) - .map(|(dtype, opts)| get_row_encoding_context(&dtype.0, opts.is_ordered())) + .map(|(_, dtype)| get_row_encoding_context(&dtype.0)) .collect::>(); // Get the BinaryOffset array. @@ -596,9 +592,30 @@ impl_get!(get_i16, i16, i16); impl_get!(get_i32, i32, i32); impl_get!(get_i64, i64, i64); impl_get!(get_str, str, &str); -impl_get!(get_date, date, i32); -impl_get!(get_datetime, datetime, i64); -impl_get!(get_duration, duration, i64); + +macro_rules! impl_get_phys { + ($name:ident, $series_variant:ident, $type:ty) => { + #[pymethods] + impl PySeries { + fn $name(&self, index: i64) -> Option<$type> { + if let Ok(ca) = self.series.$series_variant() { + let index = if index < 0 { + (ca.len() as i64 + index) as usize + } else { + index as usize + }; + ca.physical().get(index) + } else { + None + } + } + } + }; +} + +impl_get_phys!(get_date, date, i32); +impl_get_phys!(get_datetime, datetime, i64); +impl_get_phys!(get_duration, duration, i64); #[cfg(test)] mod test { diff --git a/crates/polars-python/src/series/scatter.rs b/crates/polars-python/src/series/scatter.rs index 3df2ed812410..7138c3addac1 100644 --- a/crates/polars-python/src/series/scatter.rs +++ b/crates/polars-python/src/series/scatter.rs @@ -31,6 +31,25 @@ impl PySeries { fn scatter(mut s: Series, idx: &Series, values: &Series) -> Result { let logical_dtype = s.dtype().clone(); + let values = if logical_dtype.is_categorical() || logical_dtype.is_enum() { + if matches!( + values.dtype(), + DataType::Categorical(_, _) | DataType::Enum(_, _) | DataType::String + ) { + match values.strict_cast(&logical_dtype) { + Ok(values) => values, + Err(err) => return Err((s, err)), + } + } else { + return Err(( + s, + polars_err!(InvalidOperation: "invalid values dtype '{}' for scattering into dtype '{}'", values.dtype(), logical_dtype), + )); + } + } else { + values.clone() + }; + let idx = match polars_ops::prelude::convert_to_unsigned_index(idx, s.len()) { Ok(idx) => idx, Err(err) => return Err((s, err)), diff --git a/crates/polars-python/src/utils.rs b/crates/polars-python/src/utils.rs index 265beadc8dc8..0cf80394a8fe 100644 --- a/crates/polars-python/src/utils.rs +++ b/crates/polars-python/src/utils.rs @@ -31,8 +31,8 @@ macro_rules! apply_method_all_arrow_series2 { DataType::Int128 => $self.i128().unwrap().$method($($args),*), DataType::Float32 => $self.f32().unwrap().$method($($args),*), DataType::Float64 => $self.f64().unwrap().$method($($args),*), - DataType::Date => $self.date().unwrap().$method($($args),*), - DataType::Datetime(_, _) => $self.datetime().unwrap().$method($($args),*), + DataType::Date => $self.date().unwrap().physical().$method($($args),*), + DataType::Datetime(_, _) => $self.datetime().unwrap().physical().$method($($args),*), DataType::List(_) => $self.list().unwrap().$method($($args),*), DataType::Struct(_) => $self.struct_().unwrap().$method($($args),*), dt => panic!("dtype {:?} not supported", dt) diff --git a/crates/polars-row/Cargo.toml b/crates/polars-row/Cargo.toml index 0b60a8851549..2fa9c4ccb170 100644 --- a/crates/polars-row/Cargo.toml +++ b/crates/polars-row/Cargo.toml @@ -12,6 +12,7 @@ description = "Row encodings for the Polars DataFrame library" bitflags = { workspace = true } bytemuck = { workspace = true } polars-compute = { workspace = true, features = ["cast"] } +polars-dtype = { workspace = true } polars-error = { workspace = true } polars-utils = { workspace = true } diff --git a/crates/polars-row/src/decode.rs b/crates/polars-row/src/decode.rs index dd67b2637420..e2107b0ce078 100644 --- a/crates/polars-row/src/decode.rs +++ b/crates/polars-row/src/decode.rs @@ -3,12 +3,15 @@ use arrow::bitmap::{Bitmap, BitmapBuilder}; use arrow::buffer::Buffer; use arrow::datatypes::ArrowDataType; use arrow::offset::OffsetsBuffer; +use arrow::types::NativeType; +use polars_dtype::categorical::CatNative; use self::encode::fixed_size; -use self::row::{RowEncodingCategoricalContext, RowEncodingOptions}; +use self::row::{NewRowEncodingCategoricalContext, RowEncodingOptions}; use self::variable::utf8::decode_str; use super::*; -use crate::fixed::{boolean, decimal, numeric, packed_u32}; +use crate::fixed::numeric::{FixedLengthEncoding, FromSlice}; +use crate::fixed::{boolean, decimal, numeric}; use crate::variable::{binary, no_order, utf8}; /// Decode `rows` into a arrow format @@ -84,7 +87,7 @@ fn dtype_and_data_to_encoded_item_len( dict: Option<&RowEncodingContext>, ) -> usize { // Fast path: if the size is fixed, we can just divide. - if let Some(size) = fixed_size(dtype, dict) { + if let Some(size) = fixed_size(dtype, opt, dict) { return size; } @@ -171,7 +174,7 @@ fn rows_for_fixed_size_list<'a>( nested_rows.reserve(rows.len() * width); // Fast path: if the size is fixed, we can just divide. - if let Some(size) = fixed_size(dtype, dict) { + if let Some(size) = fixed_size(dtype, opt, dict) { for row in rows.iter_mut() { for i in 0..width { nested_rows.push(&row[(i * size)..][..size]); @@ -192,13 +195,19 @@ fn rows_for_fixed_size_list<'a>( } } -unsafe fn decode_lexical_cat( +unsafe fn decode_cat( rows: &mut [&[u8]], opt: RowEncodingOptions, - _values: &RowEncodingCategoricalContext, -) -> PrimitiveArray { - let mut s = numeric::decode_primitive::(rows, opt); - numeric::decode_primitive::(rows, opt).with_validity(s.take_validity()) + ctx: &NewRowEncodingCategoricalContext, +) -> PrimitiveArray +where + T::Encoded: FromSlice, +{ + if ctx.is_enum || !opt.is_ordered() { + numeric::decode_primitive::(rows, opt) + } else { + variable::utf8::decode_str_as_cat::(rows, opt, &ctx.mapping) + } } unsafe fn decode( @@ -208,6 +217,16 @@ unsafe fn decode( dtype: &ArrowDataType, ) -> ArrayRef { use ArrowDataType as D; + + if let Some(RowEncodingContext::Categorical(ctx)) = dict { + return match dtype { + D::UInt8 => decode_cat::(rows, opt, ctx).to_boxed(), + D::UInt16 => decode_cat::(rows, opt, ctx).to_boxed(), + D::UInt32 => decode_cat::(rows, opt, ctx).to_boxed(), + _ => unreachable!(), + }; + } + match dtype { D::Null => NullArray::new(D::Null, rows.len()).to_boxed(), D::Boolean => boolean::decode_bool(rows, opt).to_boxed(), @@ -329,23 +348,6 @@ unsafe fn decode( }, dt => { - if matches!(dt, D::UInt32) { - if let Some(dict) = dict { - return match dict { - RowEncodingContext::Categorical(ctx) => { - if ctx.is_enum { - packed_u32::decode(rows, opt, ctx.needed_num_bits()).to_boxed() - } else if ctx.lexical_sort_idxs.is_none() { - numeric::decode_primitive::(rows, opt).to_boxed() - } else { - decode_lexical_cat(rows, opt, ctx).to_boxed() - } - }, - _ => unreachable!(), - }; - } - } - if matches!(dt, D::Int128) { if let Some(dict) = dict { return match dict { diff --git a/crates/polars-row/src/encode.rs b/crates/polars-row/src/encode.rs index 683deb6b3474..16d39e5e6fe4 100644 --- a/crates/polars-row/src/encode.rs +++ b/crates/polars-row/src/encode.rs @@ -3,17 +3,21 @@ use std::mem::MaybeUninit; use arrow::array::{ Array, BinaryArray, BinaryViewArray, BooleanArray, FixedSizeListArray, ListArray, - PrimitiveArray, StructArray, Utf8Array, Utf8ViewArray, + PrimitiveArray, StructArray, UInt8Array, UInt16Array, UInt32Array, Utf8Array, Utf8ViewArray, }; use arrow::bitmap::Bitmap; use arrow::datatypes::ArrowDataType; -use arrow::types::Offset; +use arrow::types::{NativeType, Offset}; +use polars_dtype::categorical::CatNative; -use crate::fixed::{boolean, decimal, numeric, packed_u32}; +use crate::fixed::numeric::FixedLengthEncoding; +use crate::fixed::{boolean, decimal, numeric}; use crate::row::{RowEncodingOptions, RowsEncoded}; use crate::variable::{binary, no_order, utf8}; use crate::widths::RowWidths; -use crate::{ArrayRef, RowEncodingContext, with_match_arrow_primitive_type}; +use crate::{ + ArrayRef, NewRowEncodingCategoricalContext, RowEncodingContext, with_match_arrow_primitive_type, +}; pub fn convert_columns( num_rows: usize, @@ -255,7 +259,7 @@ fn get_encoder( let dtype = array.dtype(); // Fast path: column has a fixed size encoding - if let Some(size) = fixed_size(dtype, dict) { + if let Some(size) = fixed_size(dtype, opt, dict) { row_widths.push_constant(size); let state = match dtype { D::FixedSizeList(_, width) => { @@ -320,6 +324,64 @@ fn get_encoder( }; } + // Non-fixed-size categorical path. + if let Some(RowEncodingContext::Categorical(ctx)) = dict { + match dtype { + D::UInt8 => { + assert!(opt.is_ordered() && !ctx.is_enum); + let dc_array = array.as_any().downcast_ref::().unwrap(); + return striter_num_column_bytes( + array, + dc_array.values_iter().map(|cat| { + ctx.mapping + .cat_to_str(cat.as_cat()) + .map(|s| s.len()) + .unwrap_or(0) + }), + dc_array.validity(), + opt, + row_widths, + ); + }, + D::UInt16 => { + assert!(opt.is_ordered() && !ctx.is_enum); + let dc_array = array.as_any().downcast_ref::().unwrap(); + return striter_num_column_bytes( + array, + dc_array.values_iter().map(|cat| { + ctx.mapping + .cat_to_str(cat.as_cat()) + .map(|s| s.len()) + .unwrap_or(0) + }), + dc_array.validity(), + opt, + row_widths, + ); + }, + D::UInt32 => { + assert!(opt.is_ordered() && !ctx.is_enum); + let dc_array = array.as_any().downcast_ref::().unwrap(); + return striter_num_column_bytes( + array, + dc_array.values_iter().map(|cat| { + ctx.mapping + .cat_to_str(cat.as_cat()) + .map(|s| s.len()) + .unwrap_or(0) + }), + dc_array.validity(), + opt, + row_widths, + ); + }, + _ => { + // Fall through to below, should be nested type containing categorical. + debug_assert!(dtype.is_nested()) + }, + } + } + match dtype { D::FixedSizeList(_, width) => { let array = array.as_any().downcast_ref::().unwrap(); @@ -522,6 +584,26 @@ unsafe fn encode_bins<'a>( } } +unsafe fn encode_cat_array( + buffer: &mut [MaybeUninit], + keys: &PrimitiveArray, + opt: RowEncodingOptions, + ctx: &NewRowEncodingCategoricalContext, + offsets: &mut [usize], +) { + if ctx.is_enum || !opt.is_ordered() { + numeric::encode(buffer, keys, opt, offsets); + } else { + utf8::encode_str( + buffer, + keys.iter() + .map(|k| k.map(|&cat| ctx.mapping.cat_to_str_unchecked(cat.as_cat()))), + opt, + offsets, + ); + } +} + unsafe fn encode_flat_array( buffer: &mut [MaybeUninit], array: &dyn Array, @@ -531,6 +613,31 @@ unsafe fn encode_flat_array( ) { use ArrowDataType as D; + if let Some(RowEncodingContext::Categorical(ctx)) = dict { + match array.dtype() { + D::UInt8 => { + let keys = array.as_any().downcast_ref::>().unwrap(); + encode_cat_array(buffer, keys, opt, ctx, offsets); + }, + D::UInt16 => { + let keys = array + .as_any() + .downcast_ref::>() + .unwrap(); + encode_cat_array(buffer, keys, opt, ctx, offsets); + }, + D::UInt32 => { + let keys = array + .as_any() + .downcast_ref::>() + .unwrap(); + encode_cat_array(buffer, keys, opt, ctx, offsets); + }, + _ => unreachable!(), + }; + return; + } + match array.dtype() { D::Null => {}, D::Boolean => { @@ -539,44 +646,6 @@ unsafe fn encode_flat_array( }, dt if dt.is_numeric() => { - if matches!(dt, D::UInt32) { - if let Some(dict) = dict { - let keys = array - .as_any() - .downcast_ref::>() - .unwrap(); - - match dict { - RowEncodingContext::Categorical(ctx) => { - if ctx.is_enum { - packed_u32::encode( - buffer, - keys, - opt, - offsets, - ctx.needed_num_bits(), - ); - } else { - if let Some(lexical_sort_idxs) = &ctx.lexical_sort_idxs { - numeric::encode_iter( - buffer, - keys.iter() - .map(|k| k.map(|&k| lexical_sort_idxs[k as usize])), - opt, - offsets, - ); - } - - numeric::encode(buffer, keys, opt, offsets); - } - }, - - _ => unreachable!(), - } - return; - } - } - if matches!(dt, D::Int128) { if let Some(RowEncodingContext::Decimal(precision)) = dict { decimal::encode( @@ -855,30 +924,28 @@ unsafe fn encode_validity( } } -pub fn fixed_size(dtype: &ArrowDataType, dict: Option<&RowEncodingContext>) -> Option { +pub fn fixed_size( + dtype: &ArrowDataType, + opt: RowEncodingOptions, + dict: Option<&RowEncodingContext>, +) -> Option { use ArrowDataType as D; use numeric::FixedLengthEncoding; + + if let Some(RowEncodingContext::Categorical(ctx)) = dict { + // If ordered categorical (non-enum) we encode strings, otherwise physical. + if !ctx.is_enum && opt.is_ordered() { + return None; + } + } + Some(match dtype { D::Null => 0, D::Boolean => 1, D::UInt8 => u8::ENCODED_LEN, D::UInt16 => u16::ENCODED_LEN, - D::UInt32 => match dict { - None => u32::ENCODED_LEN, - Some(RowEncodingContext::Categorical(ctx)) => { - if ctx.is_enum { - packed_u32::len_from_num_bits(ctx.needed_num_bits()) - } else { - let mut num_bytes = u32::ENCODED_LEN; - if ctx.lexical_sort_idxs.is_some() { - num_bytes += u32::ENCODED_LEN; - } - num_bytes - } - }, - _ => return None, - }, + D::UInt32 => u32::ENCODED_LEN, D::UInt64 => u64::ENCODED_LEN, D::Int8 => i8::ENCODED_LEN, @@ -893,19 +960,19 @@ pub fn fixed_size(dtype: &ArrowDataType, dict: Option<&RowEncodingContext>) -> O D::Float32 => f32::ENCODED_LEN, D::Float64 => f64::ENCODED_LEN, - D::FixedSizeList(f, width) => 1 + width * fixed_size(f.dtype(), dict)?, + D::FixedSizeList(f, width) => 1 + width * fixed_size(f.dtype(), opt, dict)?, D::Struct(fs) => match dict { None => { let mut sum = 0; for f in fs { - sum += fixed_size(f.dtype(), None)?; + sum += fixed_size(f.dtype(), opt, None)?; } 1 + sum }, Some(RowEncodingContext::Struct(dicts)) => { let mut sum = 0; for (f, dict) in fs.iter().zip(dicts) { - sum += fixed_size(f.dtype(), dict.as_ref())?; + sum += fixed_size(f.dtype(), opt, dict.as_ref())?; } 1 + sum }, diff --git a/crates/polars-row/src/fixed/mod.rs b/crates/polars-row/src/fixed/mod.rs index 69fb28ebac75..3448eda90bc6 100644 --- a/crates/polars-row/src/fixed/mod.rs +++ b/crates/polars-row/src/fixed/mod.rs @@ -16,4 +16,3 @@ macro_rules! with_arms { pub mod boolean; pub mod decimal; pub mod numeric; -pub mod packed_u32; diff --git a/crates/polars-row/src/fixed/packed_u32.rs b/crates/polars-row/src/fixed/packed_u32.rs deleted file mode 100644 index 66866d9ca2a4..000000000000 --- a/crates/polars-row/src/fixed/packed_u32.rs +++ /dev/null @@ -1,181 +0,0 @@ -#![allow(unsafe_op_in_unsafe_fn)] -//! Row Encoding for Enum's -//! -//! This is a fixed-size encoding that takes a number of maximum bits that each value can take and -//! compresses such that a minimum amount of bytes are used for each value. - -use std::mem::MaybeUninit; - -use arrow::array::{Array, PrimitiveArray}; -use arrow::bitmap::BitmapBuilder; -use arrow::datatypes::ArrowDataType; -use polars_utils::slice::Slice2Uninit; - -use crate::row::RowEncodingOptions; - -pub fn len_from_num_bits(mut num_bits: usize) -> usize { - // 1 bit is used to indicate the nullability - num_bits += 1; - num_bits.div_ceil(8) -} - -macro_rules! with_constant_num_bytes { - ($num_bytes:ident, $block:block) => { - with_arms!($num_bytes, $block, (1, 2, 3, 4)) - }; -} - -pub unsafe fn encode( - buffer: &mut [MaybeUninit], - input: &PrimitiveArray, - opt: RowEncodingOptions, - offsets: &mut [usize], - - num_bits: usize, -) { - if input.null_count() == 0 { - unsafe { encode_slice(buffer, input.values(), opt, offsets, num_bits) } - } else { - unsafe { - encode_iter( - buffer, - input.iter().map(|v| v.copied()), - opt, - offsets, - num_bits, - ) - } - } -} - -fn get_invert_mask(opt: RowEncodingOptions, num_bits: usize) -> u32 { - if !opt.contains(RowEncodingOptions::DESCENDING) { - return 0; - } - - (1 << num_bits) - 1 -} - -pub unsafe fn encode_slice( - buffer: &mut [MaybeUninit], - input: &[u32], - opt: RowEncodingOptions, - offsets: &mut [usize], - - num_bits: usize, -) { - if num_bits == 32 { - super::numeric::encode_slice(buffer, input, opt, offsets); - return; - } - - let num_bytes = len_from_num_bits(num_bits); - let valid_mask = ((!opt.null_sentinel() & 0x80) as u32) << ((num_bytes - 1) * 8); - let invert_mask = get_invert_mask(opt, num_bits); - - with_constant_num_bytes!(num_bytes, { - for (offset, &v) in offsets.iter_mut().zip(input) { - let v = (v ^ invert_mask) | valid_mask; - unsafe { buffer.get_unchecked_mut(*offset..*offset + num_bytes) } - .copy_from_slice(v.to_be_bytes()[4 - num_bytes..].as_uninit()); - *offset += num_bytes; - } - }); -} - -pub unsafe fn encode_iter( - buffer: &mut [MaybeUninit], - input: impl Iterator>, - opt: RowEncodingOptions, - offsets: &mut [usize], - - num_bits: usize, -) { - if num_bits == 32 { - super::numeric::encode_iter(buffer, input, opt, offsets); - return; - } - - let num_bytes = len_from_num_bits(num_bits); - let null_value = (opt.null_sentinel() as u32) << ((num_bytes - 1) * 8); - let valid_mask = ((!opt.null_sentinel() & 0x80) as u32) << ((num_bytes - 1) * 8); - let invert_mask = get_invert_mask(opt, num_bits); - - with_constant_num_bytes!(num_bytes, { - for (offset, v) in offsets.iter_mut().zip(input) { - match v { - None => { - unsafe { buffer.get_unchecked_mut(*offset..*offset + num_bytes) } - .copy_from_slice(null_value.to_be_bytes()[4 - num_bytes..].as_uninit()); - }, - Some(v) => { - let v = (v ^ invert_mask) | valid_mask; - unsafe { buffer.get_unchecked_mut(*offset..*offset + num_bytes) } - .copy_from_slice(v.to_be_bytes()[4 - num_bytes..].as_uninit()); - }, - } - - *offset += num_bytes; - } - }); -} - -pub unsafe fn decode( - rows: &mut [&[u8]], - opt: RowEncodingOptions, - num_bits: usize, -) -> PrimitiveArray { - if num_bits == 32 { - return super::numeric::decode_primitive(rows, opt); - } - - let mut values = Vec::with_capacity(rows.len()); - let null_sentinel = opt.null_sentinel(); - - let num_bytes = len_from_num_bits(num_bits); - let mask = (1 << num_bits) - 1; - let invert_mask = get_invert_mask(opt, num_bits); - - with_constant_num_bytes!(num_bytes, { - values.extend( - rows.iter_mut() - .take_while(|row| *unsafe { row.get_unchecked(0) } != null_sentinel) - .map(|row| { - let mut value = 0u32; - let value_ref: &mut [u8; 4] = bytemuck::cast_mut(&mut value); - value_ref[4 - num_bytes..].copy_from_slice(row.get_unchecked(..num_bytes)); - - *row = &row[num_bytes..]; - ((value.swap_bytes()) & mask) ^ invert_mask - }), - ); - }); - - if values.len() == rows.len() { - return PrimitiveArray::new(ArrowDataType::UInt32, values.into(), None); - } - - let mut validity = BitmapBuilder::with_capacity(rows.len()); - validity.extend_constant(values.len(), true); - - let start_len = values.len(); - - with_constant_num_bytes!(num_bytes, { - values.extend(rows[start_len..].iter_mut().map(|row| { - validity.push(*unsafe { row.get_unchecked(0) } != null_sentinel); - - let mut value = 0u32; - let value_ref: &mut [u8; 4] = bytemuck::cast_mut(&mut value); - value_ref[4 - num_bytes..].copy_from_slice(row.get_unchecked(..num_bytes)); - - *row = &row[num_bytes..]; - ((value.swap_bytes()) & mask) ^ invert_mask - })); - }); - - PrimitiveArray::new( - ArrowDataType::UInt32, - values.into(), - validity.into_opt_validity(), - ) -} diff --git a/crates/polars-row/src/lib.rs b/crates/polars-row/src/lib.rs index 73bc61553a11..3107f2f652d9 100644 --- a/crates/polars-row/src/lib.rs +++ b/crates/polars-row/src/lib.rs @@ -284,4 +284,6 @@ pub use encode::{ convert_columns, convert_columns_amortized, convert_columns_amortized_no_order, convert_columns_no_order, }; -pub use row::{RowEncodingCategoricalContext, RowEncodingContext, RowEncodingOptions, RowsEncoded}; +pub use row::{ + NewRowEncodingCategoricalContext, RowEncodingContext, RowEncodingOptions, RowsEncoded, +}; diff --git a/crates/polars-row/src/row.rs b/crates/polars-row/src/row.rs index 8d74b46331d6..0d04d51155db 100644 --- a/crates/polars-row/src/row.rs +++ b/crates/polars-row/src/row.rs @@ -1,9 +1,12 @@ #![allow(unsafe_op_in_unsafe_fn)] +use std::sync::Arc; + use arrow::array::{BinaryArray, BinaryViewArray}; use arrow::datatypes::ArrowDataType; use arrow::ffi::mmap; use arrow::offset::{Offsets, OffsetsBuffer}; use polars_compute::cast::binary_to_binview; +use polars_dtype::categorical::CategoricalMapping; const BOOLEAN_TRUE_SENTINEL: u8 = 0x03; const BOOLEAN_FALSE_SENTINEL: u8 = 0x02; @@ -16,30 +19,15 @@ const BOOLEAN_FALSE_SENTINEL: u8 = 0x02; pub enum RowEncodingContext { Struct(Vec>), /// Categorical / Enum - Categorical(RowEncodingCategoricalContext), + Categorical(NewRowEncodingCategoricalContext), /// Decimal with given precision Decimal(usize), } #[derive(Debug, Clone)] -pub struct RowEncodingCategoricalContext { - /// The number of known categories in categorical / enum currently. - pub num_known_categories: u32, +pub struct NewRowEncodingCategoricalContext { pub is_enum: bool, - - /// The mapping from key to lexical sort index - pub lexical_sort_idxs: Option>, -} - -impl RowEncodingCategoricalContext { - pub fn needed_num_bits(&self) -> usize { - if self.num_known_categories == 0 { - 0 - } else { - let max_category_index = self.num_known_categories - 1; - (max_category_index.next_power_of_two().trailing_zeros() + 1) as usize - } - } + pub mapping: Arc, } bitflags::bitflags! { diff --git a/crates/polars-row/src/variable/utf8.rs b/crates/polars-row/src/variable/utf8.rs index cf6ad7c6aca6..24759b473da0 100644 --- a/crates/polars-row/src/variable/utf8.rs +++ b/crates/polars-row/src/variable/utf8.rs @@ -10,8 +10,10 @@ //! This allows the string row encoding to have a constant 1 byte overhead. use std::mem::MaybeUninit; -use arrow::array::{MutableBinaryViewArray, Utf8ViewArray}; +use arrow::array::{MutableBinaryViewArray, PrimitiveArray, Utf8ViewArray}; use arrow::bitmap::BitmapBuilder; +use arrow::types::NativeType; +use polars_dtype::categorical::{CatNative, CategoricalMapping}; use crate::row::RowEncodingOptions; @@ -127,3 +129,69 @@ pub unsafe fn decode_str(rows: &mut [&[u8]], opt: RowEncodingOptions) -> Utf8Vie let out: Utf8ViewArray = array.into(); out.with_validity(validity.into_opt_validity()) } + +/// The same as decode_str but inserts it into the given mapping, translating +/// it to physical type T. +pub unsafe fn decode_str_as_cat( + rows: &mut [&[u8]], + opt: RowEncodingOptions, + mapping: &CategoricalMapping, +) -> PrimitiveArray { + let null_sentinel = opt.null_sentinel(); + let descending = opt.contains(RowEncodingOptions::DESCENDING); + + let num_rows = rows.len(); + let mut out = Vec::::with_capacity(rows.len()); + + let mut scratch = Vec::new(); + for row in rows.iter_mut() { + let sentinel = *unsafe { row.get_unchecked(0) }; + if sentinel == null_sentinel { + *row = unsafe { row.get_unchecked(1..) }; + break; + } + + scratch.clear(); + if descending { + scratch.extend(row.iter().take_while(|&b| *b != 0xFE).map(|&v| !v - 2)); + } else { + scratch.extend(row.iter().take_while(|&b| *b != 0x01).map(|&v| v - 2)); + } + + *row = row.get_unchecked(1 + scratch.len()..); + let s = unsafe { std::str::from_utf8_unchecked(&scratch) }; + out.push(T::from_cat(mapping.insert_cat(s).unwrap())); + } + + if out.len() == num_rows { + return PrimitiveArray::from_vec(out); + } + + let mut validity = BitmapBuilder::with_capacity(num_rows); + validity.extend_constant(out.len(), true); + validity.push(false); + out.push(T::zeroed()); + + for row in rows[out.len()..].iter_mut() { + let sentinel = *unsafe { row.get_unchecked(0) }; + validity.push(sentinel != null_sentinel); + if sentinel == null_sentinel { + *row = unsafe { row.get_unchecked(1..) }; + out.push(T::zeroed()); + continue; + } + + scratch.clear(); + if descending { + scratch.extend(row.iter().take_while(|&b| *b != 0xFE).map(|&v| !v - 2)); + } else { + scratch.extend(row.iter().take_while(|&b| *b != 0x01).map(|&v| v - 2)); + } + + *row = row.get_unchecked(1 + scratch.len()..); + let s = unsafe { std::str::from_utf8_unchecked(&scratch) }; + out.push(T::from_cat(mapping.insert_cat(s).unwrap())); + } + + PrimitiveArray::from_vec(out).with_validity(validity.into_opt_validity()) +} diff --git a/crates/polars-stream/src/nodes/io_sources/csv.rs b/crates/polars-stream/src/nodes/io_sources/csv.rs index b2730d426dea..955e9bd6fe46 100644 --- a/crates/polars-stream/src/nodes/io_sources/csv.rs +++ b/crates/polars-stream/src/nodes/io_sources/csv.rs @@ -2,8 +2,6 @@ use std::ops::Range; use std::sync::Arc; use async_trait::async_trait; -#[cfg(feature = "dtype-categorical")] -use polars_core::StringCacheHolder; use polars_core::prelude::{Column, Field}; use polars_core::schema::{SchemaExt, SchemaRef}; use polars_error::{PolarsResult, polars_bail, polars_err, polars_warn}; @@ -590,8 +588,6 @@ struct ChunkReader { reader_schema: SchemaRef, parse_options: Arc, fields_to_cast: Vec, - #[cfg(feature = "dtype-categorical")] - _cat_lock: Option, ignore_errors: bool, projection: Vec, null_values: Option, @@ -610,10 +606,7 @@ impl ChunkReader { alt_count_lines: Option>, ) -> PolarsResult { let mut fields_to_cast: Vec = options.fields_to_cast.clone(); - let has_categorical = prepare_csv_schema(&mut reader_schema, &mut fields_to_cast)?; - - #[cfg(feature = "dtype-categorical")] - let _cat_lock = has_categorical.then(polars_core::StringCacheHolder::hold); + prepare_csv_schema(&mut reader_schema, &mut fields_to_cast)?; let parse_options = options.parse_options.clone(); @@ -632,8 +625,6 @@ impl ChunkReader { reader_schema, parse_options, fields_to_cast, - #[cfg(feature = "dtype-categorical")] - _cat_lock, ignore_errors: options.ignore_errors, projection, null_values, diff --git a/crates/polars-stream/src/nodes/io_sources/ndjson/chunk_reader.rs b/crates/polars-stream/src/nodes/io_sources/ndjson/chunk_reader.rs index ea6fbc8aafae..52239aa28647 100644 --- a/crates/polars-stream/src/nodes/io_sources/ndjson/chunk_reader.rs +++ b/crates/polars-stream/src/nodes/io_sources/ndjson/chunk_reader.rs @@ -10,8 +10,6 @@ use crate::nodes::compute_node_prelude::*; #[derive(Default)] pub(super) struct ChunkReader { projected_schema: SchemaRef, - #[cfg(feature = "dtype-categorical")] - _cat_lock: Option, ignore_errors: bool, } @@ -22,16 +20,8 @@ impl ChunkReader { ) -> PolarsResult { let projected_schema = projected_schema.clone(); - #[cfg(feature = "dtype-categorical")] - let _cat_lock = projected_schema - .iter_values() - .any(|x| x.is_categorical()) - .then(polars_core::StringCacheHolder::hold); - Ok(Self { projected_schema, - #[cfg(feature = "dtype-categorical")] - _cat_lock, ignore_errors: options.ignore_errors, }) } diff --git a/crates/polars-testing/src/asserts/frame.rs b/crates/polars-testing/src/asserts/frame.rs index c34427c46818..93f9e44820a3 100644 --- a/crates/polars-testing/src/asserts/frame.rs +++ b/crates/polars-testing/src/asserts/frame.rs @@ -55,7 +55,6 @@ macro_rules! assert_dataframe_equal { mod tests { #[allow(unused_imports)] use polars_core::prelude::*; - use polars_core::{disable_string_cache, enable_string_cache}; // Testing default struct implementation #[test] @@ -534,46 +533,21 @@ mod tests { #[test] #[should_panic(expected = "dtypes do not match")] fn test_dataframe_categorical_as_string_mismatch() { - enable_string_cache(); - - let mut categorical = Series::new("categories".into(), &["a", "b", "c", "a"]); - categorical = categorical - .cast(&DataType::Categorical(None, Default::default())) - .unwrap(); - - let df1 = DataFrame::new(vec![categorical.into()]).unwrap(); - - let df2 = DataFrame::new(vec![ - Series::new("categories".into(), &["a", "b", "c", "a"]).into(), - ]) - .unwrap(); - - assert_dataframe_equal!(&df1, &df2); - - disable_string_cache(); - } - - #[test] - fn test_dataframe_categorical_as_string_match() { - enable_string_cache(); - let mut categorical1 = Series::new("categories".into(), &["a", "b", "c", "a"]); categorical1 = categorical1 - .cast(&DataType::Categorical(None, Default::default())) + .cast(&DataType::from_categories(Categories::global())) .unwrap(); let df1 = DataFrame::new(vec![categorical1.into()]).unwrap(); let mut categorical2 = Series::new("categories".into(), &["a", "b", "c", "a"]); categorical2 = categorical2 - .cast(&DataType::Categorical(None, Default::default())) + .cast(&DataType::from_categories(Categories::global())) .unwrap(); let df2 = DataFrame::new(vec![categorical2.into()]).unwrap(); let options = crate::asserts::DataFrameEqualOptions::default().with_categorical_as_str(true); assert_dataframe_equal!(&df1, &df2, options); - - disable_string_cache(); } // Testing nested types diff --git a/crates/polars-testing/src/asserts/series.rs b/crates/polars-testing/src/asserts/series.rs index d1be349f5d09..906b5239852a 100644 --- a/crates/polars-testing/src/asserts/series.rs +++ b/crates/polars-testing/src/asserts/series.rs @@ -50,7 +50,6 @@ macro_rules! assert_series_equal { #[cfg(test)] mod tests { use polars_core::prelude::*; - use polars_core::{disable_string_cache, enable_string_cache}; // Testing default struct implementation #[test] @@ -314,27 +313,23 @@ mod tests { #[test] #[should_panic(expected = "exact value mismatch")] fn test_series_categorical_mismatch() { - enable_string_cache(); - let s1 = Series::new("".into(), &["apple", "banana", "cherry"]) - .cast(&DataType::Categorical(None, Default::default())) + .cast(&DataType::from_categories(Categories::global())) .unwrap(); let s2 = Series::new("".into(), &["apple", "orange", "cherry"]) - .cast(&DataType::Categorical(None, Default::default())) + .cast(&DataType::from_categories(Categories::global())) .unwrap(); assert_series_equal!(&s1, &s2); - - disable_string_cache(); } #[test] fn test_series_categorical_match() { let s1 = Series::new("".into(), &["apple", "banana", "cherry"]) - .cast(&DataType::Categorical(None, Default::default())) + .cast(&DataType::from_categories(Categories::global())) .unwrap(); let s2 = Series::new("".into(), &["apple", "banana", "cherry"]) - .cast(&DataType::Categorical(None, Default::default())) + .cast(&DataType::from_categories(Categories::global())) .unwrap(); assert_series_equal!(&s1, &s2); @@ -344,10 +339,10 @@ mod tests { #[should_panic(expected = "exact value mismatch")] fn test_series_categorical_str_mismatch() { let s1 = Series::new("".into(), &["apple", "banana", "cherry"]) - .cast(&DataType::Categorical(None, Default::default())) + .cast(&DataType::from_categories(Categories::global())) .unwrap(); let s2 = Series::new("".into(), &["apple", "orange", "cherry"]) - .cast(&DataType::Categorical(None, Default::default())) + .cast(&DataType::from_categories(Categories::global())) .unwrap(); let options = crate::asserts::SeriesEqualOptions::default().with_categorical_as_str(true); @@ -358,10 +353,10 @@ mod tests { #[test] fn test_series_categorical_str_match() { let s1 = Series::new("".into(), &["apple", "banana", "cherry"]) - .cast(&DataType::Categorical(None, Default::default())) + .cast(&DataType::from_categories(Categories::global())) .unwrap(); let s2 = Series::new("".into(), &["apple", "banana", "cherry"]) - .cast(&DataType::Categorical(None, Default::default())) + .cast(&DataType::from_categories(Categories::global())) .unwrap(); let options = crate::asserts::SeriesEqualOptions::default().with_categorical_as_str(true); diff --git a/crates/polars-time/src/chunkedarray/date.rs b/crates/polars-time/src/chunkedarray/date.rs index 0ce083c997fb..f1f15475417b 100644 --- a/crates/polars-time/src/chunkedarray/date.rs +++ b/crates/polars-time/src/chunkedarray/date.rs @@ -18,20 +18,22 @@ pub trait DateMethods: AsDate { /// Returns the year number in the calendar date. fn year(&self) -> Int32Chunked { let ca = self.as_date(); - ca.apply_kernel_cast::(&date_to_year) + ca.physical().apply_kernel_cast::(&date_to_year) } /// Extract year from underlying NaiveDate representation. /// Returns whether the year is a leap year. fn is_leap_year(&self) -> BooleanChunked { let ca = self.as_date(); - ca.apply_kernel_cast::(&date_to_is_leap_year) + ca.physical() + .apply_kernel_cast::(&date_to_is_leap_year) } /// This year number might not match the calendar year number. fn iso_year(&self) -> Int32Chunked { let ca = self.as_date(); - ca.apply_kernel_cast::(&date_to_iso_year) + ca.physical() + .apply_kernel_cast::(&date_to_iso_year) } /// Extract month from underlying NaiveDateTime representation. @@ -47,14 +49,15 @@ pub trait DateMethods: AsDate { /// The return value ranges from 1 to 12. fn month(&self) -> Int8Chunked { let ca = self.as_date(); - ca.apply_kernel_cast::(&date_to_month) + ca.physical().apply_kernel_cast::(&date_to_month) } /// Returns the ISO week number starting from 1. /// The return value ranges from 1 to 53. (The last week of year differs by years.) fn week(&self) -> Int8Chunked { let ca = self.as_date(); - ca.apply_kernel_cast::(&date_to_iso_week) + ca.physical() + .apply_kernel_cast::(&date_to_iso_week) } /// Extract day from underlying NaiveDate representation. @@ -63,7 +66,7 @@ pub trait DateMethods: AsDate { /// The return value ranges from 1 to 31. (The last day of month differs by months.) fn day(&self) -> Int8Chunked { let ca = self.as_date(); - ca.apply_kernel_cast::(&date_to_day) + ca.physical().apply_kernel_cast::(&date_to_day) } /// Returns the day of year starting from 1. @@ -71,7 +74,8 @@ pub trait DateMethods: AsDate { /// The return value ranges from 1 to 366. (The last day of year differs by years.) fn ordinal(&self) -> Int16Chunked { let ca = self.as_date(); - ca.apply_kernel_cast::(&date_to_ordinal) + ca.physical() + .apply_kernel_cast::(&date_to_ordinal) } fn parse_from_str_slice(name: PlSmallStr, v: &[&str], fmt: &str) -> DateChunked; diff --git a/crates/polars-time/src/chunkedarray/datetime.rs b/crates/polars-time/src/chunkedarray/datetime.rs index 9149d03e8181..a7911d8e701e 100644 --- a/crates/polars-time/src/chunkedarray/datetime.rs +++ b/crates/polars-time/src/chunkedarray/datetime.rs @@ -15,7 +15,7 @@ fn cast_and_apply< func: F, ) -> ChunkedArray { let dtype = ca.dtype().to_arrow(CompatLevel::newest()); - let chunks = ca.downcast_iter().map(|arr| { + let chunks = ca.physical().downcast_iter().map(|arr| { let arr = cast( arr, &dtype, @@ -57,7 +57,7 @@ pub trait DatetimeMethods: AsDatetime { .expect("Removing time zone is infallible"), _ => ca, }; - ca_local.apply_kernel_cast::(&f) + ca_local.physical().apply_kernel_cast::(&f) } fn iso_year(&self) -> Int32Chunked { @@ -78,7 +78,7 @@ pub trait DatetimeMethods: AsDatetime { .expect("Removing time zone is infallible"), _ => ca, }; - ca_local.apply_kernel_cast::(&f) + ca_local.physical().apply_kernel_cast::(&f) } /// Extract quarter from underlying NaiveDateTime representation. @@ -162,7 +162,7 @@ pub trait DatetimeMethods: AsDatetime { .expect("Removing time zone is infallible"), _ => ca, }; - ca_local.apply_kernel_cast::(&f) + ca_local.physical().apply_kernel_cast::(&f) } fn parse_from_str_slice( @@ -299,7 +299,7 @@ mod test { 1_441_497_364_000_000_000, 1_356_048_000_000_000_000 ], - dt.cont_slice().unwrap() + dt.physical().cont_slice().unwrap() ); } } diff --git a/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs b/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs index 7b82b1dc2f8a..2c08b05a5bb4 100644 --- a/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs +++ b/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs @@ -118,7 +118,7 @@ where let func = rolling_agg_fn_dynamic; let out: ArrayRef = if by_is_sorted { let arr = ca.downcast_iter().next().unwrap(); - let by_values = by.cont_slice().unwrap(); + let by_values = by.physical().cont_slice().unwrap(); let values = arr.values().as_slice(); func( values, @@ -132,9 +132,9 @@ where None, )? } else { - let sorting_indices = by.arg_sort(Default::default()); + let sorting_indices = by.physical().arg_sort(Default::default()); let ca = unsafe { ca.take_unchecked(&sorting_indices) }; - let by = unsafe { by.take_unchecked(&sorting_indices) }; + let by = unsafe { by.physical().take_unchecked(&sorting_indices) }; let arr = ca.downcast_iter().next().unwrap(); let by_values = by.cont_slice().unwrap(); let values = arr.values().as_slice(); diff --git a/crates/polars-time/src/chunkedarray/time.rs b/crates/polars-time/src/chunkedarray/time.rs index ea97abcac0e6..c62e0ea1c435 100644 --- a/crates/polars-time/src/chunkedarray/time.rs +++ b/crates/polars-time/src/chunkedarray/time.rs @@ -27,26 +27,29 @@ impl TimeMethods for TimeChunked { /// Extract hour from underlying NaiveDateTime representation. /// Returns the hour number from 0 to 23. fn hour(&self) -> Int8Chunked { - self.apply_kernel_cast::(&time_to_hour) + self.physical().apply_kernel_cast::(&time_to_hour) } /// Extract minute from underlying NaiveDateTime representation. /// Returns the minute number from 0 to 59. fn minute(&self) -> Int8Chunked { - self.apply_kernel_cast::(&time_to_minute) + self.physical() + .apply_kernel_cast::(&time_to_minute) } /// Extract second from underlying NaiveDateTime representation. /// Returns the second number from 0 to 59. fn second(&self) -> Int8Chunked { - self.apply_kernel_cast::(&time_to_second) + self.physical() + .apply_kernel_cast::(&time_to_second) } /// Extract second from underlying NaiveDateTime representation. /// Returns the number of nanoseconds since the whole non-leap second. /// The range from 1,000,000,000 to 1,999,999,999 represents the leap second. fn nanosecond(&self) -> Int32Chunked { - self.apply_kernel_cast::(&time_to_nanosecond) + self.physical() + .apply_kernel_cast::(&time_to_nanosecond) } fn parse_from_str_slice(name: PlSmallStr, v: &[&str], fmt: &str) -> TimeChunked { diff --git a/crates/polars-time/src/date_range.rs b/crates/polars-time/src/date_range.rs index 00913c0de299..6470e5460ea3 100644 --- a/crates/polars-time/src/date_range.rs +++ b/crates/polars-time/src/date_range.rs @@ -58,7 +58,7 @@ pub fn datetime_range_impl( _ => out.into_datetime(tu, None), }; - out.set_sorted_flag(IsSorted::Ascending); + out.physical_mut().set_sorted_flag(IsSorted::Ascending); Ok(out) } @@ -89,7 +89,7 @@ pub fn time_range_impl( ) .into_time(); - out.set_sorted_flag(IsSorted::Ascending); + out.physical_mut().set_sorted_flag(IsSorted::Ascending); Ok(out) } diff --git a/crates/polars-time/src/group_by/dynamic.rs b/crates/polars-time/src/group_by/dynamic.rs index 83fd1b9bb80b..17a983478230 100644 --- a/crates/polars-time/src/group_by/dynamic.rs +++ b/crates/polars-time/src/group_by/dynamic.rs @@ -310,7 +310,7 @@ impl Wrap<&DataFrame> { }; let groups = if group_by.is_none() { - let vals = dt.downcast_iter().next().unwrap(); + let vals = dt.physical().downcast_iter().next().unwrap(); let ts = vals.values().as_slice(); let (groups, lower, upper) = group_by_windows( w, @@ -328,7 +328,7 @@ impl Wrap<&DataFrame> { rolling: false, }) } else { - let vals = dt.downcast_iter().next().unwrap(); + let vals = dt.physical().downcast_iter().next().unwrap(); let ts = vals.values().as_slice(); let groups = group_by.as_ref().unwrap(); @@ -383,7 +383,7 @@ impl Wrap<&DataFrame> { // upper column remain/are sorted let dt = unsafe { dt.clone().into_series().agg_first(&groups) }; - let mut dt = dt.datetime().unwrap().as_ref().clone(); + let mut dt = dt.datetime().unwrap().physical().clone(); let lower = lower_bound.map(|lower| Int64Chunked::new_vec(PlSmallStr::from_static(LB_NAME), lower)); @@ -438,7 +438,7 @@ impl Wrap<&DataFrame> { // so we can set this such that downstream code has this info dt.set_sorted_flag(IsSorted::Ascending); let dt = dt.datetime().unwrap(); - let vals = dt.downcast_iter().next().unwrap(); + let vals = dt.physical().downcast_iter().next().unwrap(); let ts = vals.values().as_slice(); PolarsResult::Ok(GroupsType::Slice { groups: group_by_values( @@ -453,7 +453,7 @@ impl Wrap<&DataFrame> { }) } else { let dt = dt.datetime().unwrap(); - let vals = dt.downcast_iter().next().unwrap(); + let vals = dt.physical().downcast_iter().next().unwrap(); let ts = vals.values().as_slice(); let groups = group_by.unwrap(); diff --git a/crates/polars-time/src/offset_by.rs b/crates/polars-time/src/offset_by.rs index 7ea64b2ab44d..4c2b4a6b2a42 100644 --- a/crates/polars-time/src/offset_by.rs +++ b/crates/polars-time/src/offset_by.rs @@ -44,15 +44,16 @@ fn apply_offsets_to_datetime( TimeUnit::Microseconds => Duration::add_us, TimeUnit::Nanoseconds => Duration::add_ns, }; - broadcast_try_binary_elementwise(datetime, offsets, |timestamp_opt, offset_opt| match ( - timestamp_opt, - offset_opt, - ) { - (Some(timestamp), Some(offset)) => { - offset_fn(&Duration::try_parse(offset)?, timestamp, time_zone).map(Some) + broadcast_try_binary_elementwise( + datetime.physical(), + offsets, + |timestamp_opt, offset_opt| match (timestamp_opt, offset_opt) { + (Some(timestamp), Some(offset)) => { + offset_fn(&Duration::try_parse(offset)?, timestamp, time_zone).map(Some) + }, + _ => Ok(None), }, - _ => Ok(None), - }) + ) }, } } diff --git a/crates/polars-time/src/replace.rs b/crates/polars-time/src/replace.rs index d02be3467c36..621117059cbd 100644 --- a/crates/polars-time/src/replace.rs +++ b/crates/polars-time/src/replace.rs @@ -139,7 +139,7 @@ pub fn replace_datetime( // Ensure nulls are propagated. if ca.has_nulls() { - out.merge_validities(ca.chunks()); + out.physical_mut().merge_validities(ca.physical().chunks()); } Ok(out) @@ -186,7 +186,7 @@ pub fn replace_date( // Ensure nulls are propagated. if ca.has_nulls() { - out.merge_validities(ca.chunks()); + out.physical_mut().merge_validities(ca.physical().chunks()); } Ok(out) diff --git a/crates/polars-time/src/round.rs b/crates/polars-time/src/round.rs index 551960396726..d27a3ff7f192 100644 --- a/crates/polars-time/src/round.rs +++ b/crates/polars-time/src/round.rs @@ -41,20 +41,21 @@ impl PolarsRound for DatetimeChunked { TimeUnit::Nanoseconds => every_parsed.duration_ns(), }; return Ok(self + .physical() .apply_values(|t| fast_round(t, every)) .into_datetime(self.time_unit(), time_zone.clone())); } else { let w = Window::new(every_parsed, every_parsed, offset); let out = match self.time_unit() { - TimeUnit::Milliseconds => { - self.try_apply_nonnull_values_generic(|t| w.round_ms(t, tz)) - }, - TimeUnit::Microseconds => { - self.try_apply_nonnull_values_generic(|t| w.round_us(t, tz)) - }, - TimeUnit::Nanoseconds => { - self.try_apply_nonnull_values_generic(|t| w.round_ns(t, tz)) - }, + TimeUnit::Milliseconds => self + .physical() + .try_apply_nonnull_values_generic(|t| w.round_ms(t, tz)), + TimeUnit::Microseconds => self + .physical() + .try_apply_nonnull_values_generic(|t| w.round_us(t, tz)), + TimeUnit::Nanoseconds => self + .physical() + .try_apply_nonnull_values_generic(|t| w.round_ns(t, tz)), }; return Ok(out?.into_datetime(self.time_unit(), self.time_zone().clone())); } @@ -80,22 +81,23 @@ impl PolarsRound for DatetimeChunked { TimeUnit::Milliseconds => Window::round_ms, }; - let out = broadcast_try_binary_elementwise(self, every, |opt_timestamp, opt_every| match ( - opt_timestamp, - opt_every, - ) { - (Some(timestamp), Some(every)) => { - let every = *duration_cache.get_or_insert_with(every, Duration::parse); + let out = broadcast_try_binary_elementwise( + self.physical(), + every, + |opt_timestamp, opt_every| match (opt_timestamp, opt_every) { + (Some(timestamp), Some(every)) => { + let every = *duration_cache.get_or_insert_with(every, Duration::parse); - if every.negative { - polars_bail!(ComputeError: "cannot round a Datetime to a negative duration") - } + if every.negative { + polars_bail!(ComputeError: "cannot round a Datetime to a negative duration") + } - let w = Window::new(every, every, offset); - func(&w, timestamp, tz).map(Some) + let w = Window::new(every, every, offset); + func(&w, timestamp, tz).map(Some) + }, + _ => Ok(None), }, - _ => Ok(None), - }); + ); Ok(out?.into_datetime(self.time_unit(), self.time_zone().clone())) } } @@ -111,7 +113,7 @@ impl PolarsRound for DateChunked { polars_bail!(ComputeError: "cannot round a Date to a negative duration") } let w = Window::new(every, every, offset); - self.try_apply_nonnull_values_generic(|t| { + self.physical().try_apply_nonnull_values_generic(|t| { Ok( (w.round_ms(MILLISECONDS_IN_DAY * t as i64, None)? / MILLISECONDS_IN_DAY) as i32, @@ -128,7 +130,7 @@ impl PolarsRound for DateChunked { self.len(), every.len() ); - broadcast_try_binary_elementwise(self, every, |opt_t, opt_every| { + broadcast_try_binary_elementwise(self.physical(), every, |opt_t, opt_every| { // A sqrt(n) cache is not too small, not too large. let mut duration_cache = LruCache::with_capacity((every.len() as f64).sqrt() as usize); diff --git a/crates/polars-time/src/series/mod.rs b/crates/polars-time/src/series/mod.rs index 4c996f295c60..b075d8383aad 100644 --- a/crates/polars-time/src/series/mod.rs +++ b/crates/polars-time/src/series/mod.rs @@ -1,4 +1,4 @@ -use std::ops::{Deref, Div}; +use std::ops::Div; use arrow::temporal_conversions::{MICROSECONDS_IN_DAY, MILLISECONDS_IN_DAY, NANOSECONDS_IN_DAY}; use polars_core::prelude::arity::unary_elementwise_values; @@ -92,7 +92,7 @@ pub trait TemporalMethods: AsSeries { // Closed formula to find weekday, no need to go via Chrono. // The 4 comes from the fact that 1970-01-01 was a Thursday. // We do an extra `+ 7` then `% 7` to ensure the result is non-negative. - unary_elementwise_values(ca, |t| (((t - 4) % 7 + 7) % 7 + 1) as i8) + unary_elementwise_values(ca.physical(), |t| (((t - 4) % 7 + 7) % 7 + 1) as i8) }), #[cfg(feature = "dtype-datetime")] DataType::Datetime(time_unit, time_zone) => s.datetime().map(|ca| { @@ -106,7 +106,7 @@ pub trait TemporalMethods: AsSeries { TimeUnit::Microseconds => MICROSECONDS_IN_DAY, TimeUnit::Nanoseconds => NANOSECONDS_IN_DAY, }; - unary_elementwise_values(ca, |t| { + unary_elementwise_values(ca.physical(), |t| { let t = t / divisor - ((t < 0 && t % divisor != 0) as i64); (((t - 4) % 7 + 7) % 7 + 1) as i8 }) @@ -298,7 +298,7 @@ pub trait TemporalMethods: AsSeries { polars_bail!(opq = timestamp, s.dtype()); } else { s.cast(&DataType::Datetime(tu, None)) - .map(|s| s.datetime().unwrap().deref().clone()) + .map(|s| s.datetime().unwrap().physical().clone()) } } } diff --git a/crates/polars-time/src/truncate.rs b/crates/polars-time/src/truncate.rs index d117cfd06e40..5f6cef948d43 100644 --- a/crates/polars-time/src/truncate.rs +++ b/crates/polars-time/src/truncate.rs @@ -48,20 +48,21 @@ impl PolarsTruncate for DatetimeChunked { TimeUnit::Nanoseconds => every_parsed.duration_ns(), }; return Ok(self + .physical() .apply_values(|t| fast_truncate(t, every)) .into_datetime(self.time_unit(), time_zone.clone())); } else { let w = Window::new(every_parsed, every_parsed, offset); let out = match self.time_unit() { - TimeUnit::Milliseconds => { - self.try_apply_nonnull_values_generic(|t| w.truncate_ms(t, tz)) - }, - TimeUnit::Microseconds => { - self.try_apply_nonnull_values_generic(|t| w.truncate_us(t, tz)) - }, - TimeUnit::Nanoseconds => { - self.try_apply_nonnull_values_generic(|t| w.truncate_ns(t, tz)) - }, + TimeUnit::Milliseconds => self + .physical() + .try_apply_nonnull_values_generic(|t| w.truncate_ms(t, tz)), + TimeUnit::Microseconds => self + .physical() + .try_apply_nonnull_values_generic(|t| w.truncate_us(t, tz)), + TimeUnit::Nanoseconds => self + .physical() + .try_apply_nonnull_values_generic(|t| w.truncate_ns(t, tz)), }; return Ok(out?.into_datetime(self.time_unit(), self.time_zone().clone())); } @@ -80,22 +81,23 @@ impl PolarsTruncate for DatetimeChunked { TimeUnit::Milliseconds => Window::truncate_ms, }; - let out = broadcast_try_binary_elementwise(self, every, |opt_timestamp, opt_every| match ( - opt_timestamp, - opt_every, - ) { - (Some(timestamp), Some(every)) => { - let every = *duration_cache.get_or_insert_with(every, Duration::parse); + let out = broadcast_try_binary_elementwise( + self.physical(), + every, + |opt_timestamp, opt_every| match (opt_timestamp, opt_every) { + (Some(timestamp), Some(every)) => { + let every = *duration_cache.get_or_insert_with(every, Duration::parse); - if every.negative { - polars_bail!(ComputeError: "cannot truncate a Datetime to a negative duration") - } + if every.negative { + polars_bail!(ComputeError: "cannot truncate a Datetime to a negative duration") + } - let w = Window::new(every, every, offset); - func(&w, timestamp, tz).map(Some) + let w = Window::new(every, every, offset); + func(&w, timestamp, tz).map(Some) + }, + _ => Ok(None), }, - _ => Ok(None), - }); + ); Ok(out?.into_datetime(self.time_unit(), self.time_zone().clone())) } } @@ -118,7 +120,7 @@ impl PolarsTruncate for DateChunked { polars_bail!(ComputeError: "cannot truncate a Date to a negative duration") } let w = Window::new(every, every, offset); - self.try_apply_nonnull_values_generic(|t| { + self.physical().try_apply_nonnull_values_generic(|t| { Ok((w.truncate_ms(MILLISECONDS_IN_DAY * t as i64, None)? / MILLISECONDS_IN_DAY) as i32) }) @@ -126,7 +128,7 @@ impl PolarsTruncate for DateChunked { Ok(Int32Chunked::full_null(self.name().clone(), self.len())) } }, - _ => broadcast_try_binary_elementwise(self, every, |opt_t, opt_every| { + _ => broadcast_try_binary_elementwise(self.physical(), every, |opt_t, opt_every| { // A sqrt(n) cache is not too small, not too large. let mut duration_cache = LruCache::with_capacity((every.len() as f64).sqrt() as usize); diff --git a/crates/polars-utils/src/pl_str.rs b/crates/polars-utils/src/pl_str.rs index 9e9917727606..e7a2cf9dd608 100644 --- a/crates/polars-utils/src/pl_str.rs +++ b/crates/polars-utils/src/pl_str.rs @@ -70,6 +70,12 @@ impl PlSmallStr { self.0.as_mut_str() } + #[inline(always)] + #[allow(clippy::inherent_to_string_shadow_display)] // This is faster. + pub fn to_string(&self) -> String { + self.0.as_str().to_owned() + } + #[inline(always)] pub fn into_string(self) -> String { self.0.into_string() diff --git a/crates/polars/src/lib.rs b/crates/polars/src/lib.rs index 9897b63e1d67..c3570f6302b2 100644 --- a/crates/polars/src/lib.rs +++ b/crates/polars/src/lib.rs @@ -421,8 +421,6 @@ pub use polars_core::{ apply_method_all_arrow_series, chunked_array, datatypes, df, error, frame, functions, series, testing, }; -#[cfg(feature = "dtype-categorical")] -pub use polars_core::{enable_string_cache, using_string_cache}; #[cfg(feature = "polars-io")] pub use polars_io as io; #[cfg(feature = "lazy")] diff --git a/crates/polars/tests/it/core/joins.rs b/crates/polars/tests/it/core/joins.rs index 63f622295b0f..e9f0be5aaebb 100644 --- a/crates/polars/tests/it/core/joins.rs +++ b/crates/polars/tests/it/core/joins.rs @@ -1,6 +1,4 @@ use polars_core::utils::{accumulate_dataframes_vertical, split_df}; -#[cfg(feature = "dtype-categorical")] -use polars_core::{SINGLE_LOCK, disable_string_cache}; use super::*; @@ -313,18 +311,14 @@ fn test_join_multiple_columns() { #[cfg_attr(miri, ignore)] #[cfg(feature = "dtype-categorical")] fn test_join_categorical() { - let _guard = SINGLE_LOCK.lock(); - disable_string_cache(); - let _sc = StringCacheHolder::hold(); - let (mut df_a, mut df_b) = get_dfs(); df_a.try_apply("b", |s| { - s.cast(&DataType::Categorical(None, Default::default())) + s.cast(&DataType::from_categories(Categories::global())) }) .unwrap(); df_b.try_apply("bar", |s| { - s.cast(&DataType::Categorical(None, Default::default())) + s.cast(&DataType::from_categories(Categories::global())) }) .unwrap(); @@ -351,23 +345,23 @@ fn test_join_categorical() { let out = out.column("b").unwrap(); assert_eq!( out.dtype(), - &DataType::Categorical(None, Default::default()) + &DataType::from_categories(Categories::global()) ); } - // Test error when joining on different string cache + // Test error when joining on different categories. let (mut df_a, mut df_b) = get_dfs(); df_a.try_apply("b", |s| { - s.cast(&DataType::Categorical(None, Default::default())) + s.cast(&DataType::from_categories(Categories::global())) }) .unwrap(); - // Create a new string cache - drop(_sc); - let _sc = StringCacheHolder::hold(); - df_b.try_apply("bar", |s| { - s.cast(&DataType::Categorical(None, Default::default())) + s.cast(&DataType::from_categories(Categories::new( + PlSmallStr::from_static("test"), + PlSmallStr::EMPTY, + CategoricalPhysical::U32, + ))) }) .unwrap(); let out = df_a.join(&df_b, ["b"], ["bar"], JoinType::Left.into(), None); diff --git a/crates/polars/tests/it/core/pivot.rs b/crates/polars/tests/it/core/pivot.rs index 04ea549bdff3..57bed7e0021a 100644 --- a/crates/polars/tests/it/core/pivot.rs +++ b/crates/polars/tests/it/core/pivot.rs @@ -143,7 +143,7 @@ fn test_pivot_categorical() -> PolarsResult<()> { "values" => [8, 2, 3, 6, 3, 6, 2, 2], ]?; df.try_apply("columns", |s| { - s.cast(&DataType::Categorical(None, Default::default())) + s.cast(&DataType::from_categories(Categories::global())) })?; let out = pivot( diff --git a/crates/polars/tests/it/lazy/expressions/arity.rs b/crates/polars/tests/it/lazy/expressions/arity.rs index 6d4dc7e530e5..d04a9db37ef4 100644 --- a/crates/polars/tests/it/lazy/expressions/arity.rs +++ b/crates/polars/tests/it/lazy/expressions/arity.rs @@ -116,8 +116,6 @@ fn includes_null_predicate_3038() -> PolarsResult<()> { #[test] #[cfg(feature = "dtype-categorical")] fn test_when_then_otherwise_cats() -> PolarsResult<()> { - polars::enable_string_cache(); - let lf = df!["book" => [Some("bookA"), None, Some("bookB"), @@ -130,8 +128,8 @@ fn test_when_then_otherwise_cats() -> PolarsResult<()> { ]?.lazy(); let out = lf - .with_column(col("book").cast(DataType::Categorical(None, Default::default()))) - .with_column(col("user").cast(DataType::Categorical(None, Default::default()))) + .with_column(col("book").cast(DataType::from_categories(Categories::global()))) + .with_column(col("user").cast(DataType::from_categories(Categories::global()))) .with_column( when(col("book").is_null()) .then(col("user")) @@ -142,7 +140,7 @@ fn test_when_then_otherwise_cats() -> PolarsResult<()> { assert_eq!( out.column("a")? - .categorical()? + .cat32()? .iter_str() .flatten() .collect::>(), diff --git a/crates/polars/tests/it/lazy/predicate_queries.rs b/crates/polars/tests/it/lazy/predicate_queries.rs index 977b159abbe7..6372da6cfa27 100644 --- a/crates/polars/tests/it/lazy/predicate_queries.rs +++ b/crates/polars/tests/it/lazy/predicate_queries.rs @@ -1,7 +1,3 @@ -// used only if feature="is_in", feature="dtype-categorical" -#[cfg(all(feature = "is_in", feature = "dtype-categorical"))] -use polars_core::{SINGLE_LOCK, StringCacheHolder, disable_string_cache}; - use super::*; #[test] @@ -131,15 +127,11 @@ fn test_is_in_categorical_3420() -> PolarsResult<()> { "b" => [1, 2, 3, 4, 5] ]?; - let _guard = SINGLE_LOCK.lock(); - disable_string_cache(); - let _sc = StringCacheHolder::hold(); - let s = Series::new("x".into(), ["a", "b", "c"]) - .strict_cast(&DataType::Categorical(None, Default::default()))?; + .strict_cast(&DataType::from_categories(Categories::global()))?; let out = df .lazy() - .with_column(col("a").strict_cast(DataType::Categorical(None, Default::default()))) + .with_column(col("a").strict_cast(DataType::from_categories(Categories::global()))) .filter(col("a").is_in(lit(s).alias("x"), false)) .collect()?; @@ -148,7 +140,7 @@ fn test_is_in_categorical_3420() -> PolarsResult<()> { "b" => [1, 2, 3] ]?; expected.try_apply("a", |s| { - s.cast(&DataType::Categorical(None, Default::default())) + s.cast(&DataType::from_categories(Categories::global())) })?; assert!(out.equals(&expected)); Ok(()) diff --git a/docs/source/src/rust/user-guide/expressions/aggregation.rs b/docs/source/src/rust/user-guide/expressions/aggregation.rs index cb6ee70f6ebf..4ac3127eda7b 100644 --- a/docs/source/src/rust/user-guide/expressions/aggregation.rs +++ b/docs/source/src/rust/user-guide/expressions/aggregation.rs @@ -10,23 +10,23 @@ fn main() -> Result<(), Box> { let mut schema = Schema::default(); schema.with_column( "first_name".into(), - DataType::Categorical(None, Default::default()), + DataType::from_categories(Categories::global()), ); schema.with_column( "gender".into(), - DataType::Categorical(None, Default::default()), + DataType::from_categories(Categories::global()), ); schema.with_column( "type".into(), - DataType::Categorical(None, Default::default()), + DataType::from_categories(Categories::global()), ); schema.with_column( "state".into(), - DataType::Categorical(None, Default::default()), + DataType::from_categories(Categories::global()), ); schema.with_column( "party".into(), - DataType::Categorical(None, Default::default()), + DataType::from_categories(Categories::global()), ); schema.with_column("birthday".into(), DataType::Date); diff --git a/docs/source/src/rust/user-guide/transformations/time-series/rolling.rs b/docs/source/src/rust/user-guide/transformations/time-series/rolling.rs index 7378816b786a..dccdada0f622 100644 --- a/docs/source/src/rust/user-guide/transformations/time-series/rolling.rs +++ b/docs/source/src/rust/user-guide/transformations/time-series/rolling.rs @@ -90,6 +90,7 @@ fn main() -> Result<(), Box> { |s| { Ok(Some( s.duration()? + .physical() .into_iter() .map(|d| d.map(|v| v / 1000 / 24 / 60 / 60)) .collect::() diff --git a/py-polars/polars/datatypes/classes.py b/py-polars/polars/datatypes/classes.py index 9156d128edc6..d02015f9bd7c 100644 --- a/py-polars/polars/datatypes/classes.py +++ b/py-polars/polars/datatypes/classes.py @@ -648,7 +648,7 @@ class Categorical(DataType): def __init__( self, - ordering: CategoricalOrdering | None = "physical", + ordering: CategoricalOrdering | None = "lexical", ) -> None: self.ordering = ordering diff --git a/py-polars/polars/datatypes/group.py b/py-polars/polars/datatypes/group.py index 783e5f049720..ec31947ed3f1 100644 --- a/py-polars/polars/datatypes/group.py +++ b/py-polars/polars/datatypes/group.py @@ -106,7 +106,6 @@ def __contains__(self, item: Any) -> bool: CATEGORICAL_DTYPES: frozenset[PolarsDataType] = DataTypeGroup( [ Categorical, - Categorical("physical"), Categorical("lexical"), ] ) diff --git a/py-polars/polars/expr/categorical.py b/py-polars/polars/expr/categorical.py index 0d40107d623d..1cf36e62df92 100644 --- a/py-polars/polars/expr/categorical.py +++ b/py-polars/polars/expr/categorical.py @@ -26,7 +26,7 @@ def get_categories(self) -> Expr: >>> df = pl.Series( ... "cats", ["foo", "bar", "foo", "foo", "ham"], dtype=pl.Categorical ... ).to_frame() - >>> df.select(pl.col("cats").cat.get_categories()) + >>> df.select(pl.col("cats").cat.get_categories()) # doctest: +SKIP shape: (3, 1) ┌──────┐ │ cats │ diff --git a/py-polars/polars/interchange/column.py b/py-polars/polars/interchange/column.py index 65bc9a19fd46..a1ef4eaca5e3 100644 --- a/py-polars/polars/interchange/column.py +++ b/py-polars/polars/interchange/column.py @@ -69,7 +69,7 @@ def describe_categorical(self) -> CategoricalDescription: dtype = self._col.dtype if dtype == Categorical: categories = self._col.cat.get_categories() - is_ordered = dtype.ordering == "physical" # type: ignore[attr-defined] + is_ordered = False elif dtype == Enum: categories = dtype.categories # type: ignore[attr-defined] is_ordered = True @@ -155,11 +155,6 @@ def get_buffers(self) -> ColumnBuffers: if dtype == String and not self._allow_copy: msg = "string buffers must be converted" raise CopyNotAllowedError(msg) - elif dtype == Categorical and not self._col.cat.is_local(): - if not self._allow_copy: - msg = f"column {self._col.name!r} must be converted to a local categorical" - raise CopyNotAllowedError(msg) - self._col = self._col.cat.to_local() buffers = self._col._get_buffers() diff --git a/py-polars/polars/series/categorical.py b/py-polars/polars/series/categorical.py index f82e512ab81b..3dea6104d7ab 100644 --- a/py-polars/polars/series/categorical.py +++ b/py-polars/polars/series/categorical.py @@ -27,7 +27,7 @@ def get_categories(self) -> Series: Examples -------- >>> s = pl.Series(["foo", "bar", "foo", "foo", "ham"], dtype=pl.Categorical) - >>> s.cat.get_categories() + >>> s.cat.get_categories() # doctest: +SKIP shape: (3,) Series: '' [str] [ @@ -41,56 +41,12 @@ def is_local(self) -> bool: """ Return whether or not the column is a local categorical. - Examples - -------- - Categoricals constructed without a string cache are considered local. - - >>> s = pl.Series(["a", "b", "a"], dtype=pl.Categorical) - >>> s.cat.is_local() - True - - Categoricals constructed with a string cache are considered global. - - >>> with pl.StringCache(): - ... s = pl.Series(["a", "b", "a"], dtype=pl.Categorical) - >>> s.cat.is_local() - False + Always returns false. """ return self._s.cat_is_local() def to_local(self) -> Series: - """ - Convert a categorical column to its local representation. - - This may change the underlying physical representation of the column. - - See the documentation of :func:`StringCache` for more information on the - difference between local and global categoricals. - - Examples - -------- - Compare the global and local representations of a categorical. - - >>> with pl.StringCache(): - ... _ = pl.Series("x", ["a", "b", "a"], dtype=pl.Categorical) - ... s = pl.Series("y", ["c", "b", "d"], dtype=pl.Categorical) - >>> s.to_physical() - shape: (3,) - Series: 'y' [u32] - [ - 2 - 1 - 3 - ] - >>> s.cat.to_local().to_physical() - shape: (3,) - Series: 'y' [u32] - [ - 0 - 1 - 2 - ] - """ + """Simply returns the column as-is, local representations are deprecated.""" return wrap_s(self._s.cat_to_local()) @unstable() @@ -106,9 +62,6 @@ def uses_lexical_ordering(self) -> bool: -------- >>> s = pl.Series(["b", "a", "b"]).cast(pl.Categorical) >>> s.cat.uses_lexical_ordering() - False - >>> s = s.cast(pl.Categorical("lexical")) - >>> s.cat.uses_lexical_ordering() True """ return self._s.cat_uses_lexical_ordering() diff --git a/py-polars/polars/testing/parametric/strategies/dtype.py b/py-polars/polars/testing/parametric/strategies/dtype.py index ce4af6d4bcb0..8901de68bae8 100644 --- a/py-polars/polars/testing/parametric/strategies/dtype.py +++ b/py-polars/polars/testing/parametric/strategies/dtype.py @@ -39,7 +39,7 @@ from hypothesis.strategies import DrawFn, SearchStrategy - from polars._typing import CategoricalOrdering, PolarsDataType, TimeUnit + from polars._typing import PolarsDataType, TimeUnit from polars.datatypes import DataTypeClass @@ -233,8 +233,7 @@ def _instantiate_flat_dtype( time_unit = draw(_time_units()) return Duration(time_unit) elif dtype == Categorical: - ordering = draw(_categorical_orderings()) - return Categorical(ordering) + return Categorical() elif dtype == Enum: n_categories = draw( st.integers(min_value=1, max_value=_DEFAULT_ENUM_CATEGORIES_LIMIT) @@ -338,11 +337,6 @@ def _time_zones() -> SearchStrategy[str]: ) -def _categorical_orderings() -> SearchStrategy[CategoricalOrdering]: - """Create a strategy for generating valid ordering types for categorical data.""" - return st.sampled_from(["physical", "lexical"]) - - @st.composite def _instantiate_dtype( draw: DrawFn, diff --git a/py-polars/tests/unit/constructors/test_any_value_fallbacks.py b/py-polars/tests/unit/constructors/test_any_value_fallbacks.py index a197f3705443..7078f60a5efa 100644 --- a/py-polars/tests/unit/constructors/test_any_value_fallbacks.py +++ b/py-polars/tests/unit/constructors/test_any_value_fallbacks.py @@ -382,9 +382,7 @@ def test_fallback_with_dtype_strict_failure_enum_casting() -> None: dtype = pl.Enum(["a", "b"]) values = ["a", "b", "c", None] - with pytest.raises( - TypeError, match="cannot append 'c' to enum without that variant" - ): + with pytest.raises(TypeError, match="attempted to insert 'c'"): PySeries.new_from_any_values_and_dtype("", values, dtype, strict=True) diff --git a/py-polars/tests/unit/constructors/test_constructors.py b/py-polars/tests/unit/constructors/test_constructors.py index a5e0f71073d0..95ed03c4b37e 100644 --- a/py-polars/tests/unit/constructors/test_constructors.py +++ b/py-polars/tests/unit/constructors/test_constructors.py @@ -258,7 +258,7 @@ class TradeNT(NamedTuple): ) assert df.schema == { "ts": pl.Datetime("ms"), - "tk": pl.Categorical(ordering="physical"), + "tk": pl.Categorical(ordering="lexical"), "pc": pl.Decimal(scale=1), "sz": pl.UInt16, } diff --git a/py-polars/tests/unit/datatypes/test_categorical.py b/py-polars/tests/unit/datatypes/test_categorical.py index f92f361df801..624b67c3a710 100644 --- a/py-polars/tests/unit/datatypes/test_categorical.py +++ b/py-polars/tests/unit/datatypes/test_categorical.py @@ -1,25 +1,19 @@ from __future__ import annotations -import contextlib import io import operator -from typing import TYPE_CHECKING, Any, Callable, Literal +from typing import Callable import pytest import polars as pl from polars import StringCache from polars.exceptions import ( - CategoricalRemappingWarning, ComputeError, - StringCacheMismatchError, ) from polars.testing import assert_frame_equal, assert_series_equal from tests.unit.conftest import with_string_cache_if_auto_streaming -if TYPE_CHECKING: - from polars._typing import PolarsDataType - @StringCache() def test_categorical_full_outer_join() -> None: @@ -172,26 +166,22 @@ def test_categorical_equality_global_fastpath( @pytest.mark.parametrize( - ("op", "expected_phys", "expected_lexical"), + ("op", "expected_lexical"), [ ( operator.le, - pl.Series([True, True, True, True, False]), pl.Series([False, True, True, False, True]), ), ( operator.lt, - pl.Series([True, False, False, True, False]), pl.Series([False, False, False, False, True]), ), ( operator.ge, - pl.Series([False, True, True, False, True]), pl.Series([True, True, True, True, False]), ), ( operator.gt, - pl.Series([False, False, False, False, True]), pl.Series([True, False, False, True, False]), ), ], @@ -199,12 +189,11 @@ def test_categorical_equality_global_fastpath( @StringCache() def test_categorical_global_ordering( op: Callable[[pl.Series, pl.Series], pl.Series], - expected_phys: pl.Series, expected_lexical: pl.Series, ) -> None: s = pl.Series(["z", "b", "c", "c", "a"], dtype=pl.Categorical) s2 = pl.Series("b_cat", ["a", "b", "c", "a", "c"], dtype=pl.Categorical) - assert_series_equal(op(s, s2), expected_phys) + assert_series_equal(op(s, s2), expected_lexical) s = s.cast(pl.Categorical("lexical")) s2 = s2.cast(pl.Categorical("lexical")) @@ -212,27 +201,25 @@ def test_categorical_global_ordering( @pytest.mark.parametrize( - ("op", "expected_phys", "expected_lexical"), + ("op", "expected_lexical"), [ - (operator.le, pl.Series([True, True, False]), pl.Series([False, True, False])), + (operator.le, pl.Series([False, True, False])), ( operator.lt, - pl.Series([True, False, False]), pl.Series([False, False, False]), ), - (operator.ge, pl.Series([False, True, True]), pl.Series([True, True, True])), - (operator.gt, pl.Series([False, False, True]), pl.Series([True, False, True])), + (operator.ge, pl.Series([True, True, True])), + (operator.gt, pl.Series([True, False, True])), ], ) @StringCache() def test_categorical_global_ordering_broadcast_rhs( op: Callable[[pl.Series, pl.Series], pl.Series], - expected_phys: pl.Series, expected_lexical: pl.Series, ) -> None: s = pl.Series(["c", "a", "b"], dtype=pl.Categorical) s2 = pl.Series("b_cat", ["a"], dtype=pl.Categorical) - assert_series_equal(op(s, s2), expected_phys) + assert_series_equal(op(s, s2), expected_lexical) s = s.cast(pl.Categorical("lexical")) s2 = s2.cast(pl.Categorical("lexical")) @@ -241,14 +228,13 @@ def test_categorical_global_ordering_broadcast_rhs( @pytest.mark.parametrize( - ("op", "expected_phys", "expected_lexical"), + ("op", "expected_lexical"), [ - (operator.le, pl.Series([True, True, True]), pl.Series([True, False, True])), - (operator.lt, pl.Series([True, True, False]), pl.Series([True, False, False])), - (operator.ge, pl.Series([False, False, True]), pl.Series([False, True, True])), + (operator.le, pl.Series([True, False, True])), + (operator.lt, pl.Series([True, False, False])), + (operator.ge, pl.Series([False, True, True])), ( operator.gt, - pl.Series([False, False, False]), pl.Series([False, True, False]), ), ], @@ -256,12 +242,11 @@ def test_categorical_global_ordering_broadcast_rhs( @StringCache() def test_categorical_global_ordering_broadcast_lhs( op: Callable[[pl.Series, pl.Series], pl.Series], - expected_phys: pl.Series, expected_lexical: pl.Series, ) -> None: s = pl.Series(["b"], dtype=pl.Categorical) s2 = pl.Series(["c", "a", "b"], dtype=pl.Categorical) - assert_series_equal(op(s, s2), expected_phys) + assert_series_equal(op(s, s2), expected_lexical) s = s.cast(pl.Categorical("lexical")) s2 = s2.cast(pl.Categorical("lexical")) @@ -395,18 +380,14 @@ def test_compare_categorical_single_none( assert_series_equal(op(s, s2.cast(pl.String)), expected) -def test_categorical_error_on_local_cmp() -> None: +def test_categorical_cmp_noteq() -> None: df_cat = pl.DataFrame( [ pl.Series("a_cat", ["c", "a", "b", "c", "b"], dtype=pl.Categorical), pl.Series("b_cat", ["F", "G", "E", "G", "G"], dtype=pl.Categorical), ] ) - with pytest.raises( - StringCacheMismatchError, - match="cannot compare categoricals coming from different sources", - ): - df_cat.filter(pl.col("a_cat") == pl.col("b_cat")) + assert len(df_cat.filter(pl.col("a_cat") == pl.col("b_cat"))) == 0 @pytest.mark.usefixtures("test_global_and_local") @@ -431,39 +412,6 @@ def test_merge_lit_under_global_cache_4491() -> None: ).to_dict(as_series=False) == {"label": [None, "bar"], "value": [3, 9]} -def test_nested_cache_composition() -> None: - # very artificial example/test, but validates the behaviour - # of nested StringCache scopes, which we want to play well - # with each other when composing more complex pipelines. - - assert pl.using_string_cache() is False - - # function representing a composable stage of a pipeline; it implements - # an inner scope for the case where it is called by itself, but when - # called as part of a larger series of ops it should not invalidate - # the string cache (eg: the outermost scope should be respected). - def create_lazy(data: dict) -> pl.LazyFrame: # type: ignore[type-arg] - with pl.StringCache(): - df = pl.DataFrame({"a": ["foo", "bar", "ham"], "b": [1, 2, 3]}) - lf = df.with_columns(pl.col("a").cast(pl.Categorical)).lazy() - - # confirm that scope-exit does NOT invalidate the - # cache yet, as an outer context is still active - assert pl.using_string_cache() is True - return lf - - # this outer scope should be respected - with pl.StringCache(): - lf1 = create_lazy({"a": ["foo", "bar", "ham"], "b": [1, 2, 3]}) - lf2 = create_lazy({"a": ["spam", "foo", "eggs"], "c": [3, 2, 2]}) - - res = lf1.join(lf2, on="a", how="inner").collect().rows() - assert sorted(res) == [("bar", 2, 2), ("foo", 1, 1), ("ham", 3, 3)] - - # no other scope active; NOW we expect the cache to have been invalidated - assert pl.using_string_cache() is False - - @pytest.mark.usefixtures("test_global_and_local") def test_categorical_in_struct_nulls() -> None: s = pl.Series( @@ -505,28 +453,16 @@ def test_stringcache() -> None: } -@pytest.mark.parametrize( - ("dtype", "outcome"), - [ - (pl.Categorical, ["foo", "bar", "baz"]), - (pl.Categorical("physical"), ["foo", "bar", "baz"]), - (pl.Categorical("lexical"), ["bar", "baz", "foo"]), - ], -) -@pytest.mark.usefixtures("test_global_and_local") -def test_categorical_sort_order_by_parameter( - dtype: PolarsDataType, outcome: list[str] -) -> None: - s = pl.Series(["foo", "bar", "baz"], dtype=dtype) +def test_categorical_sort_single() -> None: + s = pl.Series(["foo", "bar", "baz"], dtype=pl.Categorical) df = pl.DataFrame({"cat": s}) - assert df.sort(["cat"])["cat"].to_list() == outcome + assert df.sort(["cat"])["cat"].to_list() == ["bar", "baz", "foo"] -@StringCache() -@pytest.mark.parametrize("row_fmt_sort_enabled", [False, True]) -def test_categorical_sort_order(row_fmt_sort_enabled: bool, monkeypatch: Any) -> None: +def test_categorical_sort_multiple() -> None: # create the categorical ordering first - pl.Series(["foo", "bar", "baz"], dtype=pl.Categorical) + _s = pl.Series(["foo", "bar", "baz"], dtype=pl.Categorical) + df = pl.DataFrame( { "n": [0, 0, 0], @@ -535,17 +471,11 @@ def test_categorical_sort_order(row_fmt_sort_enabled: bool, monkeypatch: Any) -> } ) - if row_fmt_sort_enabled: - monkeypatch.setenv("POLARS_ROW_FMT_SORT", "1") - - result = df.sort(["n", "x"]) - assert result["x"].to_list() == ["foo", "bar", "baz"] - result = df.with_columns(pl.col("x").cast(pl.Categorical("lexical"))).sort("n", "x") assert result["x"].to_list() == ["bar", "baz", "foo"] -def test_err_on_categorical_asof_join_by_arg() -> None: +def test_categorical_asof_join_by_arg() -> None: df1 = pl.DataFrame( [ pl.Series("cat", ["a", "foo", "bar", "foo", "bar"], dtype=pl.Categorical), @@ -563,11 +493,11 @@ def test_err_on_categorical_asof_join_by_arg() -> None: pl.Series("x", [1, 2, 3, 4] * 2, dtype=pl.Int32), ] ) - with pytest.raises( - StringCacheMismatchError, - match="cannot compare categoricals coming from different sources", - ): - df1.join_asof(df2, on=pl.col("time").set_sorted(), by="cat") + df1s = df1.with_columns(cat=pl.col.cat.cast(pl.String)) + df2s = df2.with_columns(cat=pl.col.cat.cast(pl.String)) + out1 = df1.join_asof(df2, on=pl.col("time").set_sorted(), by="cat") + out2 = df1s.join_asof(df2s, on=pl.col("time").set_sorted(), by="cat") + assert_frame_equal(out1, out2.with_columns(cat=pl.col.cat.cast(pl.Categorical))) @pytest.mark.usefixtures("test_global_and_local") @@ -623,7 +553,8 @@ def test_categorical_fill_null_existing_category() -> None: # ensure physical types align df = pl.DataFrame({"col": ["a", None, "a"]}, schema={"col": pl.Categorical}) result = df.fill_null("a").with_columns(pl.col("col").to_physical().alias("code")) - expected = {"col": ["a", "a", "a"], "code": [0, 0, 0]} + d = result.to_dict(as_series=False) + expected = {"col": ["a", "a", "a"], "code": [d["code"][0]] * 3} assert result.to_dict(as_series=False) == expected @@ -717,63 +648,54 @@ def test_categorical_update_lengths() -> None: assert s.len() == 5 -def test_categorical_zip_append_local_different_rev_map() -> None: +def test_categorical_zip_append() -> None: s1 = pl.Series(["cat1", "cat2", "cat1"], dtype=pl.Categorical) s2 = pl.Series(["cat2", "cat2", "cat3"], dtype=pl.Categorical) - with pytest.warns( - CategoricalRemappingWarning, - match="Local categoricals have different encodings", - ): - s3 = s1.append(s2) - categories = s3.cat.get_categories() - assert len(categories) == 3 - assert set(categories) == {"cat1", "cat2", "cat3"} + s3 = s1.append(s2) + assert_series_equal( + s3, + pl.Series( + ["cat1", "cat2", "cat1", "cat2", "cat2", "cat3"], dtype=pl.Categorical + ), + ) -def test_categorical_zip_extend_local_different_rev_map() -> None: +def test_categorical_zip_extend() -> None: s1 = pl.Series(["cat1", "cat2", "cat1"], dtype=pl.Categorical) s2 = pl.Series(["cat2", "cat2", "cat3"], dtype=pl.Categorical) - with pytest.warns( - CategoricalRemappingWarning, - match="Local categoricals have different encodings", - ): - s3 = s1.extend(s2) - categories = s3.cat.get_categories() - assert len(categories) == 3 - assert set(categories) == {"cat1", "cat2", "cat3"} + s3 = s1.extend(s2) + assert_series_equal( + s3, + pl.Series( + ["cat1", "cat2", "cat1", "cat2", "cat2", "cat3"], dtype=pl.Categorical + ), + ) -def test_categorical_zip_with_local_different_rev_map() -> None: +def test_categorical_zip() -> None: s1 = pl.Series(["cat1", "cat2", "cat1"], dtype=pl.Categorical) mask = pl.Series([True, False, False]) s2 = pl.Series(["cat2", "cat2", "cat3"], dtype=pl.Categorical) - with pytest.warns( - CategoricalRemappingWarning, - match="Local categoricals have different encodings", - ): - s3 = s1.zip_with(mask, s2) - categories = s3.cat.get_categories() - assert len(categories) == 3 - assert set(categories) == {"cat1", "cat2", "cat3"} + s3 = s1.zip_with(mask, s2) + assert_series_equal(s3, pl.Series(["cat1", "cat2", "cat3"], dtype=pl.Categorical)) -def test_categorical_vstack_with_local_different_rev_map() -> None: +def test_categorical_vstack() -> None: df1 = pl.DataFrame({"a": pl.Series(["a", "b", "c"], dtype=pl.Categorical)}) df2 = pl.DataFrame({"a": pl.Series(["d", "e", "f"], dtype=pl.Categorical)}) - with pytest.warns( - CategoricalRemappingWarning, - match="Local categoricals have different encodings", - ): - df3 = df1.vstack(df2) - assert df3.get_column("a").cat.get_categories().to_list() == [ + df3 = df1.vstack(df2) + expected = pl.DataFrame( + {"a": pl.Series(["a", "b", "c", "d", "e", "f"], dtype=pl.Categorical)} + ) + assert_frame_equal(df3, expected) + assert set(df3.get_column("a").cat.get_categories().to_list()) >= { "a", "b", "c", "d", "e", "f", - ] - assert df3.get_column("a").cast(pl.UInt32).to_list() == [0, 1, 2, 3, 4, 5] + } @pytest.mark.usefixtures("test_global_and_local") @@ -792,51 +714,28 @@ def test_shift_over_13041() -> None: } -@pytest.mark.parametrize("context", [pl.StringCache(), contextlib.nullcontext()]) -@pytest.mark.parametrize("ordering", ["physical", "lexical"]) -@pytest.mark.usefixtures("test_global_and_local") -def test_sort_categorical_retain_none( - context: contextlib.AbstractContextManager, # type: ignore[type-arg] - ordering: Literal["physical", "lexical"], -) -> None: - with context: - df = pl.DataFrame( - [ - pl.Series( - "e", - ["foo", None, "bar", "ham", None], - dtype=pl.Categorical(ordering=ordering), - ) - ] - ) - - df_sorted = df.with_columns(pl.col("e").sort()) - assert ( - df_sorted.get_column("e").null_count() - == df.get_column("e").null_count() - == 2 - ) - if ordering == "lexical": - assert df_sorted.get_column("e").to_list() == [ - None, - None, - "bar", - "foo", - "ham", - ] - - -@pytest.mark.usefixtures("test_global_and_local") -def test_cast_from_cat_to_numeric() -> None: - cat_series = pl.Series( - "cat_series", - ["0.69845702", "0.69317475", "2.43642724", "-0.95303469", "0.60684237"], - ).cast(pl.Categorical) - maximum = cat_series.cast(pl.Float32).max() - assert abs(maximum - 2.43642724) < 1e-6 # type: ignore[operator] +def test_sort_categorical_retain_none() -> None: + df = pl.DataFrame( + [ + pl.Series( + "e", + ["foo", None, "bar", "ham", None], + dtype=pl.Categorical(), + ) + ] + ) - s = pl.Series(["1", "2", "3"], dtype=pl.Categorical) - assert s.cast(pl.UInt8).sum() == 6 + df_sorted = df.with_columns(pl.col("e").sort()) + assert ( + df_sorted.get_column("e").null_count() == df.get_column("e").null_count() == 2 + ) + assert df_sorted.get_column("e").to_list() == [ + None, + None, + "bar", + "foo", + "ham", + ] @pytest.mark.usefixtures("test_global_and_local") @@ -875,16 +774,6 @@ def test_cat_append_lexical_sorted_flag() -> None: assert not (s1.is_sorted()) -@pytest.mark.usefixtures("test_global_and_local") -def test_cast_physical_lexical_sorted_flag_20864() -> None: - df = pl.DataFrame({"s": ["b", "a"], "v": [1, 2]}) - sorted_physically = df.cast({"s": pl.Categorical("physical")}).sort("s") - sorted_lexically = sorted_physically.cast({"s": pl.Categorical("lexical")}).sort( - "s" - ) - assert sorted_lexically["s"].to_list() == ["a", "b"] - - @pytest.mark.usefixtures("test_global_and_local") def test_get_cat_categories_multiple_chunks() -> None: df = pl.DataFrame( @@ -895,8 +784,8 @@ def test_get_cat_categories_multiple_chunks() -> None: df = pl.concat( [df for _ in range(100)], how="vertical", rechunk=False, parallel=True ) - df_cat = df.lazy().select(pl.col("e").cat.get_categories()).collect() - assert len(df_cat) == 2 + cats = df.lazy().select(pl.col("e").cat.get_categories()).collect()["e"].to_list() + assert set(cats) >= {"a", "b"} @pytest.mark.parametrize( @@ -913,9 +802,9 @@ def test_nested_categorical_concat( _, vb = f("b") a = pl.DataFrame({"x": [va]}, schema={"x": dt}) b = pl.DataFrame({"x": [vb]}, schema={"x": dt}) - - with pytest.raises(pl.exceptions.StringCacheMismatchError): - pl.concat([a, b]) + assert_frame_equal( + pl.concat([a, b]), pl.DataFrame({"x": [va, vb]}, schema={"x": dt}) + ) @with_string_cache_if_auto_streaming @@ -996,14 +885,12 @@ def test_categorical_prefill() -> None: def test_categorical_min_max() -> None: schema = pl.Schema( { - "a": pl.Categorical("physical"), "b": pl.Categorical("lexical"), "c": pl.Enum(["foo", "bar"]), } ) lf = pl.LazyFrame( { - "a": ["foo", "bar"], "b": ["foo", "bar"], "c": ["foo", "bar"], }, @@ -1014,10 +901,10 @@ def test_categorical_min_max() -> None: result = q.collect() assert q.collect_schema() == schema assert result.schema == schema - assert result.to_dict(as_series=False) == {"a": ["foo"], "b": ["bar"], "c": ["foo"]} + assert result.to_dict(as_series=False) == {"b": ["bar"], "c": ["foo"]} q = lf.select(pl.all().max()) result = q.collect() assert q.collect_schema() == schema assert result.schema == schema - assert result.to_dict(as_series=False) == {"a": ["bar"], "b": ["foo"], "c": ["bar"]} + assert result.to_dict(as_series=False) == {"b": ["foo"], "c": ["bar"]} diff --git a/py-polars/tests/unit/datatypes/test_enum.py b/py-polars/tests/unit/datatypes/test_enum.py index d19b3c8945fd..a8a1ee5086fb 100644 --- a/py-polars/tests/unit/datatypes/test_enum.py +++ b/py-polars/tests/unit/datatypes/test_enum.py @@ -15,9 +15,7 @@ import polars as pl from polars import StringCache from polars.exceptions import ( - ComputeError, InvalidOperationError, - OutOfBoundsError, SchemaError, ) from polars.testing import assert_frame_equal, assert_series_equal @@ -256,9 +254,7 @@ def test_casting_to_an_enum_from_integer() -> None: def test_casting_to_an_enum_oob_from_integer() -> None: dtype = pl.Enum(["a", "b", "c"]) s = pl.Series([None, 1, 0, 5], dtype=pl.UInt32) - with pytest.raises( - OutOfBoundsError, match=("index 5 is bigger than the number of categories 3") - ): + with pytest.raises(InvalidOperationError, match=("values: \\[5\\]")): s.cast(dtype) @@ -321,7 +317,7 @@ def test_append_to_an_enum() -> None: def test_append_to_an_enum_with_new_category() -> None: with pytest.raises( SchemaError, - match=("type Enum.*is incompatible with expected type Enum.*"), + match=("Enum mismatch"), ): pl.Series([None, "a", "b", "c"], dtype=pl.Enum(["a", "b", "c"])).append( pl.Series(["d", "a", "b", "c"], dtype=pl.Enum(["a", "b", "c", "d"])) @@ -337,10 +333,7 @@ def test_extend_to_an_enum() -> None: def test_series_init_uninstantiated_enum() -> None: - with pytest.raises( - InvalidOperationError, - match="cannot cast / initialize Enum without categories present", - ): + with pytest.raises(InvalidOperationError): pl.Series(["a", "b", "a"], dtype=pl.Enum) @@ -447,9 +440,7 @@ def test_compare_enum_str_single_raise( with pytest.raises( InvalidOperationError, - match=re.escape( - "conversion from `str` to `enum` failed in column '' for 1 out of 1 values: [\"NOTEXIST\"]" - ), + match=re.escape("conversion from `str` to `enum` failed"), ): op(s, s2) # type: ignore[arg-type] @@ -463,7 +454,7 @@ def test_compare_enum_str_raise() -> None: for op in [operator.le, operator.gt, operator.ge, operator.lt]: with pytest.raises( InvalidOperationError, - match="conversion from `str` to `enum` failed in column", + match="conversion from `str` to `enum` failed", ): op(s, s_compare) @@ -480,10 +471,7 @@ def test_different_enum_comparison_order() -> None: ] ) for op in [operator.gt, operator.ge, operator.lt, operator.le]: - with pytest.raises( - ComputeError, - match="can only compare categoricals of the same type", - ): + with pytest.raises(SchemaError): df_enum.filter(op(pl.col("a_cat"), pl.col("b_cat"))) @@ -535,37 +523,18 @@ def test_enum_cast_from_other_integer_dtype_oob() -> None: enum_dtype = pl.Enum(["a", "b", "c", "d"]) series = pl.Series([-1, 2, 3, 3, 2, 1], dtype=pl.Int8) with pytest.raises( - InvalidOperationError, match="conversion from `i8` to `u32` failed in column" + InvalidOperationError, match="conversion from `i8` to `enum` failed in column" ): series.cast(enum_dtype) series = pl.Series([2**34, 2, 3, 3, 2, 1], dtype=pl.UInt64) with pytest.raises( InvalidOperationError, - match="conversion from `u64` to `u32` failed in column", + match="conversion from `u64` to `enum` failed in column", ): series.cast(enum_dtype) -def test_enum_creating_col_expr() -> None: - df = pl.DataFrame( - { - "col1": ["a", "b", "c"], - "col2": ["d", "e", "f"], - "col3": ["g", "h", "i"], - }, - schema={ - "col1": pl.Enum(["a", "b", "c"]), - "col2": pl.Categorical(), - "col3": pl.Enum(["g", "h", "i"]), - }, - ) - - out = df.select(pl.col(pl.Enum)) - expected = df.select("col1", "col3") - assert_frame_equal(out, expected) - - def test_enum_cse_eq() -> None: df = pl.DataFrame({"a": [1]}) diff --git a/py-polars/tests/unit/datatypes/test_list.py b/py-polars/tests/unit/datatypes/test_list.py index b905f05a71be..098c5162bddb 100644 --- a/py-polars/tests/unit/datatypes/test_list.py +++ b/py-polars/tests/unit/datatypes/test_list.py @@ -201,19 +201,13 @@ def test_inner_type_categorical_on_rechunk() -> None: assert pl.concat([df, df], rechunk=True).dtypes == [pl.List(pl.Categorical)] -def test_local_categorical_list() -> None: +def test_categorical_list() -> None: values = [["a", "b"], ["c"], ["a", "d", "d"]] s = pl.Series(values, dtype=pl.List(pl.Categorical)) assert s.dtype == pl.List assert s.dtype.inner == pl.Categorical # type: ignore[attr-defined] assert s.to_list() == values - - # Check that underlying physicals match - idx_df = pl.Series([[0, 1], [2], [0, 3, 3]], dtype=pl.List(pl.UInt32)) - assert_series_equal(s.cast(pl.List(pl.UInt32)), idx_df) - - # Check if the categories array does not overlap - assert s.list.explode().cat.get_categories().to_list() == ["a", "b", "c", "d"] + assert s.explode().to_list() == ["a", "b", "c", "a", "d", "d"] def test_group_by_list_column() -> None: diff --git a/py-polars/tests/unit/interchange/test_column.py b/py-polars/tests/unit/interchange/test_column.py index 12b9631f40e8..abe592fe3e83 100644 --- a/py-polars/tests/unit/interchange/test_column.py +++ b/py-polars/tests/unit/interchange/test_column.py @@ -44,20 +44,9 @@ def test_describe_categorical() -> None: out = col.describe_categorical - assert out["is_ordered"] is True - assert out["is_dictionary"] is True - - expected_categories = pl.Series(["b", "a", "c"]) - assert_series_equal(out["categories"]._col, expected_categories) - - -def test_describe_categorical_lexical_ordering() -> None: - s = pl.Series(["b", "a", "a", "c", None, "b"], dtype=pl.Categorical("lexical")) - col = PolarsColumn(s) - - out = col.describe_categorical - assert out["is_ordered"] is False + assert out["is_dictionary"] is True + assert set(out["categories"]._col) >= {"b", "a", "c"} def test_describe_categorical_enum() -> None: @@ -194,11 +183,12 @@ def test_get_chunks_subdivided_chunks() -> None: pl.Series([568080000000000, 0, 1670025600000000], dtype=pl.Int64), (DtypeKind.INT, 64, "l", "="), ), - ( - pl.Series(["a", "b", None, "a"], dtype=pl.Categorical), - pl.Series([0, 1, 0, 0], dtype=pl.UInt32), - (DtypeKind.UINT, 32, "I", "="), - ), + # TODO: cat-rework: re-enable this with a unique named categorical. + # ( + # pl.Series(["a", "b", None, "a"], dtype=pl.Categorical), + # pl.Series([0, 1, 0, 0], dtype=pl.UInt32), + # (DtypeKind.UINT, 32, "I", "="), + # ), ], ) def test_get_buffers_data( @@ -275,25 +265,16 @@ def test_get_buffers_string_zero_copy_fails() -> None: col.get_buffers() -def test_get_buffers_global_categorical() -> None: - with pl.StringCache(): - _ = pl.Series("a", ["a", "b"], dtype=pl.Categorical) - s = pl.Series("a", ["c", "b"], dtype=pl.Categorical) - - # Converted to local categorical - col = PolarsColumn(s, allow_copy=True) +@pytest.mark.parametrize("allow_copy", [False, True]) +def test_get_buffers_categorical(allow_copy: bool) -> None: + s = pl.Series("a", ["c", "b"], dtype=pl.Categorical) + col = PolarsColumn(s, allow_copy=allow_copy) result = col.get_buffers() data_buffer, _ = result["data"] - expected = pl.Series("a", [0, 1], dtype=pl.UInt32) - assert_series_equal(data_buffer._data, expected) - - # Zero copy fails - col = PolarsColumn(s, allow_copy=False) - - msg = "column 'a' must be converted to a local categorical" - with pytest.raises(CopyNotAllowedError, match=msg): - col.get_buffers() + assert len(data_buffer._data) == 2 + assert data_buffer._data[0] != data_buffer._data[1] + assert data_buffer._data.dtype == pl.UInt32 def test_get_buffers_chunked_zero_copy_fails() -> None: diff --git a/py-polars/tests/unit/interchange/test_roundtrip.py b/py-polars/tests/unit/interchange/test_roundtrip.py index 131173a343ac..6d604c073399 100644 --- a/py-polars/tests/unit/interchange/test_roundtrip.py +++ b/py-polars/tests/unit/interchange/test_roundtrip.py @@ -40,7 +40,7 @@ pl.Datetime, # This is broken for empty dataframes # TODO: Enable lexically ordered categoricals - # pl.Categorical("physical"), + # pl.Categorical("lexical"), # TODO: Add Enum # pl.Enum, ] @@ -288,8 +288,3 @@ def test_from_pyarrow_str_dict_with_null_values_20270() -> None: assert_series_equal( df.to_series(), pl.Series("col1", ["A", "A", None, None, "B"], pl.Categorical) ) - assert_series_equal( - df.select(pl.col.col1.cat.get_categories()).to_series(), - pl.Series(["A", "B"]), - check_names=False, - ) diff --git a/py-polars/tests/unit/interop/test_from_pandas.py b/py-polars/tests/unit/interop/test_from_pandas.py index 1c2c2eb47641..af905a17e943 100644 --- a/py-polars/tests/unit/interop/test_from_pandas.py +++ b/py-polars/tests/unit/interop/test_from_pandas.py @@ -61,7 +61,7 @@ def test_from_pandas() -> None: "floats_nulls": pl.Float64, "strings": pl.String, "strings_nulls": pl.String, - "strings-cat": pl.Categorical(ordering="physical"), + "strings-cat": pl.Categorical(ordering="lexical"), } assert out.rows() == [ (False, None, 1, 1.0, 1.0, 1.0, "foo", "foo", "foo"), diff --git a/py-polars/tests/unit/interop/test_interop.py b/py-polars/tests/unit/interop/test_interop.py index 707b4aa0eb14..bbc1e8420d3e 100644 --- a/py-polars/tests/unit/interop/test_interop.py +++ b/py-polars/tests/unit/interop/test_interop.py @@ -80,11 +80,11 @@ def test_arrow_list_chunked_array() -> None: # Test that polars convert Arrays of logical types correctly to arrow def test_arrow_array_logical() -> None: - # cast to large string and uint32 indices because polars converts to those + # cast to large string and uint8 indices because polars converts to those pa_data1 = ( pa.array(["a", "b", "c", "d"]) .dictionary_encode() - .cast(pa.dictionary(pa.uint32(), pa.large_string())) + .cast(pa.dictionary(pa.uint8(), pa.large_string())) ) pa_array_logical1 = pa.FixedSizeListArray.from_arrays(pa_data1, 2) @@ -425,7 +425,7 @@ def test_dataframe_from_repr() -> None: assert frame.schema == { "a": pl.Int64, "b": pl.Float64, - "c": pl.Categorical(ordering="physical"), + "c": pl.Categorical(ordering="lexical"), "d": pl.Boolean, "e": pl.String, "f": pl.Date, @@ -920,19 +920,11 @@ def test_arrow_roundtrip_lex_cat_20288() -> None: def test_from_arrow_string_cache_20271() -> None: with pl.StringCache(): - s = pl.Series("a", ["A", "B", "C"], pl.Categorical) df = pl.from_arrow( pa.table({"b": pa.DictionaryArray.from_arrays([0, 1], ["D", "E"])}) ) assert isinstance(df, pl.DataFrame) - - assert_series_equal( - s.to_physical(), pl.Series("a", [0, 1, 2]), check_dtypes=False - ) assert_series_equal(df.to_series(), pl.Series("b", ["D", "E"], pl.Categorical)) - assert_series_equal( - df.to_series().to_physical(), pl.Series("b", [3, 4]), check_dtypes=False - ) def test_to_arrow_empty_chunks_20627() -> None: diff --git a/py-polars/tests/unit/interop/test_to_pandas.py b/py-polars/tests/unit/interop/test_to_pandas.py index c803df572aa4..f74670242910 100644 --- a/py-polars/tests/unit/interop/test_to_pandas.py +++ b/py-polars/tests/unit/interop/test_to_pandas.py @@ -50,14 +50,15 @@ def test_to_pandas() -> None: np.dtype(np.object_), np.dtype(np.object_), np.dtype("datetime64[us]"), - pd.CategoricalDtype(categories=["a", "b", "c"], ordered=False), - pd.CategoricalDtype(categories=["e", "f"], ordered=False), ] - assert pd_out_dtypes_expected == pd_out.dtypes.to_list() + assert pd_out_dtypes_expected == pd_out.dtypes.to_list()[:-2] + assert all( + isinstance(dt, pd.CategoricalDtype) for dt in pd_out.dtypes.to_list()[-2:] + ) pd_out_dtypes_expected[3] = np.dtype("O") pd_out = df.to_pandas(date_as_object=True) - assert pd_out_dtypes_expected == pd_out.dtypes.to_list() + assert pd_out_dtypes_expected == pd_out.dtypes.to_list()[:-2] pd_pa_out = df.to_pandas(use_pyarrow_extension_array=True) pd_pa_dtypes_names = [dtype.name for dtype in pd_pa_out.dtypes] @@ -187,7 +188,8 @@ def test_series_to_pandas_categorical(polars_dtype: PolarsDataType) -> None: s = pl.Series("x", ["a", "b", "a"], dtype=polars_dtype) result = s.to_pandas() expected = pd.Series(["a", "b", "a"], name="x", dtype="category") - pd.testing.assert_series_equal(result, expected) + assert isinstance(result.dtype, pd.CategoricalDtype) + pd.testing.assert_series_equal(result, expected, check_categorical=False) @pytest.mark.parametrize("polars_dtype", [pl.Categorical, pl.Enum(["a", "b"])]) diff --git a/py-polars/tests/unit/io/test_csv.py b/py-polars/tests/unit/io/test_csv.py index c434f212d4af..d24fececb276 100644 --- a/py-polars/tests/unit/io/test_csv.py +++ b/py-polars/tests/unit/io/test_csv.py @@ -2556,9 +2556,7 @@ def test_csv_enum_raise() -> None: ENUM_DTYPE = pl.Enum(["foo", "bar"]) with ( io.StringIO("col\nfoo\nbaz\n") as csv, - pytest.raises( - pl.exceptions.ComputeError, match="category baz doesn't exist in Enum dtype" - ), + pytest.raises(pl.exceptions.ComputeError, match="could not parse `baz`"), ): pl.read_csv( csv, diff --git a/py-polars/tests/unit/io/test_ipc.py b/py-polars/tests/unit/io/test_ipc.py index 3e3c8a06633f..e6987c11f9f1 100644 --- a/py-polars/tests/unit/io/test_ipc.py +++ b/py-polars/tests/unit/io/test_ipc.py @@ -203,9 +203,7 @@ def test_ipc_schema_from_file( "datetime": pl.Datetime(), "time": pl.Time(), "cat": pl.Categorical(), - "enum": pl.Enum( - [] - ), # at schema inference categories are not read an empty Enum is returned + "enum": pl.Enum(["foo", "ham", "bar"]), } assert schema == expected diff --git a/py-polars/tests/unit/io/test_lazy_parquet.py b/py-polars/tests/unit/io/test_lazy_parquet.py index f4259fcf0363..749094e62c34 100644 --- a/py-polars/tests/unit/io/test_lazy_parquet.py +++ b/py-polars/tests/unit/io/test_lazy_parquet.py @@ -296,7 +296,7 @@ def test_categorical(tmp_path: Path) -> None: .sort("name") ) expected = pl.DataFrame( - {"name": ["Bob", "Alice"], "amount": [400, 200]}, + {"name": ["Alice", "Bob"], "amount": [200, 400]}, schema_overrides={"name": pl.Categorical}, ) assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/io/test_parquet.py b/py-polars/tests/unit/io/test_parquet.py index 88c7582d1480..17d1baed224d 100644 --- a/py-polars/tests/unit/io/test_parquet.py +++ b/py-polars/tests/unit/io/test_parquet.py @@ -2718,19 +2718,12 @@ def test_parquet_roundtrip_lex_cat_20288() -> None: def test_from_parquet_string_cache_20271() -> None: with pl.StringCache(): f = io.BytesIO() - s = pl.Series("a", ["A", "B", "C"], pl.Categorical) df = pl.Series("b", ["D", "E"], pl.Categorical).to_frame() df.write_parquet(f) + del df f.seek(0) df = pl.read_parquet(f) - - assert_series_equal( - s.to_physical(), pl.Series("a", [0, 1, 2]), check_dtypes=False - ) assert_series_equal(df.to_series(), pl.Series("b", ["D", "E"], pl.Categorical)) - assert_series_equal( - df.to_series().to_physical(), pl.Series("b", [3, 4]), check_dtypes=False - ) def test_boolean_slice_pushdown_20314() -> None: diff --git a/py-polars/tests/unit/operations/namespaces/test_categorical.py b/py-polars/tests/unit/operations/namespaces/test_categorical.py index 3ca4df90a109..c262b38d7961 100644 --- a/py-polars/tests/unit/operations/namespaces/test_categorical.py +++ b/py-polars/tests/unit/operations/namespaces/test_categorical.py @@ -63,18 +63,6 @@ def test_categorical_lexical_ordering_after_concat() -> None: } -@pytest.mark.usefixtures("test_global_and_local") -@pytest.mark.may_fail_auto_streaming -def test_sort_categoricals_6014_internal() -> None: - # create basic categorical - df = pl.DataFrame({"key": ["bbb", "aaa", "ccc"]}).with_columns( - pl.col("key").cast(pl.Categorical) - ) - - out = df.sort("key") - assert out.to_dict(as_series=False) == {"key": ["bbb", "aaa", "ccc"]} - - @pytest.mark.usefixtures("test_global_and_local") def test_sort_categoricals_6014_lexical() -> None: # create lexically-ordered categorical @@ -88,71 +76,25 @@ def test_sort_categoricals_6014_lexical() -> None: @pytest.mark.usefixtures("test_global_and_local") def test_categorical_get_categories() -> None: - assert pl.Series( - "cats", ["foo", "bar", "foo", "foo", "ham"], dtype=pl.Categorical - ).cat.get_categories().to_list() == ["foo", "bar", "ham"] + s = pl.Series("cats", ["foo", "bar", "foo", "foo", "ham"], dtype=pl.Categorical) + assert set(s.cat.get_categories().to_list()) >= {"foo", "bar", "ham"} def test_cat_to_local() -> None: - with pl.StringCache(): - s1 = pl.Series(["a", "b", "a"], dtype=pl.Categorical) - s2 = pl.Series(["c", "b", "d"], dtype=pl.Categorical) - - # s2 physical starts after s1 - assert s1.to_physical().to_list() == [0, 1, 0] - assert s2.to_physical().to_list() == [2, 1, 3] - - out = s2.cat.to_local() - - # Physical has changed and now starts at 0, string values are the same - assert out.cat.is_local() - assert out.to_physical().to_list() == [0, 1, 2] - assert out.to_list() == s2.to_list() - - # s2 should be unchanged after the operation - assert not s2.cat.is_local() - assert s2.to_physical().to_list() == [2, 1, 3] - assert s2.to_list() == ["c", "b", "d"] - - -def test_cat_to_local_missing_values() -> None: - with pl.StringCache(): - _ = pl.Series(["a", "b"], dtype=pl.Categorical) - s = pl.Series(["c", "b", None, "d"], dtype=pl.Categorical) - - out = s.cat.to_local() - assert out.to_physical().to_list() == [0, 1, None, 2] - - -def test_cat_to_local_already_local() -> None: - s = pl.Series(["a", "c", "a", "b"], dtype=pl.Categorical) - - assert s.cat.is_local() - out = s.cat.to_local() - - assert out.to_physical().to_list() == [0, 1, 0, 2] - assert out.to_list() == ["a", "c", "a", "b"] - - -def test_cat_is_local() -> None: - s = pl.Series(["a", "c", "a", "b"], dtype=pl.Categorical) - assert s.cat.is_local() - - with pl.StringCache(): - s2 = pl.Series(["a", "b", "a"], dtype=pl.Categorical) - assert not s2.cat.is_local() + s = pl.Series(["a", "b", "a"], dtype=pl.Categorical) + assert_series_equal(s, s.cat.to_local()) @pytest.mark.usefixtures("test_global_and_local") def test_cat_uses_lexical_ordering() -> None: s = pl.Series(["a", "b", None, "b"]).cast(pl.Categorical) - assert s.cat.uses_lexical_ordering() is False + assert s.cat.uses_lexical_ordering() s = s.cast(pl.Categorical("lexical")) - assert s.cat.uses_lexical_ordering() is True + assert s.cat.uses_lexical_ordering() - s = s.cast(pl.Categorical("physical")) - assert s.cat.uses_lexical_ordering() is False + s = s.cast(pl.Categorical("physical")) # Deprecated. + assert s.cat.uses_lexical_ordering() @pytest.mark.usefixtures("test_global_and_local") diff --git a/py-polars/tests/unit/operations/test_cast.py b/py-polars/tests/unit/operations/test_cast.py index c3df06142d9b..967193bc94f8 100644 --- a/py-polars/tests/unit/operations/test_cast.py +++ b/py-polars/tests/unit/operations/test_cast.py @@ -561,11 +561,6 @@ def test_strict_cast_string( @pytest.mark.parametrize( "dtype_out", [ - *INTEGER_DTYPES, - pl.Date, - pl.Datetime, - pl.Time, - pl.Duration, pl.String, pl.Categorical, pl.Enum(["1", "2"]), @@ -764,14 +759,6 @@ def test_overflowing_cast_literals_21023() -> None: ) -def test_invalid_empty_cast_to_empty_enum() -> None: - with pytest.raises( - InvalidOperationError, - match="cannot cast / initialize Enum without categories present", - ): - pl.Series([], dtype=pl.Enum) - - @pytest.mark.parametrize("value", [True, False]) @pytest.mark.parametrize( "dtype", diff --git a/py-polars/tests/unit/operations/test_cut.py b/py-polars/tests/unit/operations/test_cut.py index a8e79bc7271f..9f1598ca2f21 100644 --- a/py-polars/tests/unit/operations/test_cut.py +++ b/py-polars/tests/unit/operations/test_cut.py @@ -88,25 +88,21 @@ def test_cut_bin_name_in_agg_context() -> None: qcut=pl.col("a").qcut([1], include_breaks=True).over(1), qcut_uniform=pl.col("a").qcut(1, include_breaks=True).over(1), ) - schema = pl.Struct( - {"breakpoint": pl.Float64, "category": pl.Categorical("physical")} - ) + schema = pl.Struct({"breakpoint": pl.Float64, "category": pl.Categorical()}) assert df.schema == {"cut": schema, "qcut": schema, "qcut_uniform": schema} @pytest.mark.parametrize( - ("breaks", "expected_labels", "expected_physical", "expected_unique"), + ("breaks", "expected_labels", "expected_unique"), [ ( [2, 4], pl.Series("x", ["(-inf, 2]", "(-inf, 2]", "(2, 4]", "(2, 4]", "(4, inf]"]), - pl.Series("x", [0, 0, 1, 1, 2], dtype=pl.UInt32), 3, ), ( [99, 101], pl.Series("x", 5 * ["(-inf, 99]"]), - pl.Series("x", 5 * [0], dtype=pl.UInt32), 1, ), ], @@ -115,7 +111,6 @@ def test_cut_bin_name_in_agg_context() -> None: def test_cut_fast_unique_15981( breaks: list[int], expected_labels: pl.Series, - expected_physical: pl.Series, expected_unique: int, ) -> None: s = pl.Series("x", [1, 2, 3, 4, 5]) @@ -124,7 +119,6 @@ def test_cut_fast_unique_15981( s_cut = s.cut(breaks, include_breaks=include_breaks) assert_series_equal(s_cut.cast(pl.String), expected_labels) - assert_series_equal(s_cut.to_physical(), expected_physical) assert s_cut.n_unique() == s_cut.to_physical().n_unique() == expected_unique s_cut.to_frame().group_by(s.name).len() @@ -134,6 +128,5 @@ def test_cut_fast_unique_15981( ) assert_series_equal(s_cut.cast(pl.String), expected_labels) - assert_series_equal(s_cut.to_physical(), expected_physical) assert s_cut.n_unique() == s_cut.to_physical().n_unique() == expected_unique s_cut.to_frame().group_by(s.name).len() diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py index decdf0c5ece9..76668849e153 100644 --- a/py-polars/tests/unit/operations/test_group_by.py +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -408,7 +408,7 @@ def test_group_by_sorted_empty_dataframe_3680() -> None: ) assert df.rows() == [] assert df.shape == (0, 2) - assert df.schema == {"key": pl.Categorical(ordering="physical"), "val": pl.Float64} + assert df.schema == {"key": pl.Categorical(ordering="lexical"), "val": pl.Float64} def test_group_by_custom_agg_empty_list() -> None: diff --git a/py-polars/tests/unit/operations/test_index_of.py b/py-polars/tests/unit/operations/test_index_of.py index 848bcd952ae7..b092b2dfe278 100644 --- a/py-polars/tests/unit/operations/test_index_of.py +++ b/py-polars/tests/unit/operations/test_index_of.py @@ -327,12 +327,7 @@ def test_enum(convert_to_literal: bool) -> None: @pytest.mark.parametrize( "convert_to_literal", [ - pytest.param( - True, - marks=pytest.mark.xfail( - reason="https://github.com/pola-rs/polars/issues/20318" - ), - ), + True, pytest.param( False, marks=pytest.mark.xfail( diff --git a/py-polars/tests/unit/operations/test_merge_sorted.py b/py-polars/tests/unit/operations/test_merge_sorted.py index 2ef2441fefd4..58c6d7b0ca0a 100644 --- a/py-polars/tests/unit/operations/test_merge_sorted.py +++ b/py-polars/tests/unit/operations/test_merge_sorted.py @@ -4,7 +4,6 @@ from hypothesis import given import polars as pl -from polars.exceptions import ComputeError from polars.testing import assert_frame_equal, assert_series_equal from polars.testing.parametric import series @@ -60,7 +59,6 @@ def test_merge_sorted_decimal_20990(precision: int) -> None: assert_series_equal(result, expected) -@pytest.mark.may_fail_auto_streaming def test_merge_sorted_categorical() -> None: left = pl.Series("a", ["a", "b"], pl.Categorical()).sort().to_frame() right = pl.Series("a", ["a", "b", "b"], pl.Categorical()).sort().to_frame() @@ -69,10 +67,8 @@ def test_merge_sorted_categorical() -> None: assert_series_equal(result, expected) right = pl.Series("a", ["b", "a"], pl.Categorical()).sort().to_frame() - with pytest.raises( - ComputeError, match="can only merge-sort categoricals with the same categories" - ): - left.merge_sorted(right, "a") + expected = pl.Series("a", ["a", "a", "b", "b"], pl.Categorical()) + assert_frame_equal(left.merge_sorted(right, "a"), expected.to_frame()) @pytest.mark.may_fail_auto_streaming @@ -235,39 +231,6 @@ def test_merge_time() -> None: assert df.merge_sorted(df, "a").get_column("a").dtype == pl.Time() -@pytest.mark.may_fail_auto_streaming -def test_merge_sorted_invalid_categorical_local() -> None: - df1 = pl.DataFrame({"a": pl.Series(["a", "b", "c"], dtype=pl.Categorical)}) - df2 = pl.DataFrame({"a": pl.Series(["a", "b", "d"], dtype=pl.Categorical)}) - - with pytest.raises( - ComputeError, match="can only merge-sort categoricals with the same categories" - ): - df1.merge_sorted(df2, key="a") - - -@pytest.mark.may_fail_auto_streaming -def test_merge_sorted_categorical_global_physical() -> None: - with pl.StringCache(): - df1 = pl.DataFrame( - {"a": pl.Series(["e", "a", "f"], dtype=pl.Categorical("physical"))} - ) - df2 = pl.DataFrame( - {"a": pl.Series(["a", "c", "d"], dtype=pl.Categorical("physical"))} - ) - expected = pl.DataFrame( - { - "a": pl.Series( - (["e", "a", "a", "f", "c", "d"]), - dtype=pl.Categorical("physical"), - ) - } - ) - result = df1.merge_sorted(df2, key="a") - assert_frame_equal(result, expected) - - -@pytest.mark.may_fail_auto_streaming def test_merge_sorted_categorical_global_lexical() -> None: with pl.StringCache(): df1 = pl.DataFrame( diff --git a/py-polars/tests/unit/operations/test_replace_strict.py b/py-polars/tests/unit/operations/test_replace_strict.py index 276e41f473de..88866c8ae205 100644 --- a/py-polars/tests/unit/operations/test_replace_strict.py +++ b/py-polars/tests/unit/operations/test_replace_strict.py @@ -1,12 +1,11 @@ from __future__ import annotations -import contextlib from typing import Any import pytest import polars as pl -from polars.exceptions import CategoricalRemappingWarning, InvalidOperationError +from polars.exceptions import InvalidOperationError from polars.testing import assert_frame_equal, assert_series_equal @@ -343,63 +342,46 @@ def test_replace_strict_str_to_int() -> None: assert_series_equal(result, expected) -@pytest.mark.parametrize( - ("context", "dtype"), - [ - (pl.StringCache(), pl.Categorical), - (pytest.warns(CategoricalRemappingWarning), pl.Categorical), - (contextlib.nullcontext(), pl.Enum(["a", "b", "OTHER"])), - ], -) -@pytest.mark.may_fail_auto_streaming +@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum(["a", "b", "OTHER"])]) def test_replace_strict_cat_str( - context: contextlib.AbstractContextManager, # type: ignore[type-arg] dtype: pl.DataType, ) -> None: - with context: - for old, new, expected in [ - ("a", "c", pl.Series("s", ["c", None], dtype=pl.Utf8)), - (["a", "b"], ["c", "d"], pl.Series("s", ["c", "d"], dtype=pl.Utf8)), - (pl.lit("a", dtype=dtype), "c", pl.Series("s", ["c", None], dtype=pl.Utf8)), - ( - pl.Series(["a", "b"], dtype=dtype), - ["c", "d"], - pl.Series("s", ["c", "d"], dtype=pl.Utf8), - ), - ]: - s = pl.Series("s", ["a", "b"], dtype=dtype) - s_replaced = s.replace_strict(old, new, default=None) # type: ignore[arg-type] - assert_series_equal(s_replaced, expected) - - s = pl.Series("s", ["a", "b"], dtype=dtype) - s_replaced = s.replace_strict(old, new, default="OTHER") # type: ignore[arg-type] - assert_series_equal(s_replaced, expected.fill_null("OTHER")) - - -@pytest.mark.parametrize( - "context", [pl.StringCache(), pytest.warns(CategoricalRemappingWarning)] -) -@pytest.mark.may_fail_auto_streaming -def test_replace_strict_cat_cat( - context: contextlib.AbstractContextManager, # type: ignore[type-arg] -) -> None: - with context: - dt = pl.Categorical - for old, new, expected in [ - ("a", pl.lit("c", dtype=dt), pl.Series("s", ["c", None], dtype=dt)), - ( - ["a", "b"], - pl.Series(["c", "d"], dtype=dt), - pl.Series("s", ["c", "d"], dtype=dt), - ), - ]: - s = pl.Series("s", ["a", "b"], dtype=dt) - s_replaced = s.replace_strict(old, new, default=None) # type: ignore[arg-type] - assert_series_equal(s_replaced, expected) - - s = pl.Series("s", ["a", "b"], dtype=dt) - s_replaced = s.replace_strict(old, new, default=pl.lit("OTHER", dtype=dt)) # type: ignore[arg-type] - assert_series_equal(s_replaced, expected.fill_null("OTHER")) + for old, new, expected in [ + ("a", "c", pl.Series("s", ["c", None], dtype=pl.Utf8)), + (["a", "b"], ["c", "d"], pl.Series("s", ["c", "d"], dtype=pl.Utf8)), + (pl.lit("a", dtype=dtype), "c", pl.Series("s", ["c", None], dtype=pl.Utf8)), + ( + pl.Series(["a", "b"], dtype=dtype), + ["c", "d"], + pl.Series("s", ["c", "d"], dtype=pl.Utf8), + ), + ]: + s = pl.Series("s", ["a", "b"], dtype=dtype) + s_replaced = s.replace_strict(old, new, default=None) # type: ignore[arg-type] + assert_series_equal(s_replaced, expected) + + s = pl.Series("s", ["a", "b"], dtype=dtype) + s_replaced = s.replace_strict(old, new, default="OTHER") # type: ignore[arg-type] + assert_series_equal(s_replaced, expected.fill_null("OTHER")) + + +def test_replace_strict_cat_cat() -> None: + dt = pl.Categorical + for old, new, expected in [ + ("a", pl.lit("c", dtype=dt), pl.Series("s", ["c", None], dtype=dt)), + ( + ["a", "b"], + pl.Series(["c", "d"], dtype=dt), + pl.Series("s", ["c", "d"], dtype=dt), + ), + ]: + s = pl.Series("s", ["a", "b"], dtype=dt) + s_replaced = s.replace_strict(old, new, default=None) # type: ignore[arg-type] + assert_series_equal(s_replaced, expected) + + s = pl.Series("s", ["a", "b"], dtype=dt) + s_replaced = s.replace_strict(old, new, default=pl.lit("OTHER", dtype=dt)) # type: ignore[arg-type] + assert_series_equal(s_replaced, expected.fill_null("OTHER")) def test_replace_strict_single_argument_not_mapping() -> None: diff --git a/py-polars/tests/unit/operations/test_sets.py b/py-polars/tests/unit/operations/test_sets.py index 8b13de7e4fd4..cfc92ec9bf03 100644 --- a/py-polars/tests/unit/operations/test_sets.py +++ b/py-polars/tests/unit/operations/test_sets.py @@ -3,7 +3,6 @@ import pytest import polars as pl -from polars.exceptions import CategoricalRemappingWarning from polars.testing import assert_series_equal @@ -47,26 +46,24 @@ def test_set_intersection_st_17129() -> None: ), ], ) -@pytest.mark.may_fail_auto_streaming def test_set_operations_cats(set_operation: str, outcome: list[set[str]]) -> None: - with pytest.warns(CategoricalRemappingWarning): - df = pl.DataFrame( - { - "a": [ - ["z1", "x", "y", "z"], - ["y", "z"], - ["x", "y"], - ["x", "y", "z", "x2"], - ["z", "x3"], - ] - }, - schema={"a": pl.List(pl.Categorical)}, - ) - df = df.with_columns( - getattr(pl.col("a").list, set_operation)(["x", "y"]).alias("b") - ) - assert df.get_column("b").dtype == pl.List(pl.Categorical) - assert [set(el) for el in df["b"].to_list()] == outcome + df = pl.DataFrame( + { + "a": [ + ["z1", "x", "y", "z"], + ["y", "z"], + ["x", "y"], + ["x", "y", "z", "x2"], + ["z", "x3"], + ] + }, + schema={"a": pl.List(pl.Categorical)}, + ) + df = df.with_columns( + getattr(pl.col("a").list, set_operation)(["x", "y"]).alias("b") + ) + assert df.get_column("b").dtype == pl.List(pl.Categorical) + assert [set(el) for el in df["b"].to_list()] == outcome def test_set_invalid_types() -> None: diff --git a/py-polars/tests/unit/operations/test_sort.py b/py-polars/tests/unit/operations/test_sort.py index d82de79a7b2a..d02249e241d5 100644 --- a/py-polars/tests/unit/operations/test_sort.py +++ b/py-polars/tests/unit/operations/test_sort.py @@ -1174,7 +1174,6 @@ def test_sort_bool_nulls_last() -> None: [ pl.Enum(["a", "b"]), pl.Categorical(ordering="lexical"), - pl.Categorical(ordering="physical"), ], ) def test_sort_cat_nulls_last(dtype: PolarsDataType) -> None: diff --git a/py-polars/tests/unit/operations/test_transpose.py b/py-polars/tests/unit/operations/test_transpose.py index 8bc691ed6acd..2c0458b4a597 100644 --- a/py-polars/tests/unit/operations/test_transpose.py +++ b/py-polars/tests/unit/operations/test_transpose.py @@ -8,7 +8,6 @@ from polars.exceptions import ( InvalidOperationError, SchemaError, - StringCacheMismatchError, ) from polars.testing import assert_frame_equal, assert_series_equal @@ -142,41 +141,23 @@ def name_generator() -> Iterator[str]: @pytest.mark.may_fail_auto_streaming def test_transpose_categorical_data() -> None: - with pl.StringCache(): - df = pl.DataFrame( - [ - pl.Series(["a", "b", "c"], dtype=pl.Categorical), - pl.Series(["c", "g", "c"], dtype=pl.Categorical), - pl.Series(["d", "b", "c"], dtype=pl.Categorical), - ] - ) - df_transposed = df.transpose( - include_header=False, column_names=["col1", "col2", "col3"] - ) - assert_series_equal( - df_transposed.get_column("col1"), - pl.Series("col1", ["a", "c", "d"], dtype=pl.Categorical), - ) - - # Without string Cache only works if they have the same categories in the same order df = pl.DataFrame( [ - pl.Series(["a", "b", "c", "c"], dtype=pl.Categorical), - pl.Series(["a", "b", "b", "c"], dtype=pl.Categorical), - pl.Series(["a", "a", "b", "c"], dtype=pl.Categorical), + pl.Series(["a", "b", "c", "d"], dtype=pl.Categorical), + pl.Series(["c", "g", "c", "d"], dtype=pl.Categorical), + pl.Series(["d", "b", "c", "d"], dtype=pl.Categorical), ] ) - df_transposed = df.transpose( - include_header=False, column_names=["col1", "col2", "col3", "col4"] + df_transposed = df.transpose(include_header=False) + expected = pl.DataFrame( + [ + pl.Series(["a", "c", "d"], dtype=pl.Categorical), + pl.Series(["b", "g", "b"], dtype=pl.Categorical), + pl.Series(["c", "c", "c"], dtype=pl.Categorical), + pl.Series(["d", "d", "d"], dtype=pl.Categorical), + ] ) - - with pytest.raises(StringCacheMismatchError): - pl.DataFrame( - [ - pl.Series(["a", "b", "c", "c"], dtype=pl.Categorical), - pl.Series(["c", "b", "b", "c"], dtype=pl.Categorical), - ] - ).transpose() + assert_frame_equal(df_transposed, expected) @pytest.mark.may_fail_auto_streaming diff --git a/py-polars/tests/unit/operations/test_unpivot.py b/py-polars/tests/unit/operations/test_unpivot.py index 26cfb98d97b4..a5c96a64b99c 100644 --- a/py-polars/tests/unit/operations/test_unpivot.py +++ b/py-polars/tests/unit/operations/test_unpivot.py @@ -2,7 +2,6 @@ import polars as pl import polars.selectors as cs -from polars import StringCache from polars.testing import assert_frame_equal @@ -97,8 +96,7 @@ def test_unpivot_empty_18170() -> None: ) -@StringCache() -def test_unpivot_categorical_global() -> None: +def test_unpivot_categorical() -> None: df = pl.DataFrame( { "index": [0, 1], @@ -107,15 +105,9 @@ def test_unpivot_categorical_global() -> None: } ) out = df.unpivot(["1", "2"], index="index") - assert out.dtypes == [pl.Int64, pl.String, pl.Categorical(ordering="physical")] + assert out.dtypes == [pl.Int64, pl.String, pl.Categorical(ordering="lexical")] assert out.to_dict(as_series=False) == { "index": [0, 1, 0, 1], "variable": ["1", "1", "2", "2"], "value": ["a", "b", "b", "c"], } - - -@pytest.mark.may_fail_auto_streaming -def test_unpivot_categorical_raise_19770() -> None: - with pytest.raises(pl.exceptions.ComputeError): - (pl.DataFrame({"x": ["foo"]}).cast(pl.Categorical).unpivot()) diff --git a/py-polars/tests/unit/operations/unique/test_unique.py b/py-polars/tests/unit/operations/unique/test_unique.py index f26ba6ee1597..3ae74086a796 100644 --- a/py-polars/tests/unit/operations/unique/test_unique.py +++ b/py-polars/tests/unit/operations/unique/test_unique.py @@ -157,16 +157,6 @@ def test_unique_categorical(input: list[str | None], output: list[str | None]) - assert_series_equal(result, expected, check_order=False) -def test_unique_categorical_global() -> None: - with pl.StringCache(): - pl.Series(["aaaa", "bbbb", "cccc"]) # pre-fill global cache - s = pl.Series(["a", "b", "c"], dtype=pl.Categorical) - s_empty = s.slice(0, 0) - - assert s_empty.unique().to_list() == [] - assert_series_equal(s_empty.cat.get_categories(), pl.Series(["a", "b", "c"])) - - def test_unique_with_null() -> None: df = pl.DataFrame( { diff --git a/py-polars/tests/unit/series/buffers/test_from_buffers.py b/py-polars/tests/unit/series/buffers/test_from_buffers.py index 43185b50a757..5840b8860710 100644 --- a/py-polars/tests/unit/series/buffers/test_from_buffers.py +++ b/py-polars/tests/unit/series/buffers/test_from_buffers.py @@ -124,7 +124,7 @@ def test_series_from_buffers_string() -> None: def test_series_from_buffers_enum() -> None: dtype = pl.Enum(["a", "b", "c"]) - data = pl.Series([0, 1, 0, 2], dtype=pl.UInt32) + data = pl.Series([0, 1, 0, 2], dtype=pl.UInt8) validity = pl.Series([True, True, False, True]) result = pl.Series._from_buffers(dtype, data=data, validity=validity) diff --git a/py-polars/tests/unit/series/test_scatter.py b/py-polars/tests/unit/series/test_scatter.py index 843fd1727e7d..69dafea3b402 100644 --- a/py-polars/tests/unit/series/test_scatter.py +++ b/py-polars/tests/unit/series/test_scatter.py @@ -94,3 +94,31 @@ def test_scatter_logical_all_null() -> None: result = s.scatter(0, date(2022, 2, 2)) expected = pl.Series("dt", [date(2022, 2, 2), None]) assert_series_equal(result, expected) + + +def test_scatter_categorical_21175() -> None: + s = pl.Series(["a", "b", "c"], dtype=pl.Categorical) + assert_series_equal( + s.scatter(0, "b"), pl.Series(["b", "b", "c"], dtype=pl.Categorical) + ) + v = pl.Series(["v"], dtype=pl.Categorical) + assert_series_equal( + s.scatter([0, 2], v), pl.Series(["v", "b", "v"], dtype=pl.Categorical) + ) + + with pytest.raises(InvalidOperationError): + s.scatter(1, 2) + + +def test_scatter_enum() -> None: + e = pl.Enum(["a", "b", "c", "v"]) + s = pl.Series(["a", "b", "c"], dtype=e) + assert_series_equal(s.scatter(0, "b"), pl.Series(["b", "b", "c"], dtype=e)) + v = pl.Series(["v"], dtype=pl.Categorical) + assert_series_equal(s.scatter([0, 2], v), pl.Series(["v", "b", "v"], dtype=e)) + + with pytest.raises(InvalidOperationError): + s.scatter(1, "d") + + with pytest.raises(InvalidOperationError): + s.scatter(1, 2) diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index f414d89f1fa9..4ce41eb480a5 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -357,9 +357,7 @@ def test_date_agg() -> None: ("s", "min", "max"), [ (pl.Series(["c", "b", "a"], dtype=pl.Categorical("lexical")), "a", "c"), - (pl.Series(["a", "c", "b"], dtype=pl.Categorical), "a", "b"), (pl.Series([None, "a", "c", "b"], dtype=pl.Categorical("lexical")), "a", "c"), - (pl.Series([None, "c", "a", "b"], dtype=pl.Categorical), "c", "b"), (pl.Series([], dtype=pl.Categorical("lexical")), None, None), (pl.Series(["c", "b", "a"], dtype=pl.Enum(["c", "b", "a"])), "c", "a"), (pl.Series(["c", "b", "a"], dtype=pl.Enum(["c", "b", "a", "d"])), "c", "a"), @@ -634,25 +632,27 @@ def test_arrow() -> None: ) assert s.dtype == pl.List + +def test_arrow_cat() -> None: # categorical dtype tests (including various forms of empty pyarrow array) - with pl.StringCache(): - arr0 = pa.array(["foo", "bar"], pa.dictionary(pa.int32(), pa.utf8())) - assert_series_equal( - pl.Series("arr", ["foo", "bar"], pl.Categorical), pl.Series("arr", arr0) - ) - arr1 = pa.array(["xxx", "xxx", None, "yyy"]).dictionary_encode() - arr2 = pa.array([]).dictionary_encode() - arr3 = pa.chunked_array([], arr1.type) - arr4 = pa.array([], arr1.type) + arr0 = pa.array(["foo", "bar"], pa.dictionary(pa.int32(), pa.utf8())) + assert_series_equal( + pl.Series("arr", ["foo", "bar"], pl.Categorical), pl.Series("arr", arr0) + ) + arr1 = pa.array(["xxx", "xxx", None, "yyy"]).dictionary_encode() + arr2 = pa.chunked_array([], arr1.type) + arr3 = pa.array([], arr1.type) + arr4 = pa.array([]).dictionary_encode() + assert_series_equal( + pl.Series("arr", ["xxx", "xxx", None, "yyy"], dtype=pl.Categorical), + pl.Series("arr", arr1), + ) + for arr in (arr2, arr3): assert_series_equal( - pl.Series("arr", ["xxx", "xxx", None, "yyy"], dtype=pl.Categorical), - pl.Series("arr", arr1), + pl.Series("arr", [], dtype=pl.Categorical), pl.Series("arr", arr) ) - for arr in (arr2, arr3, arr4): - assert_series_equal( - pl.Series("arr", [], dtype=pl.Categorical), pl.Series("arr", arr) - ) + assert_series_equal(pl.Series("arr", [], dtype=pl.Null), pl.Series("arr", arr4)) def test_pycapsule_interface() -> None: @@ -1454,8 +1454,6 @@ def test_arg_sort() -> None: (pl.Series(["a", "c", "b"]), 0, 1), (pl.Series([None, "a", None, "b"]), 1, 3), # Categorical - (pl.Series(["c", "b", "a"], dtype=pl.Categorical), 0, 2), - (pl.Series([None, "c", "b", None, "a"], dtype=pl.Categorical), 1, 4), (pl.Series(["c", "b", "a"], dtype=pl.Categorical(ordering="lexical")), 2, 0), (pl.Series("s", [None, "c", "b", None, "a"], pl.Categorical("lexical")), 4, 1), ], @@ -1694,13 +1692,19 @@ def test_to_physical() -> None: # casting a categorical results in a UInt32 s = pl.Series(["cat1"]).cast(pl.Categorical) - expected = pl.Series([0], dtype=UInt32) - assert_series_equal(s.to_physical(), expected) + assert s.to_physical().dtype == pl.UInt32 + + # casting a small enum results in a UInt8 + s = pl.Series(["cat1"]).cast(pl.Enum(["cat1"])) + assert s.to_physical().dtype == pl.UInt8 # casting a List(Categorical) results in a List(UInt32) s = pl.Series([["cat1"]]).cast(pl.List(pl.Categorical)) - expected = pl.Series([[0]], dtype=pl.List(UInt32)) - assert_series_equal(s.to_physical(), expected) + assert s.to_physical().dtype == pl.List(pl.UInt32) + + # casting a List(Enum) with a small enum results in a List(UInt8) + s = pl.Series(["cat1"]).cast(pl.List(pl.Enum(["cat1"]))) + assert s.to_physical().dtype == pl.List(pl.UInt8) def test_to_physical_rechunked_21285() -> None: diff --git a/py-polars/tests/unit/streaming/test_streaming_categoricals.py b/py-polars/tests/unit/streaming/test_streaming_categoricals.py index 679e133412b0..01d9a72465fa 100644 --- a/py-polars/tests/unit/streaming/test_streaming_categoricals.py +++ b/py-polars/tests/unit/streaming/test_streaming_categoricals.py @@ -26,7 +26,7 @@ def test_streaming_cat_14933() -> None: df2 = pl.LazyFrame( [ pl.Series("a", [0, 1], dtype=pl.UInt32), - pl.Series("l", [None, None], dtype=pl.Categorical(ordering="physical")), + pl.Series("l", [None, None], dtype=pl.Categorical()), ] ) result = df1.join(df2, on="a", how="left") diff --git a/py-polars/tests/unit/test_row_encoding_sort.py b/py-polars/tests/unit/test_row_encoding_sort.py index 335ac7159ab4..08260ca51427 100644 --- a/py-polars/tests/unit/test_row_encoding_sort.py +++ b/py-polars/tests/unit/test_row_encoding_sort.py @@ -38,9 +38,7 @@ def elem_order_sign( if isinstance(lhs, pl.Series) and isinstance(rhs, pl.Series): assert lhs.dtype == rhs.dtype - if isinstance(lhs.dtype, pl.Enum) or lhs.dtype == pl.Categorical( - ordering="physical" - ): + if isinstance(lhs.dtype, pl.Enum): lhs = cast(Element, lhs.to_physical()) rhs = cast(Element, rhs.to_physical()) assert isinstance(lhs, pl.Series) diff --git a/py-polars/tests/unit/test_schema.py b/py-polars/tests/unit/test_schema.py index 13ff15fdfc67..69ba47d566f1 100644 --- a/py-polars/tests/unit/test_schema.py +++ b/py-polars/tests/unit/test_schema.py @@ -33,7 +33,7 @@ def test_schema() -> None: pl.Schema( { "foo": pl.UInt32(), - "bar": pl.Categorical("physical"), + "bar": pl.Categorical(), "baz": pl.Struct({"x": pl.Int64(), "y": pl.Float64()}), } ), diff --git a/py-polars/tests/unit/test_string_cache.py b/py-polars/tests/unit/test_string_cache.py deleted file mode 100644 index 3fa20730b57e..000000000000 --- a/py-polars/tests/unit/test_string_cache.py +++ /dev/null @@ -1,185 +0,0 @@ -from collections.abc import Iterator - -import pytest - -import polars as pl -from polars.exceptions import CategoricalRemappingWarning -from polars.testing import assert_frame_equal - - -@pytest.fixture(autouse=True) -def _disable_string_cache() -> Iterator[None]: - """Fixture to make sure the string cache is disabled before and after each test.""" - pl.disable_string_cache() - yield - pl.disable_string_cache() - - -def sc(set: bool) -> None: - """Short syntax for asserting whether the global string cache is being used.""" - assert pl.using_string_cache() is set - - -def test_string_cache_enable_disable() -> None: - sc(False) - pl.enable_string_cache() - sc(True) - pl.disable_string_cache() - sc(False) - - -def test_string_cache_enable_disable_repeated() -> None: - sc(False) - pl.enable_string_cache() - sc(True) - pl.enable_string_cache() - sc(True) - pl.disable_string_cache() - sc(False) - pl.disable_string_cache() - sc(False) - - -def test_string_cache_context_manager() -> None: - sc(False) - with pl.StringCache(): - sc(True) - sc(False) - - -def test_string_cache_context_manager_nested() -> None: - sc(False) - with pl.StringCache(): - sc(True) - with pl.StringCache(): - sc(True) - sc(True) - sc(False) - - -def test_string_cache_context_manager_mixed_with_enable_disable() -> None: - sc(False) - with pl.StringCache(): - sc(True) - pl.enable_string_cache() - sc(True) - sc(True) - - with pl.StringCache(): - sc(True) - sc(True) - - with pl.StringCache(): - sc(True) - with pl.StringCache(): - sc(True) - pl.disable_string_cache() - sc(True) - sc(True) - sc(False) - - with pl.StringCache(): - sc(True) - pl.disable_string_cache() - sc(True) - sc(False) - - -def test_string_cache_decorator() -> None: - @pl.StringCache() - def my_function() -> None: - sc(True) - - sc(False) - my_function() - sc(False) - - -def test_string_cache_decorator_mixed_with_enable() -> None: - @pl.StringCache() - def my_function() -> None: - sc(True) - pl.enable_string_cache() - sc(True) - - sc(False) - my_function() - sc(True) - - -@pytest.mark.may_fail_auto_streaming -def test_string_cache_join() -> None: - df1 = pl.DataFrame({"a": ["foo", "bar", "ham"], "b": [1, 2, 3]}) - df2 = pl.DataFrame({"a": ["eggs", "spam", "foo"], "c": [2, 2, 3]}) - - # ensure cache is off when casting to categorical; the join will fail - pl.disable_string_cache() - assert pl.using_string_cache() is False - - with pytest.warns( - CategoricalRemappingWarning, - match="Local categoricals have different encodings", - ): - df1a = df1.with_columns(pl.col("a").cast(pl.Categorical)) - df2a = df2.with_columns(pl.col("a").cast(pl.Categorical)) - out = df1a.join(df2a, on="a", how="inner") - - expected = pl.DataFrame( - {"a": ["foo"], "b": [1], "c": [3]}, schema_overrides={"a": pl.Categorical} - ) - - # Can not do equality checks on local categoricals with different categories - assert_frame_equal(out, expected, categorical_as_str=True) - - # now turn on the cache - pl.enable_string_cache() - assert pl.using_string_cache() is True - - df1b = df1.with_columns(pl.col("a").cast(pl.Categorical)) - df2b = df2.with_columns(pl.col("a").cast(pl.Categorical)) - out = df1b.join(df2b, on="a", how="inner") - - expected = pl.DataFrame( - {"a": ["foo"], "b": [1], "c": [3]}, schema_overrides={"a": pl.Categorical} - ) - assert_frame_equal(out, expected) - - -def test_string_cache_eager_lazy() -> None: - # tests if the global string cache is really global and not interfered by the lazy - # execution. first the global settings was thread-local and this breaks with the - # parallel execution of lazy - with pl.StringCache(): - df1 = pl.DataFrame( - {"region_ids": ["reg1", "reg2", "reg3", "reg4", "reg5"]} - ).select([pl.col("region_ids").cast(pl.Categorical)]) - - df2 = pl.DataFrame( - {"seq_name": ["reg4", "reg2", "reg1"], "score": [3.0, 1.0, 2.0]} - ).select([pl.col("seq_name").cast(pl.Categorical), pl.col("score")]) - - expected = pl.DataFrame( - { - "region_ids": ["reg1", "reg2", "reg3", "reg4", "reg5"], - "score": [2.0, 1.0, None, 3.0, None], - } - ).with_columns(pl.col("region_ids").cast(pl.Categorical)) - - result = df1.join(df2, left_on="region_ids", right_on="seq_name", how="left") - assert_frame_equal(result, expected, check_row_order=False) - - # also check row-wise categorical insert. - # (column-wise is preferred, but this shouldn't fail) - for params in ( - {"schema": [("region_ids", pl.Categorical)]}, - { - "schema": ["region_ids"], - "schema_overrides": {"region_ids": pl.Categorical}, - }, - ): - df3 = pl.DataFrame( - data=[["reg1"], ["reg2"], ["reg3"], ["reg4"], ["reg5"]], - orient="row", - **params, - ) - assert_frame_equal(df1, df3) diff --git a/py-polars/tests/unit/testing/test_assert_series_equal.py b/py-polars/tests/unit/testing/test_assert_series_equal.py index efa6b04a0a0c..7f9df3ca6fab 100644 --- a/py-polars/tests/unit/testing/test_assert_series_equal.py +++ b/py-polars/tests/unit/testing/test_assert_series_equal.py @@ -731,22 +731,20 @@ def test_assert_series_equal_check_dtype_deprecated() -> None: assert_series_not_equal(s1, s3, check_dtype=False) # type: ignore[call-arg] -def test_assert_series_equal_nested_categorical_as_str_global() -> None: +def test_assert_series_equal_nested_categorical_as_str_independently_constructed() -> ( + None +): # https://github.com/pola-rs/polars/issues/16196 + s1 = pl.Series(["c0"], dtype=pl.Categorical) + s2 = pl.Series(["c1"], dtype=pl.Categorical) + a = pl.DataFrame([s1, s2]).to_struct("col0") - # Global - with pl.StringCache(): - s1 = pl.Series(["c0"], dtype=pl.Categorical) - s2 = pl.Series(["c1"], dtype=pl.Categorical) - s_global = pl.DataFrame([s1, s2]).to_struct("col0") - - # Local s1 = pl.Series(["c0"], dtype=pl.Categorical) s2 = pl.Series(["c1"], dtype=pl.Categorical) - s_local = pl.DataFrame([s1, s2]).to_struct("col0") + b = pl.DataFrame([s1, s2]).to_struct("col0") - assert_series_equal(s_global, s_local, categorical_as_str=True) - assert_series_not_equal(s_global, s_local, categorical_as_str=False) + assert_series_equal(a, b, categorical_as_str=True) + assert_series_equal(a, b, categorical_as_str=False) @pytest.mark.parametrize(