Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
gdal-version: ['3.10.0']
include:
- python-version: '3.10'
rasterio-version: '==1.3.7'
rasterio-version: ''
xarray-version: '==2024.7.0'
numpy-version: '<2'
run-with-scipy: 'YES'
Expand Down
2 changes: 2 additions & 0 deletions docs/history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ History

Latest
------
- BUG:merge: Revert `rasterio.io.MemoryFile` code (#850)
- DEP: pin rasterio >= 1.4.3 (#850)

0.18.2
------
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ classifiers = [
requires-python = ">=3.10"
dependencies = [
"packaging",
"rasterio>=1.3.7",
"rasterio>=1.4.3",
"xarray>=2024.7.0",
"pyproj>=3.3",
"numpy>=1.23",
Expand Down
169 changes: 97 additions & 72 deletions rioxarray/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@
This module allows you to merge xarray Datasets/DataArrays
geospatially with the `rasterio.merge` module.
"""

from collections.abc import Sequence
from typing import Callable, Optional, Union

import numpy
from rasterio.crs import CRS
from rasterio.io import MemoryFile
from rasterio.merge import merge as _rio_merge
from xarray import DataArray, Dataset
from xarray import DataArray, Dataset, IndexVariable

from rioxarray._io import open_rasterio
from rioxarray.rioxarray import _get_nonspatial_coords
from rioxarray.rioxarray import _get_nonspatial_coords, _make_coords


class RasterioDatasetDuck:
Expand All @@ -31,29 +30,13 @@ def __init__(self, xds: DataArray):
self.count = int(xds.rio.count)
self.dtypes = [xds.dtype]
self.name = xds.name
if xds.rio.encoded_nodata is not None:
self.nodatavals = [xds.rio.encoded_nodata]
else:
self.nodatavals = [xds.rio.nodata]
self.nodatavals = [xds.rio.nodata]
res = xds.rio.resolution(recalc=True)
self.res = (abs(res[0]), abs(res[1]))
self.transform = xds.rio.transform(recalc=True)
self.profile: dict = {
"crs": self.crs,
"nodata": self.nodatavals[0],
}
valid_scale_factor = self._xds.encoding.get("scale_factor", 1) != 1 or any(
scale != 1 for scale in self._xds.encoding.get("scales", (1,))
)
valid_offset = self._xds.encoding.get("add_offset", 0.0) != 0 or any(
offset != 0 for offset in self._xds.encoding.get("offsets", (0,))
)
self._mask_and_scale = (
self._xds.rio.encoded_nodata is not None
or valid_scale_factor
or valid_offset
or self._xds.encoding.get("_Unsigned") is not None
)
# profile is only used for writing to a file.
# This never happens with rioxarray merge.
self.profile: dict = {}

def colormap(self, *args, **kwargs) -> None:
"""
Expand All @@ -63,21 +46,44 @@ def colormap(self, *args, **kwargs) -> None:
# pylint: disable=unused-argument
return None

def read(self, *args, **kwargs) -> numpy.ma.MaskedArray:
def read(self, window, out_shape, *args, **kwargs) -> numpy.ma.MaskedArray:
# pylint: disable=unused-argument
"""
This method is meant to be used by the rasterio.merge.merge function.
"""
with MemoryFile() as memfile:
self._xds.rio.to_raster(memfile.name)
with memfile.open() as dataset:
if self._mask_and_scale:
kwargs["masked"] = True
out = dataset.read(*args, **kwargs)
if self._mask_and_scale:
out = out.astype(self._xds.dtype)
for iii in range(self.count):
out[iii] = out[iii] * dataset.scales[iii] + dataset.offsets[iii]
return out
data_window = self._xds.rio.isel_window(window)
if data_window.shape != out_shape:
# in this section, the data is geographically the same
# however it is not the same dimensions as requested
# so need to resample to the requested shape
if len(out_shape) == 3:
_, out_height, out_width = out_shape
else:
out_height, out_width = out_shape
data_window = self._xds.rio.reproject(
self._xds.rio.crs,
transform=self.transform,
shape=(out_height, out_width),
)

nodata = self.nodatavals[0]
mask = False
fill_value = None
if nodata is not None and numpy.isnan(nodata):
mask = numpy.isnan(data_window)
elif nodata is not None:
mask = data_window == nodata
fill_value = nodata

# make sure the returned shape matches
# the expected shape. This can be the case
# when the xarray dataset was squeezed to 2D beforehand
if len(out_shape) == 3 and len(data_window.shape) == 2:
data_window = data_window.values.reshape((1, out_height, out_width))

return numpy.ma.array(
data_window, mask=mask, fill_value=fill_value, dtype=self.dtypes[0]
)


def merge_arrays(
Expand Down Expand Up @@ -155,47 +161,66 @@ def merge_arrays(
rioduckarrays.append(RasterioDatasetDuck(dataarray))

# use rasterio to merge
merged_data, merged_transform = _rio_merge(
rioduckarrays,
**{key: val for key, val in input_kwargs.items() if val is not None},
)
# generate merged data array
representative_array = rioduckarrays[0]._xds
with MemoryFile() as memfile:
_rio_merge(
rioduckarrays,
**{key: val for key, val in input_kwargs.items() if val is not None},
dst_path=memfile.name,
if parse_coordinates:
coords = _make_coords(
src_data_array=representative_array,
dst_affine=merged_transform,
dst_width=merged_data.shape[-1],
dst_height=merged_data.shape[-2],
)
with open_rasterio( # type: ignore
memfile.name,
parse_coordinates=parse_coordinates,
mask_and_scale=rioduckarrays[0]._mask_and_scale,
) as merged_data:
merged_data = merged_data.load()

# make sure old & new coorinate names match & dimensions are correct
rename_map = {}
original_extra_dim = representative_array.rio._check_dimensions()
new_extra_dim = merged_data.rio._check_dimensions()
# make sure the output merged data shape is 2D if the
# original data was 2D. this can happen if the
# xarray datasarray was squeezed.
if len(merged_data.shape) == 3 and len(representative_array.shape) == 2:
merged_data = merged_data.squeeze(
dim=new_extra_dim, drop=original_extra_dim is None
if (
representative_array.rio.x_dim != "x"
and "x" in coords
and coords["x"].ndim == 1
):
coords[representative_array.rio.x_dim] = IndexVariable(
representative_array.rio.x_dim, coords.pop("x")
)
new_extra_dim = merged_data.rio._check_dimensions()
if (
original_extra_dim is not None
and new_extra_dim is not None
and original_extra_dim != new_extra_dim
representative_array.rio.y_dim != "y"
and "y" in coords
and coords["y"].ndim == 1
):
rename_map[new_extra_dim] = original_extra_dim
if representative_array.rio.x_dim != merged_data.rio.x_dim:
rename_map[merged_data.rio.x_dim] = representative_array.rio.x_dim
if representative_array.rio.y_dim != merged_data.rio.y_dim:
rename_map[merged_data.rio.y_dim] = representative_array.rio.y_dim
if rename_map:
merged_data = merged_data.rename(rename_map)
merged_data.coords.update(_get_nonspatial_coords(representative_array))
return merged_data # type: ignore
coords[representative_array.rio.y_dim] = IndexVariable(
representative_array.rio.y_dim, coords.pop("y")
)
else:
coords = _get_nonspatial_coords(representative_array)

# make sure the output merged data shape is 2D if the
# original data was 2D. this can happen if the
# xarray datasarray was squeezed.
if len(merged_data.shape) == 3 and len(representative_array.shape) == 2:
merged_data = merged_data.squeeze()

xda = DataArray(
name=representative_array.name,
data=merged_data,
coords=coords,
dims=tuple(representative_array.dims),
attrs=representative_array.attrs,
)
xda.encoding = representative_array.encoding.copy()
xda.rio.write_nodata(
nodata if nodata is not None else representative_array.rio.nodata, inplace=True
)
xda.rio.write_crs(
representative_array.rio.crs,
grid_mapping_name=representative_array.rio.grid_mapping,
inplace=True,
)
xda.rio.write_transform(
merged_transform,
grid_mapping_name=representative_array.rio.grid_mapping,
inplace=True,
)
return xda


def merge_datasets(
Expand All @@ -218,7 +243,7 @@ def merge_datasets(
Parameters
----------
datasets: list[xarray.Dataset]
List of multiple xarray.Dataset with all geo attributes.
List of xarray.Dataset's with all geo attributes.
The first one is assumed to have the same
CRS, dtype, dimensions, and data_vars as the others in the array.
bounds: tuple, optional
Expand Down
2 changes: 0 additions & 2 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "test_data")
TEST_INPUT_DATA_DIR = os.path.join(TEST_DATA_DIR, "input")
TEST_COMPARE_DATA_DIR = os.path.join(TEST_DATA_DIR, "compare")
RASTERIO_GE_14 = version.parse(rasterio.__version__) >= version.parse("1.4.0")
RASTERIO_GE_143 = version.parse(rasterio.__version__) >= version.parse("1.4.3")
GDAL_GE_36 = version.parse(rasterio.__gdal_version__) >= version.parse("3.6.0")
GDAL_GE_361 = version.parse(rasterio.__gdal_version__) >= version.parse("3.6.1")
GDAL_GE_364 = version.parse(rasterio.__gdal_version__) >= version.parse("3.6.4")
Expand Down
19 changes: 4 additions & 15 deletions test/integration/test_integration_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from rioxarray import open_rasterio
from rioxarray.merge import merge_arrays, merge_datasets
from test.conftest import RASTERIO_GE_14, RASTERIO_GE_143, TEST_INPUT_DATA_DIR
from test.conftest import TEST_INPUT_DATA_DIR


@pytest.mark.parametrize("squeeze", [True, False])
Expand Down Expand Up @@ -52,12 +52,7 @@ def test_merge_arrays(squeeze):
assert sorted(merged.coords) == sorted(rds.coords)
assert merged.coords["band"].values == [1]
assert merged.rio.crs == rds.rio.crs
assert merged.attrs == {
"AREA_OR_POINT": "Area",
"add_offset": 0.0,
"scale_factor": 1.0,
**rds.attrs,
}
assert merged.attrs == rds.attrs
assert merged.encoding["grid_mapping"] == "spatial_ref"


Expand Down Expand Up @@ -90,10 +85,7 @@ def test_merge__different_crs(dataset):
(-7300984.0238134, 5003618.5908794, -7224054.1109682, 5050108.6101528),
)
assert merged.rio.shape == (84, 139)
if RASTERIO_GE_14 and not RASTERIO_GE_143:
assert_almost_equal(test_sum, -126821853)
else:
assert_almost_equal(test_sum, -131734881)
assert_almost_equal(test_sum, -131734881)

assert_almost_equal(
tuple(merged.rio.transform()),
Expand All @@ -113,10 +105,7 @@ def test_merge__different_crs(dataset):
assert merged.rio.crs == rds.rio.crs
if not dataset:
assert merged.attrs == {
"AREA_OR_POINT": "Area",
"_FillValue": -28672,
"add_offset": 0.0,
"scale_factor": 1.0,
}
assert merged.encoding["grid_mapping"] == "spatial_ref"

Expand Down Expand Up @@ -284,7 +273,7 @@ def test_merge_datasets__mask_and_scale(mask_and_scale):
rds.isel(x=slice(100, None), y=slice(100)),
]
merged = merge_datasets(datasets)
assert sorted(merged.coords) == sorted(list(rds.coords) + ["spatial_ref"])
assert sorted(merged.coords) == sorted(list(rds.coords))
total = merged.air_temperature.sum()
if mask_and_scale:
assert_almost_equal(total, 133376696)
Expand Down
17 changes: 3 additions & 14 deletions test/integration/test_integration_rioxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from rioxarray.rioxarray import _generate_spatial_coords, _make_coords
from test.conftest import (
GDAL_GE_361,
RASTERIO_GE_14,
TEST_COMPARE_DATA_DIR,
TEST_INPUT_DATA_DIR,
_assert_xarrays_equal,
Expand Down Expand Up @@ -2225,12 +2224,9 @@ def test_reproject_transform_missing_shape():
"dtype, expected_nodata",
[
(numpy.uint8, 255),
pytest.param(
(
numpy.int8,
-128,
marks=pytest.mark.xfail(
not RASTERIO_GE_14, reason="Not worried about it if it works on latest."
),
),
(numpy.uint16, 65535),
(numpy.int16, -32768),
Expand All @@ -2240,19 +2236,13 @@ def test_reproject_transform_missing_shape():
(numpy.float64, numpy.nan),
(numpy.complex64, numpy.nan),
(numpy.complex128, numpy.nan),
pytest.param(
(
numpy.uint64,
18446744073709551615,
marks=pytest.mark.xfail(
not RASTERIO_GE_14, reason="Not worried about it if it works on latest."
),
),
pytest.param(
(
numpy.int64,
-9223372036854775808,
marks=pytest.mark.xfail(
not RASTERIO_GE_14, reason="Not worried about it if it works on latest."
),
),
],
)
Expand Down Expand Up @@ -3225,7 +3215,6 @@ def test_bounds__ordered__dataset():
assert xds.rio.bounds() == (-0.5, -0.5, 4.5, 4.5)


@pytest.mark.skipif(not RASTERIO_GE_14, reason="Requires rasterio 1.4+")
@pytest.mark.parametrize(
"rename",
[
Expand Down
Loading