Skip to content

Commit c4c33bd

Browse files
Merge pull request #9 from pcmoritz/base
Introduce base object
2 parents 3fc9196 + 84cfce3 commit c4c33bd

File tree

5 files changed

+24
-19
lines changed

5 files changed

+24
-19
lines changed

python/src/pynumbuf/adapters/numpy.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,15 @@ namespace numbuf {
2525
type* data = const_cast<type*>(values->raw_data()) \
2626
+ content->offset(offset); \
2727
*out = PyArray_SimpleNewFromData(num_dims, dim.data(), NPY_##TYPE, \
28-
reinterpret_cast<void*>(data)); \
28+
reinterpret_cast<void*>(data)); \
29+
if (base != Py_None) { \
30+
PyArray_SetBaseObject((PyArrayObject*) *out, base); \
31+
} \
32+
Py_XINCREF(base); \
2933
} \
3034
return Status::OK();
3135

32-
Status DeserializeArray(std::shared_ptr<Array> array, int32_t offset, PyObject** out) {
36+
Status DeserializeArray(std::shared_ptr<Array> array, int32_t offset, PyObject* base, PyObject** out) {
3337
DCHECK(array);
3438
auto tensor = std::dynamic_pointer_cast<StructArray>(array);
3539
DCHECK(tensor);

python/src/pynumbuf/adapters/numpy.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
namespace numbuf {
1616

1717
arrow::Status SerializeArray(PyArrayObject* array, SequenceBuilder& builder, std::vector<PyObject*>& subdicts);
18-
arrow::Status DeserializeArray(std::shared_ptr<arrow::Array> array, int32_t offset, PyObject** out);
18+
arrow::Status DeserializeArray(std::shared_ptr<arrow::Array> array, int32_t offset, PyObject* base, PyObject** out);
1919

2020
}
2121

python/src/pynumbuf/adapters/python.cc

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ extern PyObject* numbuf_deserialize_callback;
1111

1212
namespace numbuf {
1313

14-
PyObject* get_value(ArrayPtr arr, int32_t index, int32_t type) {
14+
PyObject* get_value(ArrayPtr arr, int32_t index, int32_t type, PyObject* base) {
1515
PyObject* result;
1616
switch (arr->type()->type) {
1717
case Type::BOOL:
@@ -36,13 +36,13 @@ PyObject* get_value(ArrayPtr arr, int32_t index, int32_t type) {
3636
auto s = std::static_pointer_cast<StructArray>(arr);
3737
auto l = std::static_pointer_cast<ListArray>(s->field(0));
3838
if (s->type()->child(0)->name == "list") {
39-
ARROW_CHECK_OK(DeserializeList(l->values(), l->value_offset(index), l->value_offset(index+1), &result));
39+
ARROW_CHECK_OK(DeserializeList(l->values(), l->value_offset(index), l->value_offset(index+1), base, &result));
4040
} else if (s->type()->child(0)->name == "tuple") {
41-
ARROW_CHECK_OK(DeserializeTuple(l->values(), l->value_offset(index), l->value_offset(index+1), &result));
41+
ARROW_CHECK_OK(DeserializeTuple(l->values(), l->value_offset(index), l->value_offset(index+1), base, &result));
4242
} else if (s->type()->child(0)->name == "dict") {
43-
ARROW_CHECK_OK(DeserializeDict(l->values(), l->value_offset(index), l->value_offset(index+1), &result));
43+
ARROW_CHECK_OK(DeserializeDict(l->values(), l->value_offset(index), l->value_offset(index+1), base, &result));
4444
} else {
45-
ARROW_CHECK_OK(DeserializeArray(arr, index, &result));
45+
ARROW_CHECK_OK(DeserializeArray(arr, index, base, &result));
4646
}
4747
return result;
4848
}
@@ -181,17 +181,17 @@ Status SerializeSequences(std::vector<PyObject*> sequences, std::shared_ptr<Arra
181181
int32_t offset = offsets->Value(i); \
182182
int8_t type = types->Value(i); \
183183
ArrayPtr arr = data->child(type); \
184-
SET_ITEM(result, i-start_idx, get_value(arr, offset, type)); \
184+
SET_ITEM(result, i-start_idx, get_value(arr, offset, type, base)); \
185185
} \
186186
} \
187187
*out = result; \
188188
return Status::OK();
189189

190-
Status DeserializeList(std::shared_ptr<Array> array, int32_t start_idx, int32_t stop_idx, PyObject** out) {
190+
Status DeserializeList(std::shared_ptr<Array> array, int32_t start_idx, int32_t stop_idx, PyObject* base, PyObject** out) {
191191
DESERIALIZE_SEQUENCE(PyList_New, PyList_SetItem)
192192
}
193193

194-
Status DeserializeTuple(std::shared_ptr<Array> array, int32_t start_idx, int32_t stop_idx, PyObject** out) {
194+
Status DeserializeTuple(std::shared_ptr<Array> array, int32_t start_idx, int32_t stop_idx, PyObject* base, PyObject** out) {
195195
DESERIALIZE_SEQUENCE(PyTuple_New, PyTuple_SetItem)
196196
}
197197

@@ -227,13 +227,13 @@ Status SerializeDict(std::vector<PyObject*> dicts, std::shared_ptr<Array>* out)
227227
return Status::OK();
228228
}
229229

230-
Status DeserializeDict(std::shared_ptr<Array> array, int32_t start_idx, int32_t stop_idx, PyObject** out) {
230+
Status DeserializeDict(std::shared_ptr<Array> array, int32_t start_idx, int32_t stop_idx, PyObject* base, PyObject** out) {
231231
auto data = std::dynamic_pointer_cast<StructArray>(array);
232232
// TODO(pcm): error handling, get rid of the temporary copy of the list
233233
PyObject *keys, *vals;
234234
PyObject* result = PyDict_New();
235-
ARROW_RETURN_NOT_OK(DeserializeList(data->field(0), start_idx, stop_idx, &keys));
236-
ARROW_RETURN_NOT_OK(DeserializeList(data->field(1), start_idx, stop_idx, &vals));
235+
ARROW_RETURN_NOT_OK(DeserializeList(data->field(0), start_idx, stop_idx, base, &keys));
236+
ARROW_RETURN_NOT_OK(DeserializeList(data->field(1), start_idx, stop_idx, base, &vals));
237237
for (size_t i = start_idx; i < stop_idx; ++i) {
238238
PyDict_SetItem(result, PyList_GetItem(keys, i - start_idx), PyList_GetItem(vals, i - start_idx));
239239
}

python/src/pynumbuf/adapters/python.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ namespace numbuf {
1313

1414
arrow::Status SerializeSequences(std::vector<PyObject*> sequences, std::shared_ptr<arrow::Array>* out);
1515
arrow::Status SerializeDict(std::vector<PyObject*> dicts, std::shared_ptr<arrow::Array>* out);
16-
arrow::Status DeserializeList(std::shared_ptr<arrow::Array> array, int32_t start_idx, int32_t stop_idx, PyObject** out);
17-
arrow::Status DeserializeTuple(std::shared_ptr<arrow::Array> array, int32_t start_idx, int32_t stop_idx, PyObject** out);
18-
arrow::Status DeserializeDict(std::shared_ptr<arrow::Array> array, int32_t start_idx, int32_t stop_idx, PyObject** out);
16+
arrow::Status DeserializeList(std::shared_ptr<arrow::Array> array, int32_t start_idx, int32_t stop_idx, PyObject* base, PyObject** out);
17+
arrow::Status DeserializeTuple(std::shared_ptr<arrow::Array> array, int32_t start_idx, int32_t stop_idx, PyObject* base, PyObject** out);
18+
arrow::Status DeserializeDict(std::shared_ptr<arrow::Array> array, int32_t start_idx, int32_t stop_idx, PyObject* base, PyObject** out);
1919

2020
arrow::Status python_error_to_status();
2121

python/src/pynumbuf/numbuf.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,12 @@ static PyObject* read_from_buffer(PyObject* self, PyObject* args) {
126126
/* Documented in doc/numbuf.rst in ray-core */
127127
static PyObject* deserialize_list(PyObject* self, PyObject* args) {
128128
std::shared_ptr<RowBatch>* data;
129-
if (!PyArg_ParseTuple(args, "O&", &PyObjectToArrow, &data)) {
129+
PyObject* base = Py_None;
130+
if (!PyArg_ParseTuple(args, "O&|O", &PyObjectToArrow, &data, &base)) {
130131
return NULL;
131132
}
132133
PyObject* result;
133-
ARROW_CHECK_OK(DeserializeList((*data)->column(0), 0, (*data)->num_rows(), &result));
134+
ARROW_CHECK_OK(DeserializeList((*data)->column(0), 0, (*data)->num_rows(), base, &result));
134135
return result;
135136
}
136137

0 commit comments

Comments
 (0)