Skip to content

Commit

Permalink
Merge pull request #358 from nhatnm52/write-with-more-downsample-methods
Browse files Browse the repository at this point in the history
Allow to change downsample function
  • Loading branch information
joshmoore authored Apr 24, 2024
2 parents 028e986 + 93428aa commit e33f736
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 4 deletions.
12 changes: 9 additions & 3 deletions ome_zarr/scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@ def methods() -> Iterator[str]:

def scale(self, input_array: str, output_directory: str) -> None:
"""Perform downsampling to disk."""
func = getattr(self, self.method, None)
if not func:
raise Exception
func = self.func

store = self.__check_store(output_directory)
base = zarr.open_array(input_array)
Expand All @@ -94,6 +92,14 @@ def scale(self, input_array: str, output_directory: str) -> None:
print(f"copying attribute keys: {list(base.attrs.keys())}")
grp.attrs.update(base.attrs)

@property
def func(self) -> Callable[[np.ndarray], List[np.ndarray]]:
"""Get downsample function."""
func = getattr(self, self.method, None)
if not func:
raise Exception
return func

def __check_store(self, output_directory: str) -> MutableMapping:
"""Return a Zarr store if it doesn't already exist."""
assert not os.path.exists(output_directory)
Expand Down
2 changes: 1 addition & 1 deletion ome_zarr/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,7 @@ def _create_mip(
"Can't downsample if size of x or y dimension is 1. "
"Shape: %s" % (image.shape,)
)
mip = scaler.nearest(image)
mip = scaler.func(image)
else:
LOGGER.debug("disabling pyramid")
mip = [image]
Expand Down
54 changes: 54 additions & 0 deletions tests/test_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,26 @@ def test_nearest(self, shape):
downscaled = scaler.nearest(data)
self.check_downscaled(downscaled, shape)

def test_nearest_via_method(self, shape):
data = self.create_data(shape)

scaler = Scaler()
expected_downscaled = scaler.nearest(data)

scaler.method = "nearest"
downscaled = scaler.func(data)
self.check_downscaled(downscaled, shape)

assert (
np.sum(
[
not np.array_equal(downscaled[i], expected_downscaled[i])
for i in range(len(downscaled))
]
)
== 0
)

# this fails because of wrong channel dimension; need to fix in follow-up PR
@pytest.mark.xfail
def test_gaussian(self, shape):
Expand All @@ -58,6 +78,26 @@ def test_local_mean(self, shape):
downscaled = scaler.local_mean(data)
self.check_downscaled(downscaled, shape)

def test_local_mean_via_method(self, shape):
data = self.create_data(shape)

scaler = Scaler()
expected_downscaled = scaler.local_mean(data)

scaler.method = "local_mean"
downscaled = scaler.func(data)
self.check_downscaled(downscaled, shape)

assert (
np.sum(
[
not np.array_equal(downscaled[i], expected_downscaled[i])
for i in range(len(downscaled))
]
)
== 0
)

@pytest.mark.skip(reason="This test does not terminate")
def test_zoom(self, shape):
data = self.create_data(shape)
Expand All @@ -80,6 +120,20 @@ def test_scale_dask(self, shape):

assert np.array_equal(resized_data, resized_dask)

def test_scale_dask_via_method(self, shape):
data = self.create_data(shape)

chunk_size = [100, 100]
chunk_2d = (*(1,) * (data.ndim - 2), *chunk_size)
data_delayed = da.from_array(data, chunks=chunk_2d)

scaler = Scaler()
expected_downscaled = scaler.resize_image(data)

scaler.method = "resize_image"
assert np.array_equal(expected_downscaled, scaler.func(data))
assert np.array_equal(expected_downscaled, scaler.func(data_delayed))

def test_big_dask_pyramid(self, tmpdir):
# from https://github.com/ome/omero-cli-zarr/pull/134
shape = (6675, 9560)
Expand Down

0 comments on commit e33f736

Please sign in to comment.