diff --git a/include/LightGBM/arrow.h b/include/LightGBM/arrow.h index 6c6ca0d5a615..7161e11fbcdc 100644 --- a/include/LightGBM/arrow.h +++ b/include/LightGBM/arrow.h @@ -84,6 +84,10 @@ class ArrowChunkedArray { const ArrowSchema* schema_; /* List of length `n + 1` for `n` chunks containing the offsets for each chunk. */ std::vector chunk_offsets_; + /* Indicator whether this chunked array needs to call the arrays' release callbacks. + NOTE: This is MUST only be set to `true` if this chunked array is not part of a + `ArrowTable` as children arrays may not be released by the consumer (see below). */ + const bool releases_arrow_; inline void construct_chunk_offsets() { chunk_offsets_.reserve(chunks_.size() + 1); @@ -100,7 +104,8 @@ class ArrowChunkedArray { * @param chunks A list with the chunks. * @param schema The schema for all chunks. */ - inline ArrowChunkedArray(std::vector chunks, const ArrowSchema* schema) { + inline ArrowChunkedArray(std::vector chunks, const ArrowSchema* schema) + : releases_arrow_(false) { chunks_ = chunks; schema_ = schema; construct_chunk_offsets(); @@ -113,9 +118,9 @@ class ArrowChunkedArray { * @param chunks A C-style array containing the chunks. * @param schema The schema for all chunks. */ - inline ArrowChunkedArray(int64_t n_chunks, - const struct ArrowArray* chunks, - const struct ArrowSchema* schema) { + inline ArrowChunkedArray(int64_t n_chunks, const struct ArrowArray* chunks, + const struct ArrowSchema* schema) + : releases_arrow_(true) { chunks_.reserve(n_chunks); for (auto k = 0; k < n_chunks; ++k) { if (chunks[k].length == 0) continue; @@ -125,6 +130,21 @@ class ArrowChunkedArray { construct_chunk_offsets(); } + ~ArrowChunkedArray() { + if (!releases_arrow_) { + return; + } + for (size_t i = 0; i < chunks_.size(); ++i) { + auto chunk = chunks_[i]; + if (chunk->release) { + chunk->release(const_cast(chunk)); + } + } + if (schema_->release) { + schema_->release(const_cast(schema_)); + } + } + /** * @brief Get the length of the chunked array. * This method returns the cumulative length of all chunks. @@ -219,7 +239,7 @@ class ArrowTable { * @param chunks A C-style array containing the chunks. * @param schema The schema for all chunks. */ - inline ArrowTable(int64_t n_chunks, const ArrowArray *chunks, const ArrowSchema *schema) + inline ArrowTable(int64_t n_chunks, const ArrowArray* chunks, const ArrowSchema* schema) : n_chunks_(n_chunks), chunks_ptr_(chunks), schema_ptr_(schema) { columns_.reserve(schema->n_children); for (int64_t j = 0; j < schema->n_children; ++j) { @@ -236,7 +256,8 @@ class ArrowTable { ~ArrowTable() { // As consumer of the Arrow array, the Arrow table must release all Arrow arrays it receives // as well as the schema. As per the specification, children arrays are released by the - // producer. See: https://arrow.apache.org/docs/format/CDataInterface.html#release-callback-semantics-for-consumers + // producer. See: + // https://arrow.apache.org/docs/format/CDataInterface.html#release-callback-semantics-for-consumers for (int64_t i = 0; i < n_chunks_; ++i) { auto chunk = &chunks_ptr_[i]; if (chunk->release) { diff --git a/include/LightGBM/arrow.tpp b/include/LightGBM/arrow.tpp index 8d1ce4f4c0c1..1a1f609e32ff 100644 --- a/include/LightGBM/arrow.tpp +++ b/include/LightGBM/arrow.tpp @@ -31,8 +31,7 @@ inline ArrowChunkedArray::Iterator ArrowChunkedArray::end() const { /* ---------------------------------- ITERATOR IMPLEMENTATION ---------------------------------- */ template -ArrowChunkedArray::Iterator::Iterator(const ArrowChunkedArray& array, - getter_fn get, +ArrowChunkedArray::Iterator::Iterator(const ArrowChunkedArray& array, getter_fn get, int64_t ptr_chunk) : array_(array), get_(get), ptr_chunk_(ptr_chunk) { this->ptr_offset_ = 0; @@ -41,7 +40,7 @@ ArrowChunkedArray::Iterator::Iterator(const ArrowChunkedArray& array, template T ArrowChunkedArray::Iterator::operator*() const { auto chunk = array_.chunks_[ptr_chunk_]; - return static_cast(get_(chunk, ptr_offset_)); + return get_(chunk, ptr_offset_); } template @@ -54,7 +53,7 @@ T ArrowChunkedArray::Iterator::operator[](I idx) const { auto chunk = array_.chunks_[chunk_idx]; auto ptr_offset = static_cast(idx) - array_.chunk_offsets_[chunk_idx]; - return static_cast(get_(chunk, ptr_offset)); + return get_(chunk, ptr_offset); } template @@ -147,11 +146,28 @@ struct ArrayIndexAccessor { if (validity == nullptr || (validity[buffer_idx / 8] & (1 << (buffer_idx % 8)))) { // In case the index is valid, we take it from the data buffer auto data = static_cast(array->buffers[1]); - return static_cast(data[buffer_idx]); + return static_cast(data[buffer_idx]); } // In case the index is not valid, we return a default value - return arrow_primitive_missing_value(); + return arrow_primitive_missing_value(); + } +}; + +template +struct ArrayIndexAccessor { + V operator()(const ArrowArray* array, size_t idx) { + // Custom implementation for booleans as values are bit-packed: + // https://arrow.apache.org/docs/cpp/api/datatype.html#_CPPv4N5arrow4Type4type4BOOLE + auto buffer_idx = idx + array->offset; + auto validity = static_cast(array->buffers[0]); + if (validity == nullptr || (validity[buffer_idx / 8] & (1 << (buffer_idx % 8)))) { + // In case the index is valid, we have to take the appropriate bit from the buffer + auto data = static_cast(array->buffers[1]); + auto value = (data[buffer_idx / 8] & (1 << (buffer_idx % 8))) >> (buffer_idx % 8); + return static_cast(value); + } + return arrow_primitive_missing_value(); } }; @@ -180,6 +196,8 @@ std::function get_index_accessor(const char* dtype return ArrayIndexAccessor(); case 'g': return ArrayIndexAccessor(); + case 'b': + return ArrayIndexAccessor(); default: throw std::invalid_argument("unsupported Arrow datatype"); } diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index bb7dfb3b73eb..ee55b642ffa0 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -22,6 +22,7 @@ PANDAS_INSTALLED, PYARROW_INSTALLED, arrow_cffi, + arrow_is_boolean, arrow_is_floating, arrow_is_integer, concat, @@ -1688,7 +1689,7 @@ def __pred_for_pyarrow_table( raise LightGBMError("Cannot predict from Arrow without `pyarrow` installed.") # Check that the input is valid: we only handle numbers (for now) - if not all(arrow_is_integer(t) or arrow_is_floating(t) for t in table.schema.types): + if not all(arrow_is_integer(t) or arrow_is_floating(t) or arrow_is_boolean(t) for t in table.schema.types): raise ValueError("Arrow table may only have integer or floating point datatypes") # Prepare prediction output array @@ -2435,7 +2436,7 @@ def __init_from_pyarrow_table( raise LightGBMError("Cannot init dataframe from Arrow without `pyarrow` installed.") # Check that the input is valid: we only handle numbers (for now) - if not all(arrow_is_integer(t) or arrow_is_floating(t) for t in table.schema.types): + if not all(arrow_is_integer(t) or arrow_is_floating(t) or arrow_is_boolean(t) for t in table.schema.types): raise ValueError("Arrow table may only have integer or floating point datatypes") # Export Arrow table to C diff --git a/python-package/lightgbm/compat.py b/python-package/lightgbm/compat.py index 7656a0458315..9eed61a66a6c 100644 --- a/python-package/lightgbm/compat.py +++ b/python-package/lightgbm/compat.py @@ -222,6 +222,7 @@ def __init__(self, *args: Any, **kwargs: Any): from pyarrow import Table as pa_Table from pyarrow import chunked_array as pa_chunked_array from pyarrow.cffi import ffi as arrow_cffi + from pyarrow.types import is_boolean as arrow_is_boolean from pyarrow.types import is_floating as arrow_is_floating from pyarrow.types import is_integer as arrow_is_integer @@ -265,6 +266,7 @@ class pa_compute: # type: ignore equal = None pa_chunked_array = None + arrow_is_boolean = None arrow_is_integer = None arrow_is_floating = None diff --git a/tests/cpp_tests/test_arrow.cpp b/tests/cpp_tests/test_arrow.cpp index e975b6ba374b..c2dbd6cec6bf 100644 --- a/tests/cpp_tests/test_arrow.cpp +++ b/tests/cpp_tests/test_arrow.cpp @@ -5,87 +5,151 @@ * Author: Oliver Borchert */ -#include #include +#include -#include #include +#include using LightGBM::ArrowChunkedArray; using LightGBM::ArrowTable; +/* --------------------------------------------------------------------------------------------- */ +/* UTILS */ +/* --------------------------------------------------------------------------------------------- */ +// This code is copied and adapted from the official Arrow producer examples: +// https://arrow.apache.org/docs/format/CDataInterface.html#exporting-a-struct-float32-utf8-array + +static void release_schema(struct ArrowSchema* schema) { + // Free children + if (schema->children) { + for (int64_t i = 0; i < schema->n_children; ++i) { + struct ArrowSchema* child = schema->children[i]; + if (child->release) { + child->release(child); + } + free(child); + } + free(schema->children); + } + + // Finalize + schema->release = nullptr; +} + +static void release_array(struct ArrowArray* array) { + // Free children + if (array->children) { + for (int64_t i = 0; i < array->n_children; ++i) { + struct ArrowArray* child = array->children[i]; + if (child->release) { + child->release(child); + } + free(child); + } + free(array->children); + } + + // Free buffers + for (int64_t i = 0; i < array->n_buffers; ++i) { + if (array->buffers[i]) { + free(const_cast(array->buffers[i])); + } + } + free(array->buffers); + + // Finalize + array->release = nullptr; +} + +/* ------------------------------------------ PRODUCER ----------------------------------------- */ + class ArrowChunkedArrayTest : public testing::Test { protected: void SetUp() override {} - ArrowArray created_nested_array(const std::vector& arrays) { + /* -------------------------------------- ARRAY CREATION ------------------------------------- */ + + char* build_validity_bitmap(int64_t size, std::vector null_indices = {}) { + if (null_indices.empty()) { + return nullptr; + } + auto num_bytes = (size + 7) / 8; + auto validity = static_cast(malloc(num_bytes * sizeof(char))); + memset(validity, 0xff, num_bytes * sizeof(char)); + for (auto idx : null_indices) { + validity[idx / 8] &= ~(1 << (idx % 8)); + } + return validity; + } + + ArrowArray build_primitive_array(void* data, int64_t size, int64_t offset, + std::vector null_indices) { + const void** buffers = (const void**)malloc(sizeof(void*) * 2); + buffers[0] = build_validity_bitmap(size, null_indices); + buffers[1] = data; + ArrowArray arr; - arr.buffers = nullptr; - arr.children = (ArrowArray**)arrays.data(); // NOLINT + arr.length = size - offset; + arr.null_count = static_cast(null_indices.size()); + arr.offset = offset; + arr.n_buffers = 2; + arr.n_children = 0; + arr.buffers = buffers; + arr.children = nullptr; arr.dictionary = nullptr; - arr.length = arrays[0]->length; - arr.n_buffers = 0; - arr.n_children = arrays.size(); - arr.null_count = 0; - arr.offset = 0; + arr.release = &release_array; arr.private_data = nullptr; - arr.release = nullptr; return arr; } template - ArrowArray create_primitive_array(const std::vector& values, - int64_t offset = 0, + ArrowArray create_primitive_array(const std::vector& values, int64_t offset = 0, std::vector null_indices = {}) { // NOTE: Arrow arrays have 64-bit alignment but we can safely ignore this in tests - // 1) Create validity bitmap - char* validity = nullptr; - if (!null_indices.empty()) { - auto num_bytes = (values.size() + 7) / 8; - validity = static_cast(calloc(num_bytes, sizeof(char))); - memset(validity, 0xff, num_bytes * sizeof(char)); - for (size_t i = 0; i < values.size(); ++i) { - if (std::find(null_indices.begin(), null_indices.end(), i) != null_indices.end()) { - validity[i / 8] &= ~(1 << (i % 8)); - } + auto buffer = static_cast(malloc(sizeof(T) * values.size())); + for (size_t i = 0; i < values.size(); ++i) { + buffer[i] = values[i]; + } + return build_primitive_array(buffer, values.size(), offset, null_indices); + } + + ArrowArray create_primitive_array(const std::vector& values, int64_t offset = 0, + std::vector null_indices = {}) { + auto num_bytes = (values.size() + 7) / 8; + auto buffer = static_cast(calloc(sizeof(char), num_bytes)); + for (size_t i = 0; i < values.size(); ++i) { + // By using `calloc` above, we only need to set 'true' values + if (values[i]) { + buffer[i / 8] |= (1 << (i % 8)); } } + return build_primitive_array(buffer, values.size(), offset, null_indices); + } - // 2) Create buffers - const void** buffers = (const void**)malloc(sizeof(void*) * 2); - buffers[0] = validity; - buffers[1] = values.data() + offset; + ArrowArray created_nested_array(const std::vector& arrays) { + auto children = static_cast(malloc(sizeof(ArrowArray*) * arrays.size())); + for (size_t i = 0; i < arrays.size(); ++i) { + auto child = static_cast(malloc(sizeof(ArrowArray))); + *child = *arrays[i]; + children[i] = child; + } - // Create arrow array ArrowArray arr; - arr.buffers = buffers; - arr.children = nullptr; - arr.dictionary = nullptr; - arr.length = values.size() - offset; + arr.length = children[0]->length; arr.null_count = 0; arr.offset = 0; + arr.n_buffers = 0; + arr.n_children = static_cast(arrays.size()); + arr.buffers = nullptr; + arr.children = children; + arr.dictionary = nullptr; + arr.release = &release_array; arr.private_data = nullptr; - arr.release = [](ArrowArray* arr) { - if (arr->buffers[0] != nullptr) - free((void*)(arr->buffers[0])); // NOLINT - free((void*)(arr->buffers)); // NOLINT - }; return arr; } - ArrowSchema create_nested_schema(const std::vector& arrays) { - ArrowSchema schema; - schema.format = "+s"; - schema.name = nullptr; - schema.metadata = nullptr; - schema.flags = 0; - schema.n_children = arrays.size(); - schema.children = (ArrowSchema**)arrays.data(); // NOLINT - schema.dictionary = nullptr; - schema.private_data = nullptr; - schema.release = nullptr; - return schema; - } + /* ------------------------------------- SCHEMA CREATION ------------------------------------- */ template ArrowSchema create_primitive_schema() { @@ -102,27 +166,71 @@ class ArrowChunkedArrayTest : public testing::Test { schema.n_children = 0; schema.children = nullptr; schema.dictionary = nullptr; + schema.release = nullptr; schema.private_data = nullptr; + return schema; + } + + template <> + ArrowSchema create_primitive_schema() { + ArrowSchema schema; + schema.format = "b"; + schema.name = nullptr; + schema.metadata = nullptr; + schema.flags = 0; + schema.n_children = 0; + schema.children = nullptr; + schema.dictionary = nullptr; schema.release = nullptr; + schema.private_data = nullptr; + return schema; + } + + ArrowSchema create_nested_schema(const std::vector& arrays) { + auto children = static_cast(malloc(sizeof(ArrowSchema*) * arrays.size())); + for (size_t i = 0; i < arrays.size(); ++i) { + auto child = static_cast(malloc(sizeof(ArrowSchema))); + *child = *arrays[i]; + children[i] = child; + } + + ArrowSchema schema; + schema.format = "+s"; + schema.name = nullptr; + schema.metadata = nullptr; + schema.flags = 0; + schema.n_children = static_cast(arrays.size()); + schema.children = children; + schema.dictionary = nullptr; + schema.release = &release_schema; + schema.private_data = nullptr; return schema; } }; +/* --------------------------------------------------------------------------------------------- */ +/* TESTS */ +/* --------------------------------------------------------------------------------------------- */ + TEST_F(ArrowChunkedArrayTest, GetLength) { + auto schema = create_primitive_schema(); + std::vector dat1 = {1, 2}; auto arr1 = create_primitive_array(dat1); - - ArrowChunkedArray ca1(1, &arr1, nullptr); + ArrowChunkedArray ca1(1, &arr1, &schema); ASSERT_EQ(ca1.get_length(), 2); std::vector dat2 = {3, 4, 5, 6}; - auto arr2 = create_primitive_array(dat2); - ArrowArray arrs[2] = {arr1, arr2}; - ArrowChunkedArray ca2(2, arrs, nullptr); + auto arr2 = create_primitive_array(dat1); + auto arr3 = create_primitive_array(dat2); + ArrowArray arrs[2] = {arr2, arr3}; + ArrowChunkedArray ca2(2, arrs, &schema); ASSERT_EQ(ca2.get_length(), 6); - arr1.release(&arr1); - arr2.release(&arr2); + std::vector dat3 = {true, false, true, true}; + auto arr4 = create_primitive_array(dat3, 1); + ArrowChunkedArray ca3(1, &arr4, &schema); + ASSERT_EQ(ca3.get_length(), 3); } TEST_F(ArrowChunkedArrayTest, GetColumns) { @@ -149,18 +257,15 @@ TEST_F(ArrowChunkedArrayTest, GetColumns) { auto ca2 = table.get_column(1); ASSERT_EQ(ca2.get_length(), 3); ASSERT_EQ(*ca2.begin(), 4); - - arr1.release(&arr1); - arr2.release(&arr2); } TEST_F(ArrowChunkedArrayTest, IteratorArithmetic) { std::vector dat1 = {1, 2}; - auto arr1 = create_primitive_array(dat1); + auto arr1 = create_primitive_array(dat1); std::vector dat2 = {3, 4, 5, 6}; - auto arr2 = create_primitive_array(dat2); + auto arr2 = create_primitive_array(dat2); std::vector dat3 = {7}; - auto arr3 = create_primitive_array(dat3); + auto arr3 = create_primitive_array(dat3); auto schema = create_primitive_schema(); ArrowArray arrs[3] = {arr1, arr2, arr3}; @@ -190,15 +295,39 @@ TEST_F(ArrowChunkedArrayTest, IteratorArithmetic) { auto end = ca.end(); ASSERT_EQ(end - it, 2); ASSERT_EQ(end - ca.begin(), 7); +} + +TEST_F(ArrowChunkedArrayTest, BooleanIterator) { + std::vector dat1 = {false, true, false}; + auto arr1 = create_primitive_array(dat1, 0, {2}); + std::vector dat2 = {false, false, false, false, true, true, true, true, false, true}; + auto arr2 = create_primitive_array(dat2, 1); + auto schema = create_primitive_schema(); + + ArrowArray arrs[2] = {arr1, arr2}; + ArrowChunkedArray ca(2, arrs, &schema); + + // Check for values in first chunk + auto it = ca.begin(); + ASSERT_EQ(*it, 0); + ASSERT_EQ(*(++it), 1); + ASSERT_TRUE(std::isnan(*(++it))); + + // Check for some values in second chunk + ASSERT_EQ(*(++it), 0); + it += 3; + ASSERT_EQ(*it, 1); + it += 4; + ASSERT_EQ(*it, 0); + ASSERT_EQ(*(++it), 1); - arr1.release(&arr1); - arr2.release(&arr2); - arr2.release(&arr3); + // Check end + ASSERT_EQ(++it, ca.end()); } TEST_F(ArrowChunkedArrayTest, OffsetAndValidity) { std::vector dat = {0, 1, 2, 3, 4, 5, 6}; - auto arr = create_primitive_array(dat, 2, {0, 1}); + auto arr = create_primitive_array(dat, 2, {2, 3}); auto schema = create_primitive_schema(); ArrowChunkedArray ca(1, &arr, &schema); diff --git a/tests/python_package_test/test_arrow.py b/tests/python_package_test/test_arrow.py index b8b90e1d051d..855300dda1ef 100644 --- a/tests/python_package_test/test_arrow.py +++ b/tests/python_package_test/test_arrow.py @@ -1,5 +1,6 @@ # coding: utf-8 import filecmp +from pathlib import Path from typing import Any, Dict, Optional import numpy as np @@ -43,16 +44,17 @@ def generate_simple_arrow_table(empty_chunks: bool = False) -> pa.Table: pa.chunked_array(c + [[1, 2, 3]] + c + [[4, 5]] + c, type=pa.int64()), pa.chunked_array(c + [[1, 2, 3]] + c + [[4, 5]] + c, type=pa.float32()), pa.chunked_array(c + [[1, 2, 3]] + c + [[4, 5]] + c, type=pa.float64()), + pa.chunked_array(c + [[True, True, False]] + c + [[False, True]] + c, type=pa.bool_()), ] return pa.Table.from_arrays(columns, names=[f"col_{i}" for i in range(len(columns))]) -def generate_nullable_arrow_table() -> pa.Table: +def generate_nullable_arrow_table(dtype: Any) -> pa.Table: columns = [ - pa.chunked_array([[1, None, 3, 4, 5]], type=pa.float32()), - pa.chunked_array([[None, 2, 3, 4, 5]], type=pa.float32()), - pa.chunked_array([[1, 2, 3, 4, None]], type=pa.float32()), - pa.chunked_array([[None, None, None, None, None]], type=pa.float32()), + pa.chunked_array([[1, None, 3, 4, 5]], type=dtype), + pa.chunked_array([[None, 2, 3, 4, 5]], type=dtype), + pa.chunked_array([[1, 2, 3, 4, None]], type=dtype), + pa.chunked_array([[None, None, None, None, None]], type=dtype), ] return pa.Table.from_arrays(columns, names=[f"col_{i}" for i in range(len(columns))]) @@ -120,13 +122,20 @@ def dummy_dataset_params() -> Dict[str, Any]: # ------------------------------------------- DATASET ------------------------------------------- # +def assert_datasets_equal(tmp_path: Path, lhs: lgb.Dataset, rhs: lgb.Dataset): + lhs._dump_text(tmp_path / "arrow.txt") + rhs._dump_text(tmp_path / "pandas.txt") + assert filecmp.cmp(tmp_path / "arrow.txt", tmp_path / "pandas.txt") + + @pytest.mark.parametrize( ("arrow_table_fn", "dataset_params"), [ # Use lambda functions here to minimize memory consumption (lambda: generate_simple_arrow_table(), dummy_dataset_params()), (lambda: generate_simple_arrow_table(empty_chunks=True), dummy_dataset_params()), (lambda: generate_dummy_arrow_table(), dummy_dataset_params()), - (lambda: generate_nullable_arrow_table(), dummy_dataset_params()), + (lambda: generate_nullable_arrow_table(pa.float32()), dummy_dataset_params()), + (lambda: generate_nullable_arrow_table(pa.int32()), dummy_dataset_params()), (lambda: generate_random_arrow_table(3, 1000, 42), {}), (lambda: generate_random_arrow_table(100, 10000, 43), {}), ], @@ -140,9 +149,22 @@ def test_dataset_construct_fuzzy(tmp_path, arrow_table_fn, dataset_params): pandas_dataset = lgb.Dataset(arrow_table.to_pandas(), params=dataset_params) pandas_dataset.construct() - arrow_dataset._dump_text(tmp_path / "arrow.txt") - pandas_dataset._dump_text(tmp_path / "pandas.txt") - assert filecmp.cmp(tmp_path / "arrow.txt", tmp_path / "pandas.txt") + assert_datasets_equal(tmp_path, arrow_dataset, pandas_dataset) + + +def test_dataset_construct_fuzzy_boolean(tmp_path): + boolean_data = generate_random_arrow_table(10, 10000, 42, generate_nulls=False, values=np.array([True, False])) + + float_schema = pa.schema([pa.field(f"col_{i}", pa.float32()) for i in range(len(boolean_data.columns))]) + float_data = boolean_data.cast(float_schema) + + arrow_dataset = lgb.Dataset(boolean_data) + arrow_dataset.construct() + + pandas_dataset = lgb.Dataset(float_data.to_pandas()) + pandas_dataset.construct() + + assert_datasets_equal(tmp_path, arrow_dataset, pandas_dataset) # -------------------------------------------- FIELDS ------------------------------------------- # @@ -195,6 +217,25 @@ def test_dataset_construct_labels(array_type, label_data, arrow_type): np_assert_array_equal(expected, dataset.get_label(), strict=True) +@pytest.mark.parametrize( + ["array_type", "label_data"], + [ + (pa.array, [False, True, False, False, True]), + (pa.chunked_array, [[False], [True, False, False, True]]), + (pa.chunked_array, [[], [False], [True, False, False, True]]), + (pa.chunked_array, [[False], [], [True, False], [], [], [False, True], []]), + ], +) +def test_dataset_construct_labels_boolean(array_type, label_data): + data = generate_dummy_arrow_table() + labels = array_type(label_data, type=pa.bool_()) + dataset = lgb.Dataset(data, label=labels, params=dummy_dataset_params()) + dataset.construct() + + expected = np.array([0, 1, 0, 0, 1], dtype=np.float32) + np_assert_array_equal(expected, dataset.get_label(), strict=True) + + # ------------------------------------------- WEIGHTS ------------------------------------------- # @@ -317,7 +358,10 @@ def assert_equal_predict_arrow_pandas(booster: lgb.Booster, data: pa.Table): def test_predict_regression(): - data = generate_random_arrow_table(10, 10000, 42) + data_float = generate_random_arrow_table(10, 10000, 42) + data_bool = generate_random_arrow_table(1, 10000, 42, generate_nulls=False, values=np.array([True, False])) + data = pa.Table.from_arrays(data_float.columns + data_bool.columns, names=data_float.schema.names + ["col_bool"]) + dataset = lgb.Dataset( data, label=generate_random_arrow_array(10000, 43, generate_nulls=False),