Skip to content

Commit a2a7205

Browse files
sourcery-ai[bot]pre-commit-ci[bot]henryiii
authored
Sourcery refactored develop branch (#544)
* 'Refactored by Sourcery' * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: cleanup tests and sets Co-authored-by: Sourcery AI <> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Henry Schreiner <[email protected]>
1 parent 871fbc5 commit a2a7205

File tree

6 files changed

+62
-53
lines changed

6 files changed

+62
-53
lines changed

src/boost_histogram/_internal/axis.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -345,10 +345,13 @@ def __init__(
345345
ax = ca.regular_uflow(bins, start, stop)
346346
elif options == {"overflow"}:
347347
ax = ca.regular_oflow(bins, start, stop)
348-
elif options == {"circular", "underflow", "overflow"} or options == {
349-
"circular",
350-
"overflow",
351-
}:
348+
elif options in [
349+
{"circular", "underflow", "overflow"},
350+
{
351+
"circular",
352+
"overflow",
353+
},
354+
]:
352355
# growth=True, underflow=False is also correct
353356
ax = ca.regular_circular(bins, start, stop)
354357

@@ -449,10 +452,17 @@ def __init__(
449452
ax = ca.variable_uflow(edges)
450453
elif options == {"overflow"}:
451454
ax = ca.variable_oflow(edges)
452-
elif options == {"circular", "underflow", "overflow",} or options == {
453-
"circular",
454-
"overflow",
455-
}:
455+
elif options in [
456+
{
457+
"circular",
458+
"underflow",
459+
"overflow",
460+
},
461+
{
462+
"circular",
463+
"overflow",
464+
},
465+
]:
456466
# growth=True, underflow=False is also correct
457467
ax = ca.variable_circular(edges)
458468
elif options == set():

src/boost_histogram/_internal/hist.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ def _fill_cast(value: T, *, inner: bool = False) -> Union[T, np.ndarray, Tuple[T
7575

7676

7777
def _arg_shortcut(item: Union[Tuple[int, float, float], Axis, CppAxis]) -> CppAxis:
78-
msg = "Developer shortcut: will be removed in a future version"
7978
if isinstance(item, tuple) and len(item) == 3:
79+
msg = "Developer shortcut: will be removed in a future version"
8080
warnings.warn(msg, FutureWarning)
8181
return _core.axis.regular_uoflow(item[0], item[1], item[2]) # type: ignore
8282
elif isinstance(item, Axis):
@@ -364,10 +364,10 @@ def _compute_inplace_op(
364364
len(other.shape), self.ndim
365365
)
366366
)
367-
elif all((a == b or a == 1) for a, b in zip(other.shape, self.shape)):
367+
elif all(a in {b, 1} for a, b in zip(other.shape, self.shape)):
368368
view = self.view(flow=False)
369369
getattr(view, name)(other)
370-
elif all((a == b or a == 1) for a, b in zip(other.shape, self.axes.extent)):
370+
elif all(a in {b, 1} for a, b in zip(other.shape, self.axes.extent)):
371371
view = self.view(flow=True)
372372
getattr(view, name)(other)
373373
else:
@@ -494,13 +494,11 @@ def __str__(self) -> str:
494494
"""
495495
# TODO check the terminal width and adjust the presentation
496496
# only use for 1D, fall back to repr for ND
497-
if self._hist.rank() == 1:
498-
s = str(self._hist)
499-
# get rid of first line and last character
500-
s = s[s.index("\n") + 1 : -1]
501-
else:
502-
s = repr(self)
503-
return s
497+
if self._hist.rank() != 1:
498+
return repr(self)
499+
s = str(self._hist)
500+
# get rid of first line and last character
501+
return s[s.index("\n") + 1 : -1]
504502

505503
def _axis(self, i: int = 0) -> Axis:
506504
"""
@@ -547,15 +545,14 @@ def __setstate__(self, state: Any) -> None:
547545
msg = "Cannot open boost-histogram pickle v{}".format(state[0])
548546
raise RuntimeError(msg)
549547

550-
self.axes = self._generate_axes_()
551-
552548
else: # Classic (0.10 and before) state
553549
self._hist = state["_hist"]
554550
self._variance_known = True
555551
self.metadata = state.get("metadata", None)
556552
for i in range(self._hist.rank()):
557553
self._hist.axis(i).metadata = {"metadata": self._hist.axis(i).metadata}
558-
self.axes = self._generate_axes_()
554+
555+
self.axes = self._generate_axes_()
559556

560557
def __repr__(self) -> str:
561558
newline = "\n "
@@ -779,14 +776,13 @@ def __getitem__( # noqa: C901
779776

780777
if not integrations:
781778
return self._new_hist(reduced)
782-
else:
783-
projections = [i for i in range(self.ndim) if i not in integrations]
779+
projections = [i for i in range(self.ndim) if i not in integrations]
784780

785-
return (
786-
self._new_hist(reduced.project(*projections))
787-
if projections
788-
else reduced.sum(flow=True)
789-
)
781+
return (
782+
self._new_hist(reduced.project(*projections))
783+
if projections
784+
else reduced.sum(flow=True)
785+
)
790786

791787
def __setitem__(
792788
self, index: IndexingExpr, value: Union[ArrayLike, Accumulator]

src/boost_histogram/_internal/view.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,16 @@ def __array_ufunc__(
111111
return ufunc(np.asarray(inputs[0]), np.asarray(inputs[1]), **kwargs) # type: ignore
112112

113113
# Support unary + and -
114-
if method == "__call__" and len(inputs) == 1:
115-
if ufunc in {np.negative, np.positive}:
116-
(result,) = kwargs.pop("out", [np.empty(self.shape, self.dtype)])
114+
if (
115+
method == "__call__"
116+
and len(inputs) == 1
117+
and ufunc in {np.negative, np.positive}
118+
):
119+
(result,) = kwargs.pop("out", [np.empty(self.shape, self.dtype)])
117120

118-
ufunc(inputs[0]["value"], out=result["value"], **kwargs)
119-
result["variance"] = inputs[0]["variance"]
120-
return result.view(self.__class__) # type: ignore
121+
ufunc(inputs[0]["value"], out=result["value"], **kwargs)
122+
result["variance"] = inputs[0]["variance"]
123+
return result.view(self.__class__) # type: ignore
121124

122125
if method == "__call__" and len(inputs) == 2:
123126
input_0 = np.asarray(inputs[0])

src/boost_histogram/numpy.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,12 @@ def histogram2d(
114114
threads=threads,
115115
)
116116

117-
if isinstance(result, tuple):
118-
data, (edgesx, edgesy) = result
119-
return data, edgesx, edgesy
120-
else:
117+
if not isinstance(result, tuple):
121118
return result
122119

120+
data, (edgesx, edgesy) = result
121+
return data, edgesx, edgesy
122+
123123

124124
def histogram(
125125
a: ArrayLike,
@@ -162,12 +162,12 @@ def histogram(
162162
storage=storage,
163163
threads=threads,
164164
)
165-
if isinstance(result, tuple):
166-
data, (edges,) = result
167-
return data, edges
168-
else:
165+
if not isinstance(result, tuple):
169166
return result
170167

168+
data, (edges,) = result
169+
return data, edges
170+
171171

172172
# Process docstrings
173173
for f, n in zip(

tests/test_accumulators.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,12 @@ def test_sum_mean(list1, list2):
115115

116116
ab = a + b
117117
assert ab.value == approx(c.value)
118-
assert ab.variance == approx(c.variance, nan_ok=True, abs=1e-9, rel=1e-9)
118+
assert ab.variance == approx(c.variance, nan_ok=True, abs=1e-7, rel=1e-3)
119119
assert ab.count == approx(c.count)
120120

121121
a += b
122122
assert a.value == approx(c.value)
123-
assert a.variance == approx(c.variance, nan_ok=True, abs=1e-9, rel=1e-9)
123+
assert a.variance == approx(c.variance, nan_ok=True, abs=1e-7, rel=1e-3)
124124
assert a.count == approx(c.count)
125125

126126

@@ -129,7 +129,7 @@ def test_sum_mean(list1, list2):
129129
st.lists(float_st, min_size=n, max_size=n),
130130
st.lists(
131131
st.floats(
132-
allow_nan=False, allow_infinity=False, min_value=1e-4, max_value=1e5
132+
allow_nan=False, allow_infinity=False, min_value=1e-2, max_value=1e3
133133
),
134134
min_size=n,
135135
max_size=n,
@@ -151,12 +151,12 @@ def test_sum_weighed_mean(pair1, pair2):
151151

152152
ab = a + b
153153
assert ab.value == approx(c.value)
154-
assert ab.variance == approx(c.variance, nan_ok=True, abs=1e-9, rel=1e-9)
154+
assert ab.variance == approx(c.variance, nan_ok=True, abs=1e-7, rel=1e-3)
155155
assert ab.sum_of_weights == approx(c.sum_of_weights)
156156
assert ab.sum_of_weights_squared == approx(c.sum_of_weights_squared)
157157

158158
a += b
159159
assert a.value == approx(c.value)
160-
assert a.variance == approx(c.variance, nan_ok=True, abs=1e-9, rel=1e-9)
160+
assert a.variance == approx(c.variance, nan_ok=True, abs=1e-7, rel=1e-3)
161161
assert a.sum_of_weights == approx(c.sum_of_weights)
162162
assert a.sum_of_weights_squared == approx(c.sum_of_weights_squared)

tests/test_minihist_title.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@ class NamedAxesTuple(bh.axis.AxesTuple):
1414
__slots__ = ()
1515

1616
def _get_index_by_name(self, name):
17-
if isinstance(name, str):
18-
for i, ax in enumerate(self):
19-
if ax.name == name:
20-
return i
21-
raise KeyError(f"{name} not found in axes")
22-
else:
17+
if not isinstance(name, str):
2318
return name
2419

20+
for i, ax in enumerate(self):
21+
if ax.name == name:
22+
return i
23+
raise KeyError(f"{name} not found in axes")
24+
2525
def __getitem__(self, item):
2626
if isinstance(item, slice):
2727
item = slice(

0 commit comments

Comments
 (0)