Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 19 additions & 35 deletions rioxarray/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy
from rasterio.crs import CRS
from rasterio.io import MemoryFile
from rasterio.merge import merge as _rio_merge
from xarray import DataArray, Dataset

Expand Down Expand Up @@ -53,49 +54,32 @@ def colormap(self, *args, **kwargs):
See: https://github.com/corteva/rioxarray/issues/479
"""
try:
rio_file = self.xds.rio._manager.acquire()
rio_file = self._xds.rio._manager.acquire()
return rio_file.colormap(*args, **kwargs)
except AttributeError:
return None

def read(self, window, out_shape, *args, **kwargs) -> numpy.ma.MaskedArray:
# pylint: disable=unused-argument
def read(self, *args, **kwargs) -> numpy.ma.MaskedArray:
"""
This method is meant to be used by the rasterio.merge.merge function.
"""
data_window = self._xds.rio.isel_window(window)
if data_window.shape != out_shape:
# in this section, the data is geographically the same
# however it is not the same dimensions as requested
# so need to resample to the requested shape
if len(out_shape) == 3:
_, out_height, out_width = out_shape
else:
out_height, out_width = out_shape
data_window = self._xds.rio.reproject(
self._xds.rio.crs,
with MemoryFile() as memfile:
with memfile.open(
driver="GTiff",
height=int(self._xds.rio.height),
width=int(self._xds.rio.width),
count=self.count,
dtype=self.dtypes[0],
crs=self.crs,
transform=self.transform,
shape=(out_height, out_width),
)

nodata = self.nodatavals[0]
mask = False
fill_value = None
if nodata is not None and numpy.isnan(nodata):
mask = numpy.isnan(data_window)
elif nodata is not None:
mask = data_window == nodata
fill_value = nodata

# make sure the returned shape matches
# the expected shape. This can be the case
# when the xarray dataset was squeezed to 2D beforehand
if len(out_shape) == 3 and len(data_window.shape) == 2:
data_window = data_window.values.reshape((1, out_height, out_width))

return numpy.ma.array(
data_window, mask=mask, fill_value=fill_value, dtype=self.dtypes[0]
)
nodata=self.nodatavals[0],
) as dataset:
data = self._xds.values
if data.ndim == 2:
dataset.write(data, 1)
else:
dataset.write(data)
return dataset.read(*args, **kwargs)


def merge_arrays(
Expand Down