diff --git a/rioxarray/merge.py b/rioxarray/merge.py index 13d8899b..0f18d079 100644 --- a/rioxarray/merge.py +++ b/rioxarray/merge.py @@ -8,6 +8,7 @@ import numpy from rasterio.crs import CRS +from rasterio.io import MemoryFile from rasterio.merge import merge as _rio_merge from xarray import DataArray, Dataset @@ -53,49 +54,32 @@ def colormap(self, *args, **kwargs): See: https://github.com/corteva/rioxarray/issues/479 """ try: - rio_file = self.xds.rio._manager.acquire() + rio_file = self._xds.rio._manager.acquire() return rio_file.colormap(*args, **kwargs) except AttributeError: return None - def read(self, window, out_shape, *args, **kwargs) -> numpy.ma.MaskedArray: - # pylint: disable=unused-argument + def read(self, *args, **kwargs) -> numpy.ma.MaskedArray: """ This method is meant to be used by the rasterio.merge.merge function. """ - data_window = self._xds.rio.isel_window(window) - if data_window.shape != out_shape: - # in this section, the data is geographically the same - # however it is not the same dimensions as requested - # so need to resample to the requested shape - if len(out_shape) == 3: - _, out_height, out_width = out_shape - else: - out_height, out_width = out_shape - data_window = self._xds.rio.reproject( - self._xds.rio.crs, + with MemoryFile() as memfile: + with memfile.open( + driver="GTiff", + height=int(self._xds.rio.height), + width=int(self._xds.rio.width), + count=self.count, + dtype=self.dtypes[0], + crs=self.crs, transform=self.transform, - shape=(out_height, out_width), - ) - - nodata = self.nodatavals[0] - mask = False - fill_value = None - if nodata is not None and numpy.isnan(nodata): - mask = numpy.isnan(data_window) - elif nodata is not None: - mask = data_window == nodata - fill_value = nodata - - # make sure the returned shape matches - # the expected shape. This can be the case - # when the xarray dataset was squeezed to 2D beforehand - if len(out_shape) == 3 and len(data_window.shape) == 2: - data_window = data_window.values.reshape((1, out_height, out_width)) - - return numpy.ma.array( - data_window, mask=mask, fill_value=fill_value, dtype=self.dtypes[0] - ) + nodata=self.nodatavals[0], + ) as dataset: + data = self._xds.values + if data.ndim == 2: + dataset.write(data, 1) + else: + dataset.write(data) + return dataset.read(*args, **kwargs) def merge_arrays(