Skip to content

Commit 7b9b22a

Browse files
committed
Adding array setting
1 parent d8a1f4e commit 7b9b22a

File tree

7 files changed

+237
-72
lines changed

7 files changed

+237
-72
lines changed

README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ Python bindings for [Boost::Histogram][] ([source][Boost::Histogram source]), a
1717
> Please feel free to try out boost-histogram and give feedback.
1818
> Join the [discussion on gitter][gitter-link] or [open an issue](https://github.com/scikit-hep/boost-histogram/issues)!
1919
>
20-
> #### Known issues (develop):
21-
> * Setting with an array is not yet supported (`h[...] = np.array(...)`).
2220
2321

2422
## Installation

boost_histogram/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,8 @@
2525
try:
2626
from . import _core
2727
except ImportError as err:
28-
err.msg += (
29-
"\nDid you forget to compile? Use CMake or Setuptools to build, see the readme"
30-
)
28+
if "_core" in err.msg and "boost_histogram" in err.msg:
29+
err.msg += "\nDid you forget to compile? Use CMake or Setuptools to build, see the readme"
3130
raise err
3231

3332

boost_histogram/_internal/hist.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -465,8 +465,86 @@ def __getitem__(self, index):
465465
)
466466

467467
def __setitem__(self, index, value):
468-
indexes = self._compute_commonindex(index, expand_ellipsis=False)
469-
self._hist._at_set(value, *indexes)
468+
"""
469+
There are several supported possibilities:
470+
471+
h[slice] = array # same size
472+
473+
If an array is given to a compatible slice, it is set.
474+
475+
h[a:] = array # One larger
476+
477+
If an array is given that does not match, if it does match the
478+
with-overflow size, it fills that.
479+
480+
PLANNED (not yet supported):
481+
482+
h[a:] = h2
483+
484+
If another histogram is given, that must either match with or without
485+
overflow, where the overflow bins must be overflow bins (that is,
486+
you cannot set a histogram's flow bins from another histogram that
487+
is 2 larger). Bin edges must be a close match, as well. If you don't
488+
want this level of type safety, just use ``h[...] = h2.view()``.
489+
"""
490+
indexes = self._compute_commonindex(index, expand_ellipsis=True)
491+
492+
if isinstance(value, BaseHistogram):
493+
raise TypeError("Not supported yet")
494+
495+
value = np.asarray(value)
496+
view = self.view(flow=True)
497+
498+
# Disallow mismatched data types
499+
if len(value.dtype) != len(view.dtype):
500+
raise ValueError("Mismatched data types; matching types required")
501+
502+
# Numpy does not broadcast partial slices, but we would need
503+
# to allow it (because we do allow broadcasting up dimensions)
504+
# Instead, we simply require matching dimensions.
505+
if value.ndim > 0 and value.ndim != len(indexes):
506+
raise ValueError(
507+
"Setting a histogram with an array must have a matching number of dimensions"
508+
)
509+
510+
for n in range(len(indexes)):
511+
request = indexes[n]
512+
has_underflow = self.axes[n].options.underflow
513+
has_overflow = self.axes[n].options.overflow
514+
515+
if isinstance(request, slice):
516+
# Only consider underflow/overflow if the endpoints are not given
517+
use_underflow = has_underflow and request.start is None
518+
use_overflow = has_overflow and request.stop is None
519+
520+
# Make the limits explicit since we may need to shift them
521+
start = 0 if request.start is None else request.start
522+
stop = len(self.axes[n]) if request.stop is None else request.stop
523+
request_len = stop - start
524+
525+
# If there are not enough dimensions, then treat it like broadcasting
526+
if value.ndim == 0 or value.shape[n] == 1:
527+
start = 0 + has_overflow
528+
stop = len(self.axes[n]) + has_underflow
529+
elif request_len == value.shape[n]:
530+
start += has_underflow
531+
stop += has_underflow
532+
elif request_len + use_underflow + use_overflow == value.shape[n]:
533+
start += has_underflow and not use_underflow
534+
stop += has_underflow + (has_overflow and use_overflow)
535+
else:
536+
msg = "Mismatched shapes in dimension {0}".format(n)
537+
msg += ", {0} != {1}".format(value.shape[n], request_len)
538+
if use_underflow or use_overflow:
539+
msg += " or {0}".format(
540+
request_len + use_underflow + use_overflow
541+
)
542+
raise ValueError(msg)
543+
indexes[n] = slice(start, stop, request.step)
544+
else:
545+
indexes[n] = request + has_underflow
546+
547+
view[tuple(indexes)] = value
470548

