Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ASR] GSS-based mask estimator #7849

Merged
merged 3 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 252 additions & 0 deletions nemo/collections/asr/modules/audio_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,258 @@ def forward(self, input: torch.Tensor, input_length: torch.Tensor) -> Tuple[torc
return masks, mask_length


class MaskEstimatorGSS(NeuralModule):
"""Estimate masks using guided source separation with a complex
angular Central Gaussian Mixture Model (cACGMM) [1].

This module corresponds to `GSS` in Fig. 2 in [2].

Notation is approximately following [1], where `gamma` denotes
the time-frequency mask, `alpha` denotes the mixture weights,
and `BM` denotes the shape matrix. Additionally, the provided
source activity is denoted as `activity`.

Args:
num_iterations: Number of iterations for the EM algorithm
eps: Small value for regularization
dtype: Data type for internal computations (default `torch.cdouble`)

References:
[1] Ito et al., Complex Angular Central Gaussian Mixture Model for Directional Statistics in Mask-Based Microphone Array Signal Processing, 2016
[2] Boeddeker et al., Front-End Processing for the CHiME-5 Dinner Party Scenario, 2018
anteju marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self, num_iterations: int = 3, eps: float = 1e-8, dtype: torch.dtype = torch.cdouble):
super().__init__()

if num_iterations <= 0:
raise ValueError(f'Number of iterations must be positive, got {num_iterations}')

# number of iterations for the EM algorithm
self.num_iterations = num_iterations

if eps <= 0:
raise ValueError(f'eps must be positive, got {eps}')

# small regularization constant
self.eps = eps

# internal calculations
if dtype not in [torch.cfloat, torch.cdouble]:
raise ValueError(f'Unsupported dtype {dtype}, expecting cfloat or cdouble')
self.dtype = dtype

logging.debug('Initialized %s', self.__class__.__name__)
logging.debug('\tnum_iterations: %s', self.num_iterations)
logging.debug('\teps: %g', self.eps)
logging.debug('\tdtype: %s', self.dtype)

def normalize(self, x: torch.Tensor, dim: int = 1) -> torch.Tensor:
"""Normalize input to have a unit L2-norm across `dim`.
By default, normalizes across the input channels.

Args:
x: C-channel input signal, shape (B, C, F, T)
dim: Dimension for normalization, defaults to -3 to normalize over channels

Returns:
Normalized signal, shape (B, C, F, T)
"""
norm_x = torch.linalg.vector_norm(x, ord=2, dim=dim, keepdim=True)
x = x / (norm_x + self.eps)
return x

@typecheck(
input_types={
'alpha': NeuralType(('B', 'C', 'D')),
'activity': NeuralType(('B', 'C', 'T')),
'log_pdf': NeuralType(('B', 'C', 'D', 'T')),
},
output_types={'gamma': NeuralType(('B', 'C', 'D', 'T')),},
)
def update_masks(self, alpha: torch.Tensor, activity: torch.Tensor, log_pdf: torch.Tensor) -> torch.Tensor:
"""Update masks for the cACGMM.

Args:
alpha: component weights, shape (B, num_outputs, F)
activity: temporal activity for the components, shape (B, num_outputs, T)
log_pdf: logarithm of the PDF, shape (B, num_outputs, F, T)

Returns:
Masks for the components of the model, shape (B, num_outputs, F, T)
"""
# (B, num_outputs, F)
# normalize across outputs in the log domain
log_gamma = log_pdf - torch.max(log_pdf, axis=-3, keepdim=True)[0]

gamma = torch.exp(log_gamma)

# calculate the mask using weight, pdf and source activity
gamma = alpha[..., None] * gamma * activity[..., None, :]

# normalize across components/output channels
gamma = gamma / (torch.sum(gamma, dim=-3, keepdim=True) + self.eps)

return gamma

@typecheck(
input_types={'gamma': NeuralType(('B', 'C', 'D', 'T')),}, output_types={'alpha': NeuralType(('B', 'C', 'D')),},
)
def update_weights(self, gamma: torch.Tensor) -> torch.Tensor:
"""Update weights for the individual components
in the mixture model.

Args:
gamma: masks, shape (B, num_outputs, F, T)

Returns:
Component weights, shape (B, num_outputs, F)
"""
alpha = torch.mean(gamma, dim=-1)
return alpha

@typecheck(
input_types={
'z': NeuralType(('B', 'C', 'D', 'T')),
'gamma': NeuralType(('B', 'C', 'D', 'T')),
'zH_invBM_z': NeuralType(('B', 'C', 'D', 'T')),
},
output_types={'log_pdf': NeuralType(('B', 'C', 'D', 'T')), 'zH_invBM_z': NeuralType(('B', 'C', 'D', 'T')),},
)
def update_pdf(
self, z: torch.Tensor, gamma: torch.Tensor, zH_invBM_z: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Update PDF of the cACGMM.

Args:
z: directional statistics, shape (B, num_inputs, F, T)
gamma: masks, shape (B, num_outputs, F, T)
zH_invBM_z: energy weighted by shape matrices, shape (B, num_outputs, F, T)

Returns:
Logarithm of the PDF, shape (B, num_outputs, F, T), the energy term, shape (B, num_outputs, F, T)
"""
num_inputs = z.size(-3)

# shape (B, num_outputs, F, T)
scale = gamma / (zH_invBM_z + self.eps)

# scale outer product and sum over time
# shape (B, num_outputs, F, num_inputs, num_inputs)
BM = num_inputs * torch.einsum('bmft,bift,bjft->bmfij', scale.to(z.dtype), z, z.conj())

# normalize across time
denom = torch.sum(gamma, dim=-1)
BM = BM / (denom[..., None, None] + self.eps)

# make sure the matrix is Hermitian
BM = (BM + BM.conj().transpose(-1, -2)) / 2

# use eigenvalue decomposition to calculate the log determinant
# and the inverse-weighted energy term
L, Q = torch.linalg.eigh(BM)

# BM is positive definite, so all eigenvalues should be positive
# However, small negative values may occur due to a limited precision
L = torch.clamp(L.real, min=self.eps)

# PDF is invariant to scaling of the shape matrix [1], so
# eignevalues can be normalized (across num_inputs)
L = L / (torch.max(L, axis=-1, keepdim=True)[0] + self.eps)

# small regularization to avoid numerical issues
L = L + self.eps

# calculate the log determinant using the eigenvalues
log_detBM = torch.sum(torch.log(L), dim=-1)

# calculate the energy term using the inverse eigenvalues
# NOTE: keeping an alternative implementation for reference (slower)
# zH_invBM_z = torch.einsum('bift,bmfij,bmfj,bmfkj,bkft->bmft', z.conj(), Q, (1 / L).to(Q.dtype), Q.conj(), z)
# zH_invBM_z = zH_invBM_z.abs() + self.eps # small regularization

# calc sqrt(L) * Q^H * z
zH_invBM_z = torch.einsum('bmfj,bmfkj,bkft->bmftj', (1 / L.sqrt()).to(Q.dtype), Q.conj(), z)
# calc squared norm
zH_invBM_z = zH_invBM_z.abs().pow(2).sum(-1)
# small regularization
zH_invBM_z = zH_invBM_z + self.eps

# final log PDF
log_pdf = -num_inputs * torch.log(zH_invBM_z) - log_detBM[..., None]

return log_pdf, zH_invBM_z

@property
def input_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports.
"""
return {
"input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
"activity": NeuralType(('B', 'C', 'T')),
}

@property
def output_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports.
"""
return {
"gamma": NeuralType(('B', 'C', 'D', 'T')),
}

@typecheck()
def forward(self, input: torch.Tensor, activity: torch.Tensor) -> torch.Tensor:
"""Apply GSS to estimate the time-frequency masks for each output source.

Args:
input: batched C-channel input signal, shape (B, num_inputs, F, T)
activity: batched frame-wise activity for each output source, shape (B, num_outputs, T)

Returns:
Masks for the components of the model, shape (B, num_outputs, F, T)
"""
B, num_inputs, F, T = input.shape
num_outputs = activity.size(1)

if activity.size(0) != B:
raise ValueError(f'Batch dimension mismatch: activity {activity.shape} vs input {input.shape}')

if activity.size(-1) != T:
raise ValueError(f'Time dimension mismatch: activity {activity.shape} vs input {input.shape}')

if num_outputs == 1:
raise ValueError(f'Expecting multiple outputs, got {num_outputs}')

with torch.cuda.amp.autocast(enabled=False):
input = input.to(dtype=self.dtype)

assert input.is_complex(), f'Expecting complex input, got {input.dtype}'

# convert input to directional statistics by normalizing across channels
z = self.normalize(input, dim=-3)

# initialize masks
gamma = torch.clamp(activity, min=self.eps)
# normalize across channels
gamma = gamma / torch.sum(gamma, dim=-2, keepdim=True)
# expand to input shape
gamma = gamma.unsqueeze(2).expand(-1, -1, F, -1)

# initialize the energy term
zH_invBM_z = torch.ones(B, num_outputs, F, T, dtype=input.dtype, device=input.device)

# EM iterations
for it in range(self.num_iterations):
alpha = self.update_weights(gamma=gamma)
log_pdf, zH_invBM_z = self.update_pdf(z=z, gamma=gamma, zH_invBM_z=zH_invBM_z)
gamma = self.update_masks(alpha=alpha, activity=activity, log_pdf=log_pdf)

if torch.any(torch.isnan(gamma)):
raise RuntimeError(f'gamma contains NaNs: {gamma}')

return gamma


class MaskReferenceChannel(NeuralModule):
"""A simple mask processor which applies mask
on ref_channel of the input signal.
Expand Down
34 changes: 34 additions & 0 deletions tests/collections/asr/test_audio_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from nemo.collections.asr.modules.audio_modules import (
MaskBasedDereverbWPE,
MaskEstimatorFlexChannels,
MaskEstimatorGSS,
MaskReferenceChannel,
SpectrogramToMultichannelFeatures,
WPEFilter,
Expand Down Expand Up @@ -414,3 +415,36 @@ def test_flex_channels(
assert torch.all(
mask_length == spec_length
), f'Output length mismatch: expected {spec_length}, got {mask_length}'

@pytest.mark.unit
@pytest.mark.parametrize('num_channels', [1, 4])
@pytest.mark.parametrize('num_subbands', [32, 65])
@pytest.mark.parametrize('num_outputs', [2, 3])
@pytest.mark.parametrize('batch_size', [1, 4])
def test_gss(self, num_channels: int, num_subbands: int, num_outputs: int, batch_size: int):
"""Test initialization of the GSS mask estimator and make sure it can process an input tensor.
This tests initialization and the output shape. It does not test correctness of the output.
"""
# Test vector length
num_frames = 50

# Instantiate UUT
uut = MaskEstimatorGSS()

# Process the current configuration
logging.debug('Process num_channels=%d', num_channels)
input_size = (batch_size, num_channels, num_subbands, num_frames)
logging.debug('Input size: %s', input_size)

# multi-channel input
mixture_spec = torch.randn(input_size, dtype=torch.cfloat)
source_activity = torch.randn(batch_size, num_outputs, num_frames) > 0

# UUT
mask = uut(input=mixture_spec, activity=source_activity)

# Check output dimensions match
expected_mask_shape = (batch_size, num_outputs, num_subbands, num_frames)
assert (
mask.shape == expected_mask_shape
), f'Output shape mismatch: expected {expected_mask_shape}, got {mask.shape}'
Loading