Skip to content

Commit

Permalink
[backport] Make sure input numpy array is aligned. (#8690) (#8696) (#…
Browse files Browse the repository at this point in the history
…8734)

* [backport] Make sure input numpy array is aligned. (#8690)

- use `np.require` to specify that the alignment is required.
- scipy csr as well.
- validate input pointer in `ArrayInterface`.

* Workaround CUDA warning. (#8696)

* backport from half type support for alignment.

* fix import.
  • Loading branch information
trivialfis authored Feb 6, 2023
1 parent 68d8633 commit 2f22f8d
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 25 deletions.
17 changes: 10 additions & 7 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2172,6 +2172,7 @@ def assign_type(t: int) -> None:
)
return _prediction_output(shape, dims, preds, False)

# pylint: disable=too-many-statements
def inplace_predict(
self,
data: DataType,
Expand All @@ -2192,10 +2193,10 @@ def inplace_predict(
.. code-block:: python
booster.set_param({'predictor': 'gpu_predictor'})
booster.set_param({"predictor": "gpu_predictor"})
booster.inplace_predict(cupy_array)
booster.set_param({'predictor': 'cpu_predictor})
booster.set_param({"predictor": "cpu_predictor"})
booster.inplace_predict(numpy_array)
.. versionadded:: 1.1.0
Expand Down Expand Up @@ -2301,14 +2302,16 @@ def inplace_predict(
)
return _prediction_output(shape, dims, preds, False)
if isinstance(data, scipy.sparse.csr_matrix):
csr = data
from .data import _transform_scipy_csr

data = _transform_scipy_csr(data)
_check_call(
_LIB.XGBoosterPredictFromCSR(
self.handle,
_array_interface(csr.indptr),
_array_interface(csr.indices),
_array_interface(csr.data),
c_bst_ulong(csr.shape[1]),
_array_interface(data.indptr),
_array_interface(data.indices),
_array_interface(data.data),
c_bst_ulong(data.shape[1]),
from_pystr_to_cstr(json.dumps(args)),
p_handle,
ctypes.byref(shape),
Expand Down
33 changes: 24 additions & 9 deletions python-package/xgboost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
c_array,
c_str,
from_pystr_to_cstr,
make_jcargs,
)

DispatchedDataBackendReturnType = Tuple[
Expand Down Expand Up @@ -80,6 +81,21 @@ def _array_interface(data: np.ndarray) -> bytes:
return interface_str


def _transform_scipy_csr(data: DataType) -> DataType:
from scipy.sparse import csr_matrix

indptr, _ = _ensure_np_dtype(data.indptr, data.indptr.dtype)
indices, _ = _ensure_np_dtype(data.indices, data.indices.dtype)
values, _ = _ensure_np_dtype(data.data, data.data.dtype)
if (
indptr is not data.indptr
or indices is not data.indices
or values is not data.data
):
data = csr_matrix((values, indices, indptr), shape=data.shape)
return data


def _from_scipy_csr(
data: DataType,
missing: FloatCompatible,
Expand All @@ -93,18 +109,14 @@ def _from_scipy_csr(
f"length mismatch: {len(data.indices)} vs {len(data.data)}"
)
handle = ctypes.c_void_p()
args = {
"missing": float(missing),
"nthread": int(nthread),
}
config = bytes(json.dumps(args), "utf-8")
data = _transform_scipy_csr(data)
_check_call(
_LIB.XGDMatrixCreateFromCSR(
_array_interface(data.indptr),
_array_interface(data.indices),
_array_interface(data.data),
c_bst_ulong(data.shape[1]),
config,
make_jcargs(missing=float(missing), nthread=int(nthread)),
ctypes.byref(handle),
)
)
Expand Down Expand Up @@ -153,12 +165,13 @@ def _is_numpy_array(data: DataType) -> bool:


def _ensure_np_dtype(
data: DataType,
dtype: Optional[NumpyDType]
data: DataType, dtype: Optional[NumpyDType]
) -> Tuple[np.ndarray, Optional[NumpyDType]]:
if data.dtype.hasobject or data.dtype in [np.float16, np.bool_]:
data = data.astype(np.float32, copy=False)
dtype = np.float32
data = data.astype(dtype, copy=False)
if not data.flags.aligned:
data = np.require(data, requirements="A")
return data, dtype


Expand Down Expand Up @@ -1197,11 +1210,13 @@ def _proxy_transform(
data, _ = _ensure_np_dtype(data, data.dtype)
return data, None, feature_names, feature_types
if _is_scipy_csr(data):
data = _transform_scipy_csr(data)
return data, None, feature_names, feature_types
if _is_pandas_df(data):
arr, feature_names, feature_types = _transform_pandas_df(
data, enable_categorical, feature_names, feature_types
)
arr, _ = _ensure_np_dtype(arr, arr.dtype)
return arr, None, feature_names, feature_types
raise TypeError("Value type is not supported for data iterator:" + str(type(data)))

Expand Down
25 changes: 19 additions & 6 deletions src/data/array_interface.h
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
/*!
* Copyright 2019-2021 by Contributors
/**
* Copyright 2019-2023 by XGBoost Contributors
* \file array_interface.h
* \brief View of __array_interface__
*/
#ifndef XGBOOST_DATA_ARRAY_INTERFACE_H_
#define XGBOOST_DATA_ARRAY_INTERFACE_H_

#include <algorithm>
#include <cinttypes>
#include <cstddef> // std::size_t
#include <cstdint>
#include <map>
#include <string>
#include <type_traits> // std::alignment_of,std::remove_pointer_t
#include <utility>
#include <vector>

Expand Down Expand Up @@ -394,6 +396,11 @@ class ArrayInterface {

data = ArrayInterfaceHandler::ExtractData(array, n);
static_assert(allow_mask ? D == 1 : D >= 1, "Masked ndarray is not supported.");

auto alignment = this->ElementAlignment();
auto ptr = reinterpret_cast<uintptr_t>(this->data);
CHECK_EQ(ptr % alignment, 0) << "Input pointer misalignment.";

if (allow_mask) {
common::Span<RBitField8::value_type> s_mask;
size_t n_bits = ArrayInterfaceHandler::ExtractMask(array, &s_mask);
Expand Down Expand Up @@ -512,9 +519,15 @@ class ArrayInterface {
return func(reinterpret_cast<uint64_t const *>(data));
}

XGBOOST_DEVICE size_t ElementSize() {
return this->DispatchCall(
[](auto *p_values) { return sizeof(std::remove_pointer_t<decltype(p_values)>); });
XGBOOST_DEVICE std::size_t ElementSize() const {
return this->DispatchCall([](auto *typed_data_ptr) {
return sizeof(std::remove_pointer_t<decltype(typed_data_ptr)>);
});
}
XGBOOST_DEVICE std::size_t ElementAlignment() const {
return this->DispatchCall([](auto *typed_data_ptr) {
return std::alignment_of<std::remove_pointer_t<decltype(typed_data_ptr)>>::value;
});
}

template <typename T = float, typename... Index>
Expand Down
14 changes: 12 additions & 2 deletions tests/cpp/data/test_array_interface.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
/*!
* Copyright 2020-2021 by XGBoost Contributors
/**
* Copyright 2020-2023 by XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/host_device_vector.h>
#include "../helpers.h"
#include "../../../src/data/array_interface.h"
#include "dmlc/logging.h"
#include "xgboost/json.h"

namespace xgboost {
TEST(ArrayInterface, Initialize) {
Expand Down Expand Up @@ -71,6 +73,14 @@ TEST(ArrayInterface, Error) {
column["mask"]["data"] = Null{};
common::Span<RBitField8::value_type> s_mask;
EXPECT_THROW(ArrayInterfaceHandler::ExtractMask(column_obj, &s_mask), dmlc::Error);

get<Object>(column).erase("mask");
// misaligned.
j_data = {Json(Integer(reinterpret_cast<Integer::Int>(
reinterpret_cast<char const*>(storage.ConstHostPointer()) + 1))),
Json(Boolean(false))};
column["data"] = j_data;
EXPECT_THROW({ ArrayInterface<1> arr{column}; }, dmlc::Error);
}

TEST(ArrayInterface, GetElement) {
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_dmatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def test_sparse_dmatrix_csr(self):
nrow = 100
ncol = 1000
x = rand(nrow, ncol, density=0.0005, format='csr', random_state=rng)
assert x.indices.max() < ncol - 1
assert x.indices.max() < ncol
x.data[:] = 1
dtrain = xgb.DMatrix(x, label=rng.binomial(1, 0.3, nrow))
assert (dtrain.num_row(), dtrain.num_col()) == (nrow, ncol)
Expand Down

0 comments on commit 2f22f8d

Please sign in to comment.