diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index fd7c86da6..cb9428bbd 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -32,17 +32,17 @@ jobs: echo "IMAGE_SPEC=${IMAGE_SPEC}" >> $GITHUB_ENV echo "DATE_TAG=${DATE_TAG}" >> $GITHUB_ENV - name: Set up QEMU - uses: docker/setup-qemu-action@v2 + uses: docker/setup-qemu-action@v3 - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v2 + uses: docker/setup-buildx-action@v3 - name: Login to DockerHub - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_TOKEN }} - name: Build and push id: docker_build_push - uses: docker/build-push-action@v4 + uses: docker/build-push-action@v5 with: context: ./ file: ./.ci_helpers/docker/${{ matrix.image_name }}.dockerfile diff --git a/echopype/mask/__init__.py b/echopype/mask/__init__.py index 5867b1656..a287efc6e 100644 --- a/echopype/mask/__init__.py +++ b/echopype/mask/__init__.py @@ -1,3 +1,4 @@ -from .api import apply_mask, frequency_differencing, get_seabed_mask, shoal_weill +from .api import apply_mask, frequency_differencing, get_seabed_mask, get_shoal_mask + +__all__ = ["frequency_differencing", "apply_mask", "get_seabed_mask", "get_shoal_mask"] -__all__ = ["frequency_differencing", "apply_mask", "get_seabed_mask", "shoal_weill"] diff --git a/echopype/mask/api.py b/echopype/mask/api.py index 26d2c2fc5..3ca274218 100644 --- a/echopype/mask/api.py +++ b/echopype/mask/api.py @@ -10,9 +10,10 @@ from ..utils.io import get_dataset, validate_source_ds_da from ..utils.misc import frequency_nominal_to_channel from ..utils.prov import add_processing_level, echopype_prov_attrs, insert_input_processing_level -from . import seabed + +from . import shoal, seabed from .freq_diff import _check_freq_diff_source_Sv, _parse_freq_diff_eq -from .shoal import _weill as shoal_weill + # lookup table with key string operator and value as corresponding Python operator str2ops = { @@ -552,7 +553,110 @@ def create_multichannel_mask(masks: [xr.Dataset], channels: [str]) -> xr.Dataset ) return result + +def get_shoal_mask( + source_Sv: Union[xr.Dataset, str, pathlib.Path], + parameters: dict, + desired_channel: str = None, + desired_frequency: int = None, + method: str = "will", + **kwargs, +): + """ + Wrapper function for (future) multiple shoal masking algorithms + (currently, only MOVIES-B (Will) is implemented) + Args: + source_Sv: xr.Dataset or str or pathlib.Path + If a Dataset this value contains the Sv data to create a mask for, + else it specifies the path to a zarr or netcdf file containing + a Dataset. This input must correspond to a Dataset that has the + coordinate ``channel`` and variables ``frequency_nominal`` and ``Sv``. + desired_channel: str specifying the channel to generate the mask on + method: string specifying the algorithm to use + currently, 'weill' is the only one implemented + + Returns + ------- + mask: xr.DataArray + A DataArray containing the mask for the Sv data. Regions satisfying the thresholding + criteria are filled with ``True``, else the regions are filled with ``False``. + mask_: xr.DataArray + A DataArray containing the mask for areas in which shoals were searched. + Edge regions are filled with 'False', whereas the portion + in which shoals could be detected is 'True' + + + Raises + ------ + ValueError + If 'weill' is not given + """ + source_Sv = get_dataset(source_Sv) + mask_map = { + "will": shoal._weill, + } + + if method not in mask_map.keys(): + raise ValueError(f"Unsupported method: {method}") + if desired_channel is None: + if desired_frequency is None: + raise ValueError("Must specify either desired channel or desired frequency") + else: + desired_channel = frequency_nominal_to_channel(source_Sv, desired_frequency) + mask, mask_ = mask_map[method](source_Sv, desired_channel, parameters) + return mask, mask_ + + +def get_shoal_mask_multichannel( + source_Sv: Union[xr.Dataset, str, pathlib.Path], + parameters: dict, + method: str = "will", +): + """ + Wrapper function for (future) multiple shoal masking algorithms + (currently, only MOVIES-B (Will) is implemented) + + Args: + source_Sv: xr.Dataset or str or pathlib.Path + If a Dataset this value contains the Sv data to create a mask for, + else it specifies the path to a zarr or netcdf file containing + a Dataset. This input must correspond to a Dataset that has the + coordinate ``channel`` and variables ``frequency_nominal`` and ``Sv``. + mask_type: string specifying the algorithm to use + currently, 'weill' is the only one implemented + + Returns + ------- + mask: xr.DataArray + A DataArray containing the multichannel mask for the Sv data. + Regions satisfying the thresholding criteria are filled with ``True``, + else the regions are filled with ``False``. + mask_: xr.DataArray + A DataArray containing the multichannel mask for areas in which shoals were searched. + Edge regions are filled with 'False', whereas the portion + in which shoals could be detected is 'True' + + + Raises + ------ + ValueError + If 'weill' is not given + """ + channel_list = source_Sv["channel"].values + mask_list = [] + _mask_list = [] + for channel in channel_list: + mask, _mask = get_shoal_mask( + source_Sv, desired_channel=channel, method=method, parameters=parameters + ) + mask_list.append(mask) + _mask_list.append(_mask) + mask = create_multichannel_mask(mask_list, channel_list) + _mask = create_multichannel_mask(_mask_list, channel_list) + return mask, _mask + + def get_seabed_mask( source_Sv: Union[xr.Dataset, str, pathlib.Path], parameters: dict, @@ -666,108 +770,4 @@ def get_seabed_mask_multichannel( mask_list.append(mask) mask = create_multichannel_mask(mask_list, channel_list) return mask - - -def get_shoal_mask( - source_Sv: Union[xr.Dataset, str, pathlib.Path], - desired_channel: str, - mask_type: str = "will", - **kwargs, -): - """ - Wrapper function for (future) multiple shoal masking algorithms - (currently, only MOVIES-B (Will) is implemented) - - Args: - source_Sv: xr.Dataset or str or pathlib.Path - If a Dataset this value contains the Sv data to create a mask for, - else it specifies the path to a zarr or netcdf file containing - a Dataset. This input must correspond to a Dataset that has the - coordinate ``channel`` and variables ``frequency_nominal`` and ``Sv``. - desired_channel: str specifying the channel to generate the mask on - mask_type: string specifying the algorithm to use - currently, 'weill' is the only one implemented - - Returns - ------- - mask: xr.DataArray - A DataArray containing the mask for the Sv data. Regions satisfying the thresholding - criteria are filled with ``True``, else the regions are filled with ``False``. - mask_: xr.DataArray - A DataArray containing the mask for areas in which shoals were searched. - Edge regions are filled with 'False', whereas the portion - in which shoals could be detected is 'True' - - - Raises - ------ - ValueError - If 'weill' is not given - """ - assert mask_type in ["will"] - if mask_type == "will": - # Define a list of the keyword arguments your function can handle - valid_args = {"thr", "maxvgap", "maxhgap", "minvlen", "minhlen"} - # Filter out any kwargs not in your list - filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_args} - mask, mask_ = shoal_weill(source_Sv, desired_channel, **filtered_kwargs) - else: - raise ValueError("The provided mask type must be Will") - return_mask = xr.DataArray( - mask, - dims=("ping_time", "range_sample"), - coords={"ping_time": source_Sv.ping_time, "range_sample": source_Sv.range_sample}, - ) - return_mask_ = xr.DataArray( - mask_, - dims=("ping_time", "range_sample"), - coords={"ping_time": source_Sv.ping_time, "range_sample": source_Sv.range_sample}, - ) - return return_mask, return_mask_ - - -def get_shoal_mask_multichannel( - source_Sv: Union[xr.Dataset, str, pathlib.Path], - mask_type: str = "will", - **kwargs, -): - """ - Wrapper function for (future) multiple shoal masking algorithms - (currently, only MOVIES-B (Will) is implemented) - - Args: - source_Sv: xr.Dataset or str or pathlib.Path - If a Dataset this value contains the Sv data to create a mask for, - else it specifies the path to a zarr or netcdf file containing - a Dataset. This input must correspond to a Dataset that has the - coordinate ``channel`` and variables ``frequency_nominal`` and ``Sv``. - mask_type: string specifying the algorithm to use - currently, 'weill' is the only one implemented - - Returns - ------- - mask: xr.DataArray - A DataArray containing the multichannel mask for the Sv data. - Regions satisfying the thresholding criteria are filled with ``True``, - else the regions are filled with ``False``. - mask_: xr.DataArray - A DataArray containing the multichannel mask for areas in which shoals were searched. - Edge regions are filled with 'False', whereas the portion - in which shoals could be detected is 'True' - - - Raises - ------ - ValueError - If 'weill' is not given - """ - channel_list = source_Sv["channel"].values - mask_list = [] - _mask_list = [] - for channel in channel_list: - mask, _mask = get_shoal_mask(source_Sv, channel, mask_type, **kwargs) - mask_list.append(mask) - _mask_list.append(_mask) - mask = create_multichannel_mask(mask_list, channel_list) - _mask = create_multichannel_mask(_mask_list, channel_list) - return mask, _mask + diff --git a/echopype/mask/shoal.py b/echopype/mask/shoal.py index d51d58215..5a0028a6d 100644 --- a/echopype/mask/shoal.py +++ b/echopype/mask/shoal.py @@ -35,15 +35,13 @@ import scipy.ndimage as nd_img import xarray as xr +WEILL_DEFAULT_PARAMETERS = {"thr": -70, "maxvgap": -5, "maxhgap": 0, "minvlen": 0, "minhlen": 0} + def _weill( source_Sv: Union[xr.Dataset, str, pathlib.Path], desired_channel: str, - thr=-70, - maxvgap=5, - maxhgap=0, - minvlen=0, - minhlen=0, + parameters: dict = WEILL_DEFAULT_PARAMETERS, ): """ Detects and masks shoals following the algorithm described in: @@ -77,16 +75,17 @@ def _weill( a Dataset. This input must correspond to a Dataset that has the coordinate ``channel`` and variables ``frequency_nominal`` and ``Sv``. desired_channel (str): channel to generate the mask on - thr (int): Sv threshold (dB). - maxvgap (int): maximum vertical gap allowed (n samples). - maxhgap (int): maximum horizontal gap allowed (n pings). - minvlen (int): minimum vertical length for a shoal to be eligible - (n samples). - minhlen (int): minimum horizontal length for a shoal to be eligible - (n pings). - start (int): ping index to start processing. If greater than zero, it - means that Sv carries data from a preceding file and - the algorithm needs to know where to start processing. + parameters (dict): containing the required parameters + thr (int): Sv threshold (dB). + maxvgap (int): maximum vertical gap allowed (n samples). + maxhgap (int): maximum horizontal gap allowed (n pings). + minvlen (int): minimum vertical length for a shoal to be eligible + (n samples). + minhlen (int): minimum horizontal length for a shoal to be eligible + (n pings). + start (int): ping index to start processing. If greater than zero, it + means that Sv carries data from a preceding file and + the algorithm needs to know where to start processing. Returns ------- @@ -98,7 +97,20 @@ def _weill( Edge regions are filled with 'False', whereas the portion in which shoals could be detected is 'True' """ - # Sv = source_Sv["Sv"].values[0] + parameter_names = ["thr", "maxvgap", "maxhgap", "minvlen", "minhlen"] + if not all(name in parameters.keys() for name in parameter_names): + raise ValueError( + "Missing parameters - should be: " + + str(parameter_names) + + ", are: " + + str(parameters.keys()) + ) + thr = parameters["thr"] + maxvgap = parameters["maxvgap"] + maxhgap = parameters["maxhgap"] + minvlen = parameters["minvlen"] + minhlen = parameters["minhlen"] + channel_Sv = source_Sv.sel(channel=desired_channel) Sv = channel_Sv["Sv"].values @@ -165,5 +177,14 @@ def _weill( mask_ = np.zeros_like(mask, dtype=bool) mask_[minvlen : len(mask_) - minvlen, minhlen : len(mask_[0]) - minhlen] = True - # return masks, from the start ping onwards - return mask, mask_ + return_mask = xr.DataArray( + mask, + dims=("ping_time", "range_sample"), + coords={"ping_time": source_Sv.ping_time, "range_sample": source_Sv.range_sample}, + ) + return_mask_ = xr.DataArray( + mask_, + dims=("ping_time", "range_sample"), + coords={"ping_time": source_Sv.ping_time, "range_sample": source_Sv.range_sample}, + ) + return return_mask, return_mask_ \ No newline at end of file diff --git a/echopype/tests/conftest.py b/echopype/tests/conftest.py index 740448a20..fdfce92c1 100644 --- a/echopype/tests/conftest.py +++ b/echopype/tests/conftest.py @@ -129,4 +129,3 @@ def complete_dataset_jr179(setup_test_data_jr179): def raw_dataset_jr179(setup_test_data_jr179): ed = _get_raw_dataset(setup_test_data_jr179) return ed - diff --git a/echopype/tests/mask/test_mask.py b/echopype/tests/mask/test_mask.py index 8c5552e39..3b3d92eff 100644 --- a/echopype/tests/mask/test_mask.py +++ b/echopype/tests/mask/test_mask.py @@ -1017,7 +1017,9 @@ def test_channel_mask(var_name="var2"): def test_shoal_mask_all(sv_dataset_jr161): source_Sv = sv_dataset_jr161 - ml, _ml = echopype.mask.api.get_shoal_mask_multichannel(source_Sv) + ml, _ml = echopype.mask.api.get_shoal_mask_multichannel( + source_Sv, method="will", parameters=ep.mask.shoal.WEILL_DEFAULT_PARAMETERS + ) assert np.all(ml["channel"] == source_Sv["channel"]) assert np.all(_ml["channel"] == source_Sv["channel"]) @@ -1028,3 +1030,4 @@ def test_seabed_mask_all(complete_dataset_jr179): source_Sv, method="ariza", parameters=ep.mask.seabed.ARIZA_DEFAULT_PARAMS ) assert np.all(ml["channel"] == source_Sv["channel"]) + diff --git a/echopype/tests/mask/test_mask_shoal.py b/echopype/tests/mask/test_mask_shoal.py index 83ea140a0..2eb45cc31 100644 --- a/echopype/tests/mask/test_mask_shoal.py +++ b/echopype/tests/mask/test_mask_shoal.py @@ -1,19 +1,26 @@ import numpy as np import pytest -from echopype.mask.api import shoal_weill + +from echopype.mask.api import get_shoal_mask +from echopype.mask.shoal import WEILL_DEFAULT_PARAMETERS DESIRED_CHANNEL = "GPT 38 kHz 009072033fa5 1 ES38" @pytest.mark.parametrize( - "desired_channel,expected_tf_counts,expected_tf_counts_", - [(DESIRED_CHANNEL, (186650, 1980281), (2166931, 0))], + "method, desired_channel,parameters,expected_tf_counts,expected_tf_counts_", + [("will", DESIRED_CHANNEL, WEILL_DEFAULT_PARAMETERS, (101550, 2065381), (2166931, 0))], ) def test_get_shoal_mask_weill( - sv_dataset_jr161, desired_channel, expected_tf_counts, expected_tf_counts_ + sv_dataset_jr161, method, desired_channel, parameters, expected_tf_counts, expected_tf_counts_ ): source_Sv = sv_dataset_jr161 - mask, mask_ = shoal_weill(source_Sv, desired_channel) + mask, mask_ = get_shoal_mask( + source_Sv, + method=method, + desired_channel=desired_channel, + parameters=parameters, + ) count_true = np.count_nonzero(mask) count_false = mask.size - count_true @@ -23,4 +30,4 @@ def test_get_shoal_mask_weill( count_true_ = np.count_nonzero(mask_) count_false_ = mask.size - count_true_ true_false_counts_ = (count_true_, count_false_) - assert true_false_counts_ == expected_tf_counts_ + assert true_false_counts_ == expected_tf_counts_ \ No newline at end of file