471549
def reduce(self, *args):
472550
"""

include/histogram.hpp

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -119,62 +119,3 @@ const F& pyarray_at(const py::array_t<F, Opt>& input,
119119
}
120120
return *(input.data() + flat_offset);
121121
}
122-
123-
template <class Histogram>
124-
void copy_in(Histogram&, const py::array_t<double>&) {}
125-
126-
template <>
127-
inline void copy_in<>(bh::histogram<vector_axis_variant, bh::dense_storage<double>>& h,
128-
const py::array_t<double>& input) {
129-
// Works on simple datatypes only
130-
// TODO: Add other types
131-
if(h.rank() != input.ndim())
132-
throw py::value_error(
133-
"The input array dimensions must match the histogram rank");
134-
135-
// Quick check to ensure the input array is valid
136-
for(unsigned r = 0; r < h.rank(); r++) {
137-
auto input_shape = input.shape(static_cast<py::ssize_t>(r));
138-
if(input_shape != bh::axis::traits::extent(h.axis(r))
139-
&& input_shape != h.axis(r).size() && input_shape != 1)
140-
throw py::value_error("The input array sizes must match the histogram "
141-
"(with or without flow), or be broadcastable to it");
142-
}
143-
144-
std::vector<py::ssize_t> indexes;
145-
indexes.resize(h.rank());
146-
147-
for(auto&& ind : bh::indexed(h, bh::coverage::all)) {
148-
bool skip = false;
149-
150-
for(unsigned r = 0; r < h.rank(); r++) {
151-
auto input_shape = input.shape(static_cast<py::ssize_t>(r));
152-
bool use_flow = input_shape == bh::axis::traits::extent(h.axis(r));
153-
bool has_underflow = h.axis(r).options() & bh::axis::option::underflow;
154-
155-
// Broadcast size 1
156-
if(input_shape == 1)
157-
indexes[r] = 0;
158-
159-
// If this is using flow bins and has an underflow bin, convert -1 to 0
160-
// (etc)
161-
else if(use_flow && has_underflow)
162-
indexes[r] = ind.index(r) + 1;
163-
164-
// If not using flow bins, skip the flow bins
165-
else if(!use_flow
166-
&& (ind.index(r) < 0 || ind.index(r) >= h.axis(r).size())) {
167-
skip = true;
168-
break;
169-
170-
// Otherwise, this is normal
171-
} else
172-
indexes[r] = ind.index(r);
173-
}
174-
175-
if(skip)
176-
continue;
177-
178-
*ind = pyarray_at(input, indexes);
179-
}
180-
}

include/register_histogram.hpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,6 @@ auto register_histogram(py::module& m, const char* name, const char* desc) {
108108
},
109109
"flow"_a = false)
110110

111-
.def("_copy_in",
112-
[](histogram_t& h, py::array_t<double> input) { copy_in(h, input); })
113-
114111
.def(
115112
"view",
116113
[](py::object self, bool flow) {

src/register_accumulators.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,21 @@ decltype(auto) make_mean_call() {
5151
};
5252
}
5353

54+
template <class T>
55+
decltype(auto) make_buffer() {
56+
return [](T& self) -> py::buffer_info {
57+
return py::buffer_info(
58+
&self, // Pointer to buffer
59+
sizeof(T), // Size of one scalar
60+
py::format_descriptor<T>::format(), // Format registered with
61+
// PYBIND11_NUMPY_DTYPE
62+
0, // Number of dimensions
63+
{}, // Buffer dimensions
64+
{} // Stride
65+
);
66+
};
67+
}
68+
5469
void register_accumulators(py::module& accumulators) {
5570
// Naming convention:
5671
// If a value is publically available in Boost.Histogram accumulators
@@ -62,7 +77,10 @@ void register_accumulators(py::module& accumulators) {
6277

6378
PYBIND11_NUMPY_DTYPE(weighted_sum, value, variance);
6479

65-
register_accumulator<weighted_sum>(accumulators, "WeightedSum")
80+
register_accumulator<weighted_sum>(
81+
accumulators, "WeightedSum", py::buffer_protocol())
82+
83+
.def_buffer(make_buffer<weighted_sum>())
6684

6785
.def(py::init<const double&>(), "value"_a)
6886
.def(py::init<const double&, const double&>(), "value"_a, "variance"_a)
@@ -158,7 +176,11 @@ void register_accumulators(py::module& accumulators) {
158176
value,
159177
sum_of_weighted_deltas_squared);
160178

161-
register_accumulator<weighted_mean>(accumulators, "WeightedMean")
179+
register_accumulator<weighted_mean>(
180+
accumulators, "WeightedMean", py::buffer_protocol())
181+
182+
.def_buffer(make_buffer<weighted_mean>())
183+
162184
.def(py::init<const double&, const double&, const double&, const double&>(),
163185
"sum_of_weights"_a,
164186
"sum_of_weights_squared"_a,
@@ -238,7 +260,9 @@ void register_accumulators(py::module& accumulators) {
238260
using mean = accumulators::mean<double>;
239261
PYBIND11_NUMPY_DTYPE(mean, count, value, sum_of_deltas_squared);
240262

241-
register_accumulator<mean>(accumulators, "Mean")
263+
register_accumulator<mean>(accumulators, "Mean", py::buffer_protocol())
264+
.def_buffer(make_buffer<mean>())
265+
242266
.def(py::init<const double&, const double&, const double&>(),
243267
"count"_a,
244268
"value"_a,

tests/test_histogram_set.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import boost_histogram as bh
2+
import numpy as np
3+
4+
from numpy.testing import assert_array_equal
5+
6+
import pytest
7+
8+
9+
def test_1D_set_bin():
10+
11+
h = bh.Histogram(bh.axis.Regular(10, 0, 1))
12+
13+
h[2] = 2
14+
assert h[2] == 2.0
15+
16+
h[bh.underflow] = 3
17+
assert h[bh.underflow] == 3.0
18+
19+
h[bh.overflow] = 4
20+
assert h[bh.overflow] == 4.0
21+
22+
23+
def test_2d_set_bin():
24+
25+
h = bh.Histogram(bh.axis.Regular(10, 0, 1), bh.axis.Regular(10, 0, 1))
26+
27+
h[2, 4] = 2
28+
assert h[2, 4] == 2.0
29+
30+
h[bh.underflow, 5] = 3
31+
assert h[bh.underflow, 5] == 3.0
32+
33+
h[bh.overflow, bh.overflow] = 4
34+
assert h[bh.overflow, bh.overflow] == 4.0
35+
36+
37+
def test_1d_set_array():
38+
h = bh.Histogram(bh.axis.Regular(10, 0, 1))
39+
40+
h[...] = np.arange(10)
41+
assert_array_equal(h.view(), np.arange(10))
42+
43+
h[...] = np.arange(12)
44+
assert_array_equal(h.view(flow=True), np.arange(12))
45+
46+
with pytest.raises(ValueError):
47+
h[...] = np.arange(9)
48+
with pytest.raises(ValueError):
49+
h[...] = np.arange(11)
50+
with pytest.raises(ValueError):
51+
h[...] = np.arange(13)
52+
53+
h[...] = 1
54+
assert_array_equal(h.view(), np.ones(10))
55+
56+
57+
def test_2d_set_array():
58+
h = bh.Histogram(bh.axis.Regular(10, 0, 1), bh.axis.Regular(10, 0, 1))
59+
60+
h[...] = np.arange(10).reshape(-1, 1)
61+
assert_array_equal(h.view()[:, 2], np.arange(10))
62+
63+
h[...] = np.arange(12).reshape(-1, 1)
64+
assert_array_equal(h.view(flow=True)[:, 3], np.arange(12))
65+
66+
with pytest.raises(ValueError):
67+
h[...] = np.arange(9).reshape(-1, 1)
68+
with pytest.raises(ValueError):
69+
h[...] = np.arange(11).reshape(-1, 1)
70+
with pytest.raises(ValueError):
71+
h[...] = np.arange(13).reshape(-1, 1)
72+
73+
h[...] = 1
74+
assert_array_equal(h.view(), np.ones((10, 10)))
75+
76+
77+
@pytest.mark.parametrize(
78+
"storage, default",
79+
(
80+
(bh.storage.Mean, bh.accumulators.Mean(1.0, 2.0, 3.0)),
81+
(bh.storage.WeightedMean, bh.accumulators.WeightedMean(1.0, 2.0, 3.0, 4.0)),
82+
(bh.storage.Weight, bh.accumulators.WeightedSum(1.0, 2)),
83+
),
84+
)
85+
def test_set_special_dtype(storage, default):
86+
h = bh.Histogram(
87+
bh.axis.Regular(10, 0, 1), bh.axis.Regular(10, 0, 1), storage=storage()
88+
)
89+
90+
arr = np.full((10, 1), default)
91+
h[...] = arr
92+
assert_array_equal(h.view()[:, 1:2], arr)
93+
94+
arr = np.full((12, 1), default)
95+
h[...] = arr
96+
assert_array_equal(h.view(flow=True)[:, 2:3], arr)
97+
98+
arr = np.full((10, 10), default)
99+
h[...] = arr
100+
assert_array_equal(h.view(), arr)
101+
102+
arr = np.full((10, 12), default)
103+
h[...] = arr
104+
assert_array_equal(h.view(flow=True)[1:11, :], arr)
105+
106+
arr = np.full((12, 10), default)
107+
h[...] = arr
108+
assert_array_equal(h.view(flow=True)[:, 1:11], arr)
109+
110+
arr = np.full((12, 12), default)
111+
h[...] = arr
112+
assert_array_equal(h.view(flow=True), arr)
113+
114+
with pytest.raises(ValueError):
115+
arr = np.full((9, 1), default)
116+
h[...] = arr
117+
with pytest.raises(ValueError):
118+
arr = np.full((11, 1), default)
119+
h[...] = arr
120+
with pytest.raises(ValueError):
121+
arr = np.full((13, 1), default)
122+
h[...] = arr
123+
124+
with pytest.raises(ValueError):
125+
h[...] = 1
126+
127+
with pytest.raises(ValueError):
128+
h[1, 1] = 1

0 commit comments

Comments
 (0)