Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] Allow to pass Arrow table with boolean columns to dataset #6353

Merged
merged 11 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions include/LightGBM/arrow.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ class ArrowChunkedArray {
const ArrowSchema* schema_;
/* List of length `n + 1` for `n` chunks containing the offsets for each chunk. */
std::vector<int64_t> chunk_offsets_;
/* Indicator whether this chunked array needs to call the arrays' release callbacks.
NOTE: This is MUST only be set to `true` if this chunked array is not part of a
`ArrowTable` as children arrays may not be released by the consumer (see below). */
const bool releases_arrow_;

inline void construct_chunk_offsets() {
chunk_offsets_.reserve(chunks_.size() + 1);
Expand All @@ -100,7 +104,8 @@ class ArrowChunkedArray {
* @param chunks A list with the chunks.
* @param schema The schema for all chunks.
*/
inline ArrowChunkedArray(std::vector<const ArrowArray*> chunks, const ArrowSchema* schema) {
inline ArrowChunkedArray(std::vector<const ArrowArray*> chunks, const ArrowSchema* schema)
: releases_arrow_(false) {
chunks_ = chunks;
schema_ = schema;
construct_chunk_offsets();
Expand All @@ -113,9 +118,9 @@ class ArrowChunkedArray {
* @param chunks A C-style array containing the chunks.
* @param schema The schema for all chunks.
*/
inline ArrowChunkedArray(int64_t n_chunks,
const struct ArrowArray* chunks,
const struct ArrowSchema* schema) {
inline ArrowChunkedArray(int64_t n_chunks, const struct ArrowArray* chunks,
const struct ArrowSchema* schema)
: releases_arrow_(true) {
chunks_.reserve(n_chunks);
for (auto k = 0; k < n_chunks; ++k) {
if (chunks[k].length == 0) continue;
Expand All @@ -125,6 +130,21 @@ class ArrowChunkedArray {
construct_chunk_offsets();
}

~ArrowChunkedArray() {
if (!releases_arrow_) {
return;
}
for (size_t i = 0; i < chunks_.size(); ++i) {
auto chunk = chunks_[i];
if (chunk->release) {
chunk->release(const_cast<ArrowArray*>(chunk));
}
}
if (schema_->release) {
schema_->release(const_cast<ArrowSchema*>(schema_));
}
}

/**
* @brief Get the length of the chunked array.
* This method returns the cumulative length of all chunks.
Expand Down Expand Up @@ -219,7 +239,7 @@ class ArrowTable {
* @param chunks A C-style array containing the chunks.
* @param schema The schema for all chunks.
*/
inline ArrowTable(int64_t n_chunks, const ArrowArray *chunks, const ArrowSchema *schema)
inline ArrowTable(int64_t n_chunks, const ArrowArray* chunks, const ArrowSchema* schema)
: n_chunks_(n_chunks), chunks_ptr_(chunks), schema_ptr_(schema) {
columns_.reserve(schema->n_children);
for (int64_t j = 0; j < schema->n_children; ++j) {
Expand All @@ -236,7 +256,8 @@ class ArrowTable {
~ArrowTable() {
// As consumer of the Arrow array, the Arrow table must release all Arrow arrays it receives
// as well as the schema. As per the specification, children arrays are released by the
// producer. See: https://arrow.apache.org/docs/format/CDataInterface.html#release-callback-semantics-for-consumers
// producer. See:
// https://arrow.apache.org/docs/format/CDataInterface.html#release-callback-semantics-for-consumers
for (int64_t i = 0; i < n_chunks_; ++i) {
auto chunk = &chunks_ptr_[i];
if (chunk->release) {
Expand Down
30 changes: 24 additions & 6 deletions include/LightGBM/arrow.tpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ inline ArrowChunkedArray::Iterator<T> ArrowChunkedArray::end() const {
/* ---------------------------------- ITERATOR IMPLEMENTATION ---------------------------------- */

template <typename T>
ArrowChunkedArray::Iterator<T>::Iterator(const ArrowChunkedArray& array,
getter_fn get,
ArrowChunkedArray::Iterator<T>::Iterator(const ArrowChunkedArray& array, getter_fn get,
int64_t ptr_chunk)
: array_(array), get_(get), ptr_chunk_(ptr_chunk) {
this->ptr_offset_ = 0;
Expand All @@ -41,7 +40,7 @@ ArrowChunkedArray::Iterator<T>::Iterator(const ArrowChunkedArray& array,
template <typename T>
T ArrowChunkedArray::Iterator<T>::operator*() const {
auto chunk = array_.chunks_[ptr_chunk_];
return static_cast<T>(get_(chunk, ptr_offset_));
return get_(chunk, ptr_offset_);
}

template <typename T>
Expand All @@ -54,7 +53,7 @@ T ArrowChunkedArray::Iterator<T>::operator[](I idx) const {
auto chunk = array_.chunks_[chunk_idx];

auto ptr_offset = static_cast<int64_t>(idx) - array_.chunk_offsets_[chunk_idx];
return static_cast<T>(get_(chunk, ptr_offset));
return get_(chunk, ptr_offset);
}

template <typename T>
Expand Down Expand Up @@ -147,11 +146,28 @@ struct ArrayIndexAccessor {
if (validity == nullptr || (validity[buffer_idx / 8] & (1 << (buffer_idx % 8)))) {
// In case the index is valid, we take it from the data buffer
auto data = static_cast<const T*>(array->buffers[1]);
return static_cast<double>(data[buffer_idx]);
return static_cast<V>(data[buffer_idx]);
}

// In case the index is not valid, we return a default value
return arrow_primitive_missing_value<T>();
return arrow_primitive_missing_value<V>();
}
};

template <typename V>
struct ArrayIndexAccessor<bool, V> {
V operator()(const ArrowArray* array, size_t idx) {
// Custom implementation for booleans as values are bit-packed:
// https://arrow.apache.org/docs/cpp/api/datatype.html#_CPPv4N5arrow4Type4type4BOOLE
auto buffer_idx = idx + array->offset;
auto validity = static_cast<const char*>(array->buffers[0]);
if (validity == nullptr || (validity[buffer_idx / 8] & (1 << (buffer_idx % 8)))) {
// In case the index is valid, we have to take the appropriate bit from the buffer
auto data = static_cast<const char*>(array->buffers[1]);
auto value = (data[buffer_idx / 8] & (1 << (buffer_idx % 8))) >> (buffer_idx % 8);
return static_cast<V>(value);
}
return arrow_primitive_missing_value<V>();
}
};

Expand Down Expand Up @@ -180,6 +196,8 @@ std::function<T(const ArrowArray*, size_t)> get_index_accessor(const char* dtype
return ArrayIndexAccessor<float, T>();
case 'g':
return ArrayIndexAccessor<double, T>();
case 'b':
return ArrayIndexAccessor<bool, T>();
default:
throw std::invalid_argument("unsupported Arrow datatype");
}
Expand Down
5 changes: 3 additions & 2 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
PANDAS_INSTALLED,
PYARROW_INSTALLED,
arrow_cffi,
arrow_is_boolean,
arrow_is_floating,
arrow_is_integer,
concat,
Expand Down Expand Up @@ -1688,7 +1689,7 @@ def __pred_for_pyarrow_table(
raise LightGBMError("Cannot predict from Arrow without `pyarrow` installed.")

# Check that the input is valid: we only handle numbers (for now)
if not all(arrow_is_integer(t) or arrow_is_floating(t) for t in table.schema.types):
if not all(arrow_is_integer(t) or arrow_is_floating(t) or arrow_is_boolean(t) for t in table.schema.types):
raise ValueError("Arrow table may only have integer or floating point datatypes")

# Prepare prediction output array
Expand Down Expand Up @@ -2435,7 +2436,7 @@ def __init_from_pyarrow_table(
raise LightGBMError("Cannot init dataframe from Arrow without `pyarrow` installed.")

# Check that the input is valid: we only handle numbers (for now)
if not all(arrow_is_integer(t) or arrow_is_floating(t) for t in table.schema.types):
if not all(arrow_is_integer(t) or arrow_is_floating(t) or arrow_is_boolean(t) for t in table.schema.types):
raise ValueError("Arrow table may only have integer or floating point datatypes")

# Export Arrow table to C
Expand Down
2 changes: 2 additions & 0 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def __init__(self, *args: Any, **kwargs: Any):
from pyarrow import Table as pa_Table
from pyarrow import chunked_array as pa_chunked_array
from pyarrow.cffi import ffi as arrow_cffi
from pyarrow.types import is_boolean as arrow_is_boolean
from pyarrow.types import is_floating as arrow_is_floating
from pyarrow.types import is_integer as arrow_is_integer

Expand Down Expand Up @@ -265,6 +266,7 @@ class pa_compute: # type: ignore
equal = None

pa_chunked_array = None
arrow_is_boolean = None
arrow_is_integer = None
arrow_is_floating = None

Expand Down
Loading
Loading