Skip to content

Commit

Permalink
Avoid in-place multiplication of a large value to an array with small…
Browse files Browse the repository at this point in the history
… integer dtype (#8867)

* Avoid inplace multiplication

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update test_plot.py

* Update test_plot.py

* Update dataarray_plot.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Illviljan and pre-commit-ci[bot] authored Mar 29, 2024
1 parent ffb30a8 commit afce18f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
5 changes: 3 additions & 2 deletions xarray/plot/dataarray_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1848,9 +1848,10 @@ def _center_pixels(x):
# missing data transparent. We therefore add an alpha channel if
# there isn't one, and set it to transparent where data is masked.
if z.shape[-1] == 3:
alpha = np.ma.ones(z.shape[:2] + (1,), dtype=z.dtype)
safe_dtype = np.promote_types(z.dtype, np.uint8)
alpha = np.ma.ones(z.shape[:2] + (1,), dtype=safe_dtype)
if np.issubdtype(z.dtype, np.integer):
alpha *= 255
alpha[:] = 255
z = np.ma.concatenate((z, alpha), axis=2)
else:
z = z.copy()
Expand Down
12 changes: 7 additions & 5 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2040,15 +2040,17 @@ def test_normalize_rgb_one_arg_error(self) -> None:
for vmin2, vmax2 in ((-1.2, -1), (2, 2.1)):
da.plot.imshow(vmin=vmin2, vmax=vmax2)

def test_imshow_rgb_values_in_valid_range(self) -> None:
da = DataArray(np.arange(75, dtype="uint8").reshape((5, 5, 3)))
@pytest.mark.parametrize("dtype", [np.uint8, np.int8, np.int16])
def test_imshow_rgb_values_in_valid_range(self, dtype) -> None:
da = DataArray(np.arange(75, dtype=dtype).reshape((5, 5, 3)))
_, ax = plt.subplots()
out = da.plot.imshow(ax=ax).get_array()
assert out is not None
dtype = out.dtype
assert dtype is not None
assert dtype == np.uint8
actual_dtype = out.dtype
assert actual_dtype is not None
assert actual_dtype == np.uint8
assert (out[..., :3] == da.values).all() # Compare without added alpha
assert (out[..., -1] == 255).all() # Compare alpha

@pytest.mark.filterwarnings("ignore:Several dimensions of this array")
def test_regression_rgb_imshow_dim_size_one(self) -> None:
Expand Down

0 comments on commit afce18f

Please sign in to comment.