diff --git a/echopype/mask/api.py b/echopype/mask/api.py index cd385fff9..8440603d2 100644 --- a/echopype/mask/api.py +++ b/echopype/mask/api.py @@ -8,6 +8,7 @@ from ..utils.io import validate_source_ds_da from ..utils.prov import add_processing_level, echopype_prov_attrs, insert_input_processing_level +from .freq_diff import _check_freq_diff_source_Sv, _parse_freq_diff_eq # lookup table with key string operator and value as corresponding Python operator str2ops = { @@ -366,128 +367,11 @@ def get_ch_shape(da): return output_ds -def _check_freq_diff_non_data_inputs( - freqAB: Optional[List[float]] = None, - chanAB: Optional[List[str]] = None, - operator: str = ">", - diff: Union[float, int] = None, -) -> None: - """ - Checks that the non-data related inputs of ``frequency_differencing`` (i.e. ``freqAB``, - ``chanAB``, ``operator``, ``diff``) were correctly provided. - - Parameters - ---------- - freqAB: list of float, optional - The pair of nominal frequencies to be used for frequency-differencing, where - the first element corresponds to ``freqA`` and the second element corresponds - to ``freqB`` - chanAB: list of float, optional - The pair of channels that will be used to select the nominal frequencies to be - used for frequency-differencing, where the first element corresponds to ``freqA`` - and the second element corresponds to ``freqB`` - operator: {">", "<", "<=", ">=", "=="} - The operator for the frequency-differencing - diff: float or int - The threshold of Sv difference between frequencies - """ - - # check that either freqAB or chanAB are provided and they are a list of length 2 - if (freqAB is None) and (chanAB is None): - raise ValueError("Either freqAB or chanAB must be given!") - elif (freqAB is not None) and (chanAB is not None): - raise ValueError("Only freqAB or chanAB must be given, but not both!") - elif freqAB is not None: - if not isinstance(freqAB, list): - raise TypeError("freqAB must be a list!") - elif len(set(freqAB)) != 2: - raise ValueError("freqAB must be a list of length 2 with unique elements!") - else: - if not isinstance(chanAB, list): - raise TypeError("chanAB must be a list!") - elif len(set(chanAB)) != 2: - raise ValueError("chanAB must be a list of length 2 with unique elements!") - - # check that operator is a string and a valid operator - if not isinstance(operator, str): - raise TypeError("operator must be a string!") - else: - if operator not in [">", "<", "<=", ">=", "=="]: - raise ValueError("Invalid operator!") - - # ensure that diff is a float or an int - if not isinstance(diff, (float, int)): - raise TypeError("diff must be a float or int!") - - -def _check_source_Sv_freq_diff( - source_Sv: xr.Dataset, - freqAB: Optional[List[float]] = None, - chanAB: Optional[List[str]] = None, -) -> None: - """ - Ensures that ``source_Sv`` contains ``channel`` as a coordinate and - ``frequency_nominal`` as a variable, the provided list input - (``freqAB`` or ``chanAB``) are contained in the coordinate ``channel`` - or variable ``frequency_nominal``, and ``source_Sv`` does not have - repeated values for ``channel`` and ``frequency_nominal``. - - Parameters - ---------- - source_Sv: xr.Dataset - A Dataset that contains the Sv data to create a mask for - freqAB: list of float, optional - The pair of nominal frequencies to be used for frequency-differencing, where - the first element corresponds to ``freqA`` and the second element corresponds - to ``freqB`` - chanAB: list of float, optional - The pair of channels that will be used to select the nominal frequencies to be - used for frequency-differencing, where the first element corresponds to ``freqA`` - and the second element corresponds to ``freqB`` - """ - - # check that channel and frequency nominal are in source_Sv - if "channel" not in source_Sv.coords: - raise ValueError("The Dataset defined by source_Sv must have channel as a coordinate!") - elif "frequency_nominal" not in source_Sv.variables: - raise ValueError( - "The Dataset defined by source_Sv must have frequency_nominal as a variable!" - ) - - # make sure that the channel and frequency_nominal values are not repeated in source_Sv - if len(set(source_Sv.channel.values)) < source_Sv.channel.size: - raise ValueError( - "The provided source_Sv contains repeated channel values, this is not allowed!" - ) - - if len(set(source_Sv.frequency_nominal.values)) < source_Sv.frequency_nominal.size: - raise ValueError( - "The provided source_Sv contains repeated frequency_nominal " - "values, this is not allowed!" - ) - - # check that the elements of freqAB are in frequency_nominal - if (freqAB is not None) and (not all([freq in source_Sv.frequency_nominal for freq in freqAB])): - raise ValueError( - "The provided list input freqAB contains values that " - "are not in the frequency_nominal variable!" - ) - - # check that the elements of chanAB are in channel - if (chanAB is not None) and (not all([chan in source_Sv.channel for chan in chanAB])): - raise ValueError( - "The provided list input chanAB contains values that are " - "not in the channel coordinate!" - ) - - def frequency_differencing( source_Sv: Union[xr.Dataset, str, pathlib.Path], storage_options: Optional[dict] = {}, - freqAB: Optional[List[float]] = None, - chanAB: Optional[List[str]] = None, - operator: str = ">", - diff: Union[float, int] = None, + freqABEq: Optional[str] = None, + chanABEq: Optional[str] = None, ) -> xr.DataArray: """ Create a mask based on the differences of Sv values using a pair of @@ -504,19 +388,13 @@ def frequency_differencing( storage_options: dict, optional Any additional parameters for the storage backend, corresponding to the path provided for ``source_Sv`` - freqAB: list of float, optional - The pair of nominal frequencies to be used for frequency-differencing, where - the first element corresponds to ``freqA`` and the second element corresponds - to ``freqB``. Only one of ``freqAB`` and ``chanAB`` should be provided, and not both. - chanAB: list of strings, optional - The pair of channels that will be used to select the nominal frequencies to be - used for frequency-differencing, where the first element corresponds to ``freqA`` - and the second element corresponds to ``freqB``. Only one of ``freqAB`` and ``chanAB`` + freqABEq: string, optional + The frequency differencing criteria. + Only one of ``freqAB`` and ``chanAB`` should be provided, and not both. + chanAB: string, optional + The frequency differencing criteria in terms of channel names where channel names + in the criteria are enclosed in double quotes. Only one of ``freqAB`` and ``chanAB`` should be provided, and not both. - operator: {">", "<", "<=", ">=", "=="} - The operator for the frequency-differencing - diff: float or int - The threshold of Sv difference between frequencies Returns ------- @@ -527,24 +405,24 @@ def frequency_differencing( Raises ------ ValueError - If neither ``freqAB`` or ``chanAB`` are given + If neither ``freqABEq`` or ``chanABEq`` are given ValueError - If both ``freqAB`` and ``chanAB`` are given + If both ``freqABEq`` and ``chanABEq`` are given TypeError If any input is not of the correct type ValueError - If either ``freqAB`` or ``chanAB`` are provided and the list - does not contain 2 distinct elements + If either ``freqABEq`` or ``chanABEq`` are provided and the extracted + ``freqAB`` or ``chanAB`` does not contain 2 distinct elements ValueError - If ``freqAB`` contains values that are not contained in ``frequency_nominal`` + If ``freqABEq`` contains values that are not contained in ``frequency_nominal`` ValueError - If ``chanAB`` contains values that not contained in ``channel`` + If ``chanABEq`` contains values that not contained in ``channel`` ValueError If ``operator`` is not one of the following: ``">", "<", "<=", ">=", "=="`` ValueError If the path provided for ``source_Sv`` is not a valid path ValueError - If ``freqAB`` or ``chanAB`` is provided and the Dataset produced by ``source_Sv`` + If ``freqABEq`` or ``chanABEq`` is provided and the Dataset produced by ``source_Sv`` does not contain the coordinate ``channel`` and variable ``frequency_nominal`` Notes @@ -573,9 +451,8 @@ def frequency_differencing( >>> Sv_ds = xr.Dataset(data_vars={"Sv": Sv_da, "frequency_nominal": freq_nom}) ... >>> # compute frequency-differencing mask using channel names - >>> echopype.mask.frequency_differencing(source_Sv=mock_Sv_ds, storage_options={}, freqAB=None, - ... chanAB = ['chan1', 'chan2'], - ... operator = ">=", diff=10.0) + >>> echopype.mask.frequency_differencing(source_Sv=mock_Sv_ds, storage_options={}, + ... freqABEq=None, chanABEq = '"chan1" - "chan2">=10.0') array([[False, False, False, False, False], [False, False, False, False, False], @@ -588,7 +465,8 @@ def frequency_differencing( """ # check that non-data related inputs were correctly provided - _check_freq_diff_non_data_inputs(freqAB, chanAB, operator, diff) + # _check_freq_diff_non_data_inputs(freqAB, chanAB, operator, diff) + freqAB, chanAB, operator, diff = _parse_freq_diff_eq(freqABEq, chanABEq) # validate the source_Sv type or path (if it is provided) source_Sv, file_type = validate_source_ds_da(source_Sv, storage_options) @@ -598,7 +476,7 @@ def frequency_differencing( source_Sv = xr.open_dataset(source_Sv, engine=file_type, chunks={}, **storage_options) # check the source_Sv with respect to channel and frequency_nominal - _check_source_Sv_freq_diff(source_Sv, freqAB, chanAB) + _check_freq_diff_source_Sv(source_Sv, freqAB, chanAB) # determine chanA and chanB if freqAB is not None: diff --git a/echopype/mask/freq_diff.py b/echopype/mask/freq_diff.py new file mode 100644 index 000000000..002467e72 --- /dev/null +++ b/echopype/mask/freq_diff.py @@ -0,0 +1,149 @@ +import re +from typing import List, Optional, Union + +import xarray as xr + + +def _parse_freq_diff_eq( + freqABEq: Optional[str] = None, + chanABEq: Optional[str] = None, +) -> List[Union[List[float], List[str], str, Union[float, int]]]: + """ + Checks if either `freqABEq` or `chanABEq` is provided and parse the arguments accordingly + from the frequency diffrencing criteria. + + Parameters + ---------- + freqABEq : str, optional + The equation for frequency-differencing using frequency values. + chanABEq : str, optional + The equation for frequency-differencing using channel names. + + Returns + ------- + List[Union[List[float], List[str], str, Union[float, int]]] + A list containing the parsed arguments for frequency-differencing, where the first element + corresponds to `freqAB`, the second element corresponds to `chanAB`, the third element + corresponds to `operator`, the fourth element corresponds to `diff`. + + Raises + ------ + ValueError + If `operator` is not a valid operator. + If both `freqABEq` and `chanABEq` are provided. + If neither `freqABEq` nor `chanABEq` is provided. + If `freqAB` or `chanAB` is not a list of length 2 with unique elements. + TypeError + If `diff` is not a float or an int. + If `freqABEq` or `chanABEq` is not a valid equation. + """ + + if (freqABEq is None) and (chanABEq is None): + raise ValueError("Either freqAB or chanAB must be given!") + elif (freqABEq is not None) and (chanABEq is not None): + raise ValueError("Only one of freqAB or chanAB should be given, but not both!") + elif freqABEq is not None: + freqAPattern = r"(?P\d*\.\d+)\s*(?P\w?)Hz" + freqBPattern = r"(?P\d*\.\d+)\s*(?P\w?)Hz" + operatorPattern = r"\s*(?P\S*?)\s*" + rhsPattern = r"(?P\d*\.?\d+)\s*dB" + diffMatcher = re.compile( + freqAPattern + r"\s*-\s*" + freqBPattern + operatorPattern + rhsPattern + ) + eqMatched = diffMatcher.match(freqABEq) + if eqMatched is None: + raise TypeError("Invalid freqAB Equation!") + operator = eqMatched["cmp"] + if operator not in [">", "<", "<=", ">=", "=="]: + raise ValueError("Invalid operator!") + freqMultiplier = {"": 1, "k": 1e3, "M": 1e6, "G": 1e9} + freqA = float(eqMatched["freqA"]) * freqMultiplier[eqMatched["unitA"]] + freqB = float(eqMatched["freqB"]) * freqMultiplier[eqMatched["unitB"]] + freqAB = [freqA, freqB] + if len(set(freqAB)) != 2: + raise ValueError("freqAB must be a list of length 2 with unique elements!") + diff = float(eqMatched["db"]) + return [freqAB, None, operator, diff] + elif chanABEq is not None: + chanAPattern = r"(?P\".+\")\s*" + chanBPattern = r"(?P\".+\")\s*" + operatorPattern = r"\s*(?P\S*?)\s*" + rhsPattern = r"(?P\d*\.?\d+)\s*dB" + diffMatcher = re.compile( + chanAPattern + r"\s*-\s*" + chanBPattern + operatorPattern + rhsPattern + ) + eqMatched = diffMatcher.match(chanABEq) + if eqMatched is None: + raise TypeError("Invalid chanAB Equation!") + operator = eqMatched["cmp"] + if operator not in [">", "<", "<=", ">=", "=="]: + raise ValueError("Invalid operator!") + chanAB = [eqMatched["chanA"][1:-1], eqMatched["chanB"][1:-1]] + if len(set(chanAB)) != 2: + raise ValueError("chanAB must be a list of length 2 with unique elements!") + diff = float(eqMatched["db"]) + return [None, chanAB, operator, diff] + + +def _check_freq_diff_source_Sv( + source_Sv: xr.Dataset, + freqAB: Optional[List[float]] = None, + chanAB: Optional[List[str]] = None, +) -> None: + """ + Ensures that ``source_Sv`` contains ``channel`` as a coordinate and + ``frequency_nominal`` as a variable, the provided list input + (``freqAB`` or ``chanAB``) are contained in the coordinate ``channel`` + or variable ``frequency_nominal``, and ``source_Sv`` does not have + repeated values for ``channel`` and ``frequency_nominal``. + + Parameters + ---------- + source_Sv: xr.Dataset + A Dataset that contains the Sv data to create a mask for + freqAB: list of float, optional + The pair of nominal frequencies to be used for frequency-differencing, where + the first element corresponds to ``freqA`` and the second element corresponds + to ``freqB`` + chanAB: list of float, optional + The pair of channels that will be used to select the nominal frequencies to be + used for frequency-differencing, where the first element corresponds to ``freqA`` + and the second element corresponds to ``freqB`` + """ + + # check that channel and frequency nominal are in source_Sv + if "channel" not in source_Sv.coords: + raise ValueError("The Dataset defined by source_Sv must have channel as a coordinate!") + elif "frequency_nominal" not in source_Sv.variables: + raise ValueError( + "The Dataset defined by source_Sv must have frequency_nominal as a variable!" + ) + + # make sure that the channel values are not repeated in source_Sv and + # elements of chanAB are in channel + if chanAB is not None: + if len(set(source_Sv.channel.values)) < source_Sv.channel.size: + raise ValueError( + "The provided source_Sv contains repeated channel values, this is not allowed!" + ) + if not all([chan in source_Sv.channel for chan in chanAB]): + raise ValueError( + "The provided list input chanAB contains values that are " + "not in the channel coordinate!" + ) + + # make sure that the frequency_nominal values are not repeated in source_Sv and + # elements of freqAB are in frequency_nominal + if freqAB is not None: + print(source_Sv.frequency_nominal.values) + if len(set(source_Sv.frequency_nominal.values)) < source_Sv.frequency_nominal.size: + raise ValueError( + "The provided source_Sv contains repeated " + "frequency_nominal values, this is not allowed!" + ) + + if not all([freq in source_Sv.frequency_nominal for freq in freqAB]): + raise ValueError( + "The provided list input freqAB contains values that " + "are not in the frequency_nominal variable!" + ) diff --git a/echopype/tests/mask/test_mask.py b/echopype/tests/mask/test_mask.py index 15d87f242..7030f16db 100644 --- a/echopype/tests/mask/test_mask.py +++ b/echopype/tests/mask/test_mask.py @@ -11,10 +11,13 @@ import echopype as ep import echopype.mask from echopype.mask.api import ( - _check_source_Sv_freq_diff, _validate_and_collect_mask_input, _check_var_name_fill_value ) +from echopype.mask.freq_diff import ( + _parse_freq_diff_eq, + _check_freq_diff_source_Sv, +) from typing import List, Union, Optional @@ -84,7 +87,8 @@ def get_mock_freq_diff_data(n: int, n_chan_freq: int, add_chan: bool, if add_freq_nom: # construct frequency_values - freq_vals = [float(i) for i in range(1, n_chan_freq + 1)] + freqs = [1, 1e3, 2, 2e3] + freq_vals = [float(i) for i in freqs] # create mock frequency_nominal and add it to the Dataset variables mock_freq_nom = xr.DataArray(data=freq_vals, coords={channel_coord_name: chan_vals}) @@ -232,34 +236,34 @@ def create_input_mask( @pytest.mark.parametrize( ("n", "n_chan_freq", "add_chan", "add_freq_nom", "freqAB", "chanAB"), [ - (5, 3, True, True, [1.0, 3.0], None), - (5, 3, True, True, None, ['chan1', 'chan3']), - pytest.param(5, 3, False, True, [1.0, 3.0], None, + (5, 4, True, True, [1000.0, 2.0], None), + (5, 4, True, True, None, ['chan1', 'chan3']), + pytest.param(5, 4, False, True, [1.0, 2000000.0], None, marks=pytest.mark.xfail(strict=True, reason="This should fail because the Dataset " "will not have the channel coordinate.")), - pytest.param(5, 3, True, False, [1.0, 3.0], None, + pytest.param(5, 4, True, False, [1.0, 2.0], None, marks=pytest.mark.xfail(strict=True, reason="This should fail because the Dataset " "will not have the frequency_nominal variable.")), - pytest.param(5, 3, True, True, [1.0, 4.0], None, + pytest.param(5, 4, True, True, [1.0, 4.0], None, marks=pytest.mark.xfail(strict=True, reason="This should fail because not all selected frequencies" "are in the frequency_nominal variable.")), - pytest.param(5, 3, True, True, None, ['chan1', 'chan4'], + pytest.param(5, 4, True, True, None, ['chan1', 'chan9'], marks=pytest.mark.xfail(strict=True, reason="This should fail because not all selected channels" - "are in the channel coordinate.")), + "are in the channel coordinate.")) ], ids=["dataset_input_freqAB_provided", "dataset_input_chanAB_provided", "dataset_no_channel", "dataset_no_frequency_nominal", "dataset_missing_freqAB_in_freq_nom", "dataset_missing_chanAB_in_channel"] ) -def test_check_source_Sv_freq_diff(n: int, n_chan_freq: int, add_chan: bool, add_freq_nom: bool, +def test_check_freq_diff_source_Sv(n: int, n_chan_freq: int, add_chan: bool, add_freq_nom: bool, freqAB: List[float], chanAB: List[str]): """ - Test the inputs ``source_Sv, freqAB, chanAB`` for ``_check_source_Sv_freq_diff``. + Test the inputs ``source_Sv, freqAB, chanAB`` for ``_check_freq_diff_source_Sv``. Parameters ---------- @@ -287,24 +291,87 @@ def test_check_source_Sv_freq_diff(n: int, n_chan_freq: int, add_chan: bool, add source_Sv = get_mock_freq_diff_data(n, n_chan_freq, add_chan, add_freq_nom) - _check_source_Sv_freq_diff(source_Sv, freqAB=freqAB, chanAB=chanAB) + _check_freq_diff_source_Sv(source_Sv, freqAB=freqAB, chanAB=chanAB) + + +@pytest.mark.parametrize( + ("freqABEq", "chanABEq"), + [ + ("1.0Hz-2.0Hz==1.0dB", None), + (None, '"chan1"-"chan3"==1.0dB'), + ("1.0 kHz - 2.0 MHz>=1.0 dB", None), + (None, '"chan2-12 89" - "chan4 89-12" >= 1.0 dB'), + pytest.param("1.0kHz-2.0 kHz===1.0dB", None, + marks=pytest.mark.xfail(strict=True, + reason="This should fail because " + "the operator is incorrect.")), + pytest.param(None, '"chan1"-"chan3"===1.0 dB', + marks=pytest.mark.xfail(strict=True, + reason="This should fail because " + "the operator is incorrect.")), + pytest.param("1.0 MHz-1.0MHz==1.0dB", None, + marks=pytest.mark.xfail(strict=True, + reason="This should fail because the " + "frequencies are the same.")), + pytest.param(None, '"chan1"-"chan1"==1.0 dB', + marks=pytest.mark.xfail(strict=True, + reason="This should fail because the " + "channels are the same.")), + pytest.param("1.0 Hz-2.0==1.0dB", None, + marks=pytest.mark.xfail(strict=True, + reason="This should fail because unit of one of " + "the frequency is missing.")), + pytest.param(None, '"chan1"-"chan3"==1.0', + marks=pytest.mark.xfail(strict=True, + reason="This should fail because unit of the " + "difference is missing.")), + ], + ids=["input_freqABEq_provided", "input_chanABEq_provided", "input_freqABEq_different_units", + "input_chanABEq_provided", "input_freqABEq_wrong_operator", "input_chanABEq_wrong_operator", + "input_freqABEq_duplicate_frequencies", "input_chanABEq_duplicate_channels", + "input_freqABEq_missing_unit", "input_chanABEq_missing_unit"] +) +def test_parse_freq_diff_eq(freqABEq: str, chanABEq: str): + """ + Tests the inputs ``freqABEq, chanABEq`` for ``_parse_freq_diff_eq``. + Parameters + ---------- + freqABEq: string, optional + The frequency differencing criteria. + chanABEq: string, optional + The frequency differencing criteria in terms of channel names where channel names + in the criteria are enclosed in double quotes. + """ + freq_vals = [1.0, 2.0, 1e3, 2e3, 1e6, 2e6] + chan_vals = ['chan1', 'chan3', "chan2-12 89", "chan4 89-12"] + operator_vals = [">=", "=="] + diff_val = 1.0 + freqAB, chanAB, operator, diff = _parse_freq_diff_eq(freqABEq=freqABEq, chanABEq=chanABEq) + if freqAB is not None: + for freq in freqAB: + assert freq in freq_vals + if chanAB is not None: + for chan in chanAB: + assert chan in chan_vals + assert operator in operator_vals + assert diff == diff_val @pytest.mark.parametrize( - ("n", "n_chan_freq", "freqAB", "chanAB", "diff", "operator", "mask_truth"), + ("n", "n_chan_freq", "freqABEq", "chanABEq", "mask_truth"), [ - (5, 4, [1.0, 3.0], None, 1.0, "==", np.identity(5)), - (5, 4, None, ['chan1', 'chan3'], 1.0, "==", np.identity(5)), - (5, 4, [3.0, 1.0], None, 1.0, "==", np.zeros((5, 5))), - (5, 4, None, ['chan3', 'chan1'], 1.0, "==", np.zeros((5, 5))), - (5, 4, [1.0, 3.0], None, 1.0, ">=", np.identity(5)), - (5, 4, None, ['chan1', 'chan3'], 1.0, ">=", np.identity(5)), - (5, 4, [1.0, 3.0], None, 1.0, ">", np.zeros((5, 5))), - (5, 4, None, ['chan1', 'chan3'], 1.0, ">", np.zeros((5, 5))), - (5, 4, [1.0, 3.0], None, 1.0, "<=", np.ones((5, 5))), - (5, 4, None, ['chan1', 'chan3'], 1.0, "<=", np.ones((5, 5))), - (5, 4, [1.0, 3.0], None, 1.0, "<", np.ones((5, 5)) - np.identity(5)), - (5, 4, None, ['chan1', 'chan3'], 1.0, "<", np.ones((5, 5)) - np.identity(5)), + (5, 4, "1.0Hz-2.0Hz== 1.0dB", None, np.identity(5)), + (5, 4, None, '"chan1"-"chan3" == 1.0 dB', np.identity(5)), + (5, 4, "2.0 Hz - 1.0 Hz==1.0 dB", None, np.zeros((5, 5))), + (5, 4, None, '"chan3" - "chan1"==1.0 dB', np.zeros((5, 5))), + (5, 4, "1.0 Hz-2.0Hz>=1.0dB", None, np.identity(5)), + (5, 4, None, '"chan1" - "chan3" >= 1.0 dB', np.identity(5)), + (5, 4, "1.0 kHz - 2.0 kHz > 1.0dB", None, np.zeros((5, 5))), + (5, 4, None, '"chan1"-"chan3">1.0 dB', np.zeros((5, 5))), + (5, 4, "1.0kHz-2.0 kHz<=1.0dB", None, np.ones((5, 5))), + (5, 4, None, '"chan1" - "chan3" <= 1.0 dB', np.ones((5, 5))), + (5, 4, "1.0 Hz-2.0Hz<1.0dB", None, np.ones((5, 5)) - np.identity(5)), + (5, 4, None, '"chan1"-"chan3"< 1.0 dB', np.ones((5, 5)) - np.identity(5)) ], ids=["freqAB_sel_op_equals", "chanAB_sel_op_equals", "reverse_freqAB_sel_op_equals", "reverse_chanAB_sel_op_equals", "freqAB_sel_op_ge", "chanAB_sel_op_ge", @@ -312,8 +379,7 @@ def test_check_source_Sv_freq_diff(n: int, n_chan_freq: int, add_chan: bool, add "chanAB_sel_op_le", "freqAB_sel_op_less", "chanAB_sel_op_less"] ) def test_frequency_differencing(n: int, n_chan_freq: int, - freqAB: List[float], chanAB: List[str], - diff: Union[float, int], operator: str, + freqABEq: str, chanABEq: str, mask_truth: np.ndarray): """ Tests that the output values of ``frequency_differencing`` are what we @@ -328,18 +394,11 @@ def test_frequency_differencing(n: int, n_chan_freq: int, Determines the size of the ``channel`` coordinate and ``frequency_nominal`` variable. To create mock data with known outcomes for ``frequency_differencing``, this value must be greater than or equal to 3. - freqAB: list of float, optional - The pair of nominal frequencies to be used for frequency-differencing, where - the first element corresponds to ``freqA`` and the second element corresponds - to ``freqB`` - chanAB: list of float, optional - The pair of channels that will be used to select the nominal frequencies to be - used for frequency-differencing, where the first element corresponds to ``freqA`` - and the second element corresponds to ``freqB`` - diff: float or int - The threshold of Sv difference between frequencies - operator: {">", "<", "<=", ">=", "=="} - The operator for the frequency-differencing + freqABEq: string, optional + The frequency differencing criteria. + chanABEq: string, optional + The frequency differencing criteria in terms of channel names where channel names + in the criteria are enclosed in double quotes. mask_truth: np.ndarray The truth value for the output mask, provided the given inputs """ @@ -348,9 +407,8 @@ def test_frequency_differencing(n: int, n_chan_freq: int, mock_Sv_ds = get_mock_freq_diff_data(n, n_chan_freq, add_chan=True, add_freq_nom=True) # obtain the frequency-difference mask for mock_Sv_ds - out = ep.mask.frequency_differencing(source_Sv=mock_Sv_ds, storage_options={}, freqAB=freqAB, - chanAB=chanAB, - operator=operator, diff=diff) + out = ep.mask.frequency_differencing(source_Sv=mock_Sv_ds, storage_options={}, freqABEq=freqABEq, + chanABEq=chanABEq) # ensure that the output values are correct assert np.all(out == mask_truth) @@ -656,6 +714,6 @@ def test_apply_mask_channel_variation(source_has_ch, mask_has_ch): [[1, np.nan], [np.nan, 1]], coords={"ping_time": np.arange(2), "range_sample": np.arange(2)}, attrs=source_ds[var_name].attrs - ) + ) - assert masked_ds[var_name].equals(truth_da) \ No newline at end of file + assert masked_ds[var_name].equals(truth_da) diff --git a/echopype/tests/utils/test_processinglevels_integration.py b/echopype/tests/utils/test_processinglevels_integration.py index b04af8b1e..0dadbfc87 100644 --- a/echopype/tests/utils/test_processinglevels_integration.py +++ b/echopype/tests/utils/test_processinglevels_integration.py @@ -111,7 +111,8 @@ def _freqdiff_applymask(test_ds): else: out_ds = test_ds freqAB = list(out_ds.frequency_nominal.values[:2]) - freqdiff_da = ep.mask.frequency_differencing(source_Sv=out_ds, freqAB=freqAB, operator=">", diff=5) + freqABEq = str(freqAB[0]) + "Hz" + "-" + str(freqAB[1]) + "Hz" + ">" + str(5) + "dB" + freqdiff_da = ep.mask.frequency_differencing(source_Sv=out_ds, freqABEq=freqABEq) # Apply mask to multi-channel Sv return ep.mask.apply_mask(source_ds=out_ds, var_name="Sv", mask=freqdiff_da)