Skip to content

Commit

Permalink
wip: tests
Browse files Browse the repository at this point in the history
  • Loading branch information
anteju committed Sep 29, 2023
1 parent 1fae201 commit 9f93703
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 3 deletions.
4 changes: 1 addition & 3 deletions nemo/collections/asr/modules/audio_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,15 +501,14 @@ class MaskEstimatorFlexChannels(NeuralModule):
use_ipd: Use inter-channel phase difference (IPD) features
mag_normalization: Normalize using mean ('mean') or mean and variance ('mean_var')
ipd_normalization: Normalize using mean ('mean') or mean and variance ('mean_var')
estimate_ref_channel: Estimate the output reference channel automatically
"""

def __init__(
self,
num_outputs: int,
num_subbands: int,
num_blocks: int,
channel_reduction_position: int = -1, # if 0, apply before layer 0, if -1 apply at the end
channel_reduction_position: int = -1, # if 0, apply before block 0, if -1 apply at the end
channel_reduction_type: str = 'attention',
channel_block_type: str = 'transform_attend_concatenate',
temporal_block_type: str = 'conformer_encoder',
Expand All @@ -524,7 +523,6 @@ def __init__(
use_ipd: bool = True,
mag_normalization: Optional[str] = None,
ipd_normalization: Optional[str] = None,
estimate_ref_channel: Optional[bool] = False,
):
super().__init__()

Expand Down
59 changes: 59 additions & 0 deletions tests/collections/asr/test_audio_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@

from nemo.collections.asr.modules.audio_modules import (
MaskBasedDereverbWPE,
MaskEstimatorFlexChannels,
MaskReferenceChannel,
SpectrogramToMultichannelFeatures,
WPEFilter,
)
from nemo.collections.asr.modules.audio_preprocessing import AudioToSpectrogram
from nemo.collections.asr.parts.utils.audio_utils import convmtx_mc_numpy
from nemo.utils import logging

try:
importlib.import_module('torchaudio')
Expand Down Expand Up @@ -347,3 +349,60 @@ def test_mask_based_dereverb_init(self, num_channels: int, filter_length: int, d

assert y.shape == x.shape, 'Output shape not matching, example {n}'
assert torch.equal(y_length, x_length), 'Length not matching, example {n}'


class TestMaskEstimator:
@pytest.mark.unit
@pytest.mark.skipif(not HAVE_TORCHAUDIO, reason="Modules in this test require torchaudio")
@pytest.mark.parametrize('channel_reduction_position', [0, 1, -1])
@pytest.mark.parametrize('channel_reduction_type', ['average', 'attention'])
@pytest.mark.parametrize('channel_block_type', ['transform_average_concatenate', 'transform_attend_concatenate'])
def test_flex_channels(
self, channel_reduction_position: int, channel_reduction_type: str, channel_block_type: str
):
"""Test initialization of the mask estimator and make sure it can process input tensor.
"""
# model parameters
num_subbands_tests = [32, 65]
num_outputs_tests = [1, 2]
num_blocks_tests = [1, 5]

# input channels
num_channels_tests = [1, 4]
batch_size = 4
num_frames = 50

for num_subbands in num_subbands_tests:
for num_outputs in num_outputs_tests:
for num_blocks in num_blocks_tests:
logging.info(
'Initialized with num_subbands=%d, num_outputs=%d, num_blocks=%d',
num_subbands,
num_outputs,
num_blocks,
)

uut = MaskEstimatorFlexChannels(
num_outputs=num_outputs,
num_subbands=num_subbands,
num_blocks=num_blocks,
channel_reduction_position=channel_reduction_position,
channel_reduction_type=channel_reduction_type,
channel_block_type=channel_block_type,
)

for num_channels in num_channels_tests:
logging.info('Process num_channels=%d', num_channels)
input_size = (batch_size, num_channels, num_subbands, num_frames)

# multi-channel input
spec = torch.randn(input_size, dtype=torch.cfloat)
spec_length = torch.randint(1, num_frames, (batch_size,))

# UUT
mask, mask_length = uut(input=spec, input_length=spec_length)

# TODO: check dimensions and lengths match
import pdb

pdb.set_trace()

0 comments on commit 9f93703

Please sign in to comment.