Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ASR] Add optimization util for linear sum assignment algorithm #6349

Merged
merged 37 commits into from
Apr 14, 2023

Conversation

tango4j
Copy link
Collaborator

@tango4j tango4j commented Apr 3, 2023

What does this PR do ?

  1. Add a optimization util file for linear sum assignment (LSA) solver algorithm for online-diarization/multi-speaker-ASR
    LSA problem solver is needed for the following tasks in NeMo:
    (1) Permutation Invariant Loss (PIL) for diarization model training
    (2) Label permutation matching for online speaker diarzation
    (3) Concatenated minimum-permutation Word Error Rate (cp-WER) calculation

What is LSA solver algorithm? Google OR-tools LSA Solver

The NeMo linear_sum_assignment function is compared with scipy.optimization.linear_sum_assingment.
In the unit-test for NeMo LSA solver, the result is compared with the scipy version of linear_sum_assignment.

  1. Removing @torch.jit.script decorator in speaker_utils.py since it creates type-errors when the code is not used for production purpose.
    Instead, all torch.jit.script required classes and functions are tested in test_diar_utils.py.
    Take a look at these tests for checking jit_script = [True/False] and cuda = [True/False] (testing total 4 combinations)

  2. Also refactored some of the functions in online diarization

    • replaced scipy LSA solver to NeMo LSA solver in online_clustering.py.
  3. Added a couple of functions in der.py for online diarization DER calculation.

    • replaced scipy LSA solver to NeMo LSA solver in der.py.

Collection: [ASR]

Changelog

  • nemo/collections/asr/metrics/der.py
    : replaced scipy LSA solver to NeMo LSA solver in calculate_session_cpWER function.
    : Added two functions for online diarization evaluations: get_partial_ref_labels and get_online_DER_stats.

  • nemo/collections/asr/models/online_diarizer.py
    : Made _perform_online_clustering function simpler by moving get_reduced_mat and match_labels into online clustering function.

  • nemo/collections/asr/parts/utils/offline_clustering.py
    : Added laplacian = laplacian.float().to(torch.device('cpu')) to avoid jit-scripted module uses GPU even when CPU is specified or vice-versa. This behavior is always tested/checked in test_diar_utils.py.

  • nemo/collections/asr/parts/utils/online_clustering.py
    : replaced scipy LSA solver to NeMo LSA solver in get_lsa_speaker_mapping function.
    : Modified the docstrings of update_speaker_history_buffer to make the example easier.

  • nemo/collections/asr/parts/utils/optimization_utils.py
    : Fully torch-jit-scriptable, linear sum assignment problem solver class and function were added.

  • nemo/collections/asr/parts/utils/speaker_utils.py
    : Removed @torch.jit.script decorators since this creates unnecessary warning messages and type related errors when used without scripting.

  • tests/collections/asr/test_diar_metrics.py
    : Added unit-tests for the newly added function get_partial_ref_labels and get_online_DER_stats.

  • tests/collections/asr/test_diar_utils.py
    : Added tests for offline clustering and online clustering for many different cases including:
    [jit-script=True, cuda=True],
    [jit-script=True, cuda=False],
    [jit-script=False, cuda=True],
    [jit-script=False, cuda=False] cases
    which is using the torch-jit-scripted NeMo linear_sum_assignment function.

Usage

from nemo.collections.asr.parts.utils.optimization_utils import linear_sum_assignment
#An example cost matrix to be solved
cost_matrix = \
torch.tensor([[7, 6, 2], [6, 2, 1], [5, 6, 8]])
row_ind_nm, col_ind_nm = linear_sum_assignment(cost_matrix)

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.

@github-actions github-actions bot added the ASR label Apr 3, 2023
@tango4j tango4j requested a review from nithinraok April 3, 2023 02:31
nemo/collections/asr/models/online_diarizer.py Fixed Show resolved Hide resolved
nemo/collections/asr/metrics/der.py Fixed Show resolved Hide resolved
@tango4j tango4j marked this pull request as ready for review April 3, 2023 23:30
@tango4j tango4j requested a review from fayejf April 3, 2023 23:36
@tango4j tango4j marked this pull request as draft April 3, 2023 23:37
@tango4j tango4j marked this pull request as ready for review April 4, 2023 18:05
Copy link
Collaborator

@nithinraok nithinraok left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor review. Will do thorough review tomorrow.
Very neat improvement, need to understand better from my end.

@@ -552,12 +550,13 @@ def eigDecompose(
device = torch.cuda.current_device()
laplacian = laplacian.float().to(device)
else:
laplacian = laplacian.float().to(torch.device('cpu'))
laplacian = laplacian.float()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why same operation twice?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

laplacian = laplacian.float()
lambdas, diffusion_map = eigh(laplacian)
return lambdas, diffusion_map


def eigValueSh(laplacian: torch.Tensor, cuda: bool, device: torch.device = torch.device('cpu')) -> torch.Tensor:
def eigValueSh(laplacian: torch.Tensor, cuda: bool, device: torch.device) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why cuda and device? Isn't only one sufficient

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was added long back because there are users setting cuda=True but device=cpu.
This is adding some flexibility to avoid errors on such cases.
If we need to remove this, lt requires a speparate PR since this involves whole diarization pipeline.

Comment on lines 568 to 569
laplacian = laplacian.float().to(torch.device('cpu'))
laplacian = laplacian.float()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here. laplacian.float() twice

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

stacked = np.hstack((enc_P, enc_Q))
cost = -1 * linear_kernel(stacked.T)[spk_count:, :spk_count]
row_ind, col_ind = linear_sum_assignment(cost)
PandQ_list: List[int] = [int(x.item()) for x in PandQ]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: mentioning dtype in variable name need to be avoided

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense since types are strictly annotated for jit script functions.
Fixed.

marked (Tensor): 2D matrix containing the marked zeros.
"""

def __init__(self, cost_matrix):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor, mention the dtype of cost_matrix here. Isn;t it necessary for jit scripting?

Copy link
Collaborator Author

@tango4j tango4j Apr 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is no type annotation, jit compiler think of it as torch.Tensor.
So in general if it is not torch.Tensor, type annotation is needed.
Added type annotations

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it

Comment on lines +312 to +316
if cost_matrix.shape[1] < cost_matrix.shape[0]:
cost_matrix = cost_matrix.T
transposed = True
else:
transposed = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why extra transposed variable, Use the same col < row condition below?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This followed the original implementation in scipy.
If we don't use transposed variable, we need to create another variable to indicate that foo = cost_matrix.shape[1] < cost_matrix.shape[0].

# Copyright (c) 2008 Brian M. Clapper <[email protected]>, Gael Varoquaux
# Author: Brian M. Clapper, Gael Varoquaux
# License: 3-clause BSD

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have only one optimization algorithm yet? Thinking if we should move other funcs to this file as well

Copy link
Collaborator Author

@tango4j tango4j Apr 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can add other algorithms below this. (I mentioned "Linear Sum Assignment solver")
The copyright in the beginning of the code is the convention in the most of the project so I followed

for label in ref_labels:
start, end, speaker = label.split()
start, end = float(start), float(end)
# If the current [start, end] interval is latching the last prediction time
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

latching -> matching

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the expression (Checked by Elena)

@@ -31,67 +31,67 @@
# https://arxiv.org/pdf/2003.02405.pdf and the implementation from
# https://github.com/tango4j/Auto-Tuning-Spectral-Clustering.

from typing import List, Tuple
from typing import List, Set, Tuple

Check notice

Code scanning / CodeQL

Unused import

Import of 'Set' is not used.
Copy link
Collaborator

@nithinraok nithinraok left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tango4j tango4j merged commit ae55b52 into NVIDIA:main Apr 14, 2023
hsiehjackson pushed a commit to hsiehjackson/NeMo that referenced this pull request Jun 2, 2023
…IA#6349)

* [ASR] Add optimization utils for cpWER, diarization training, online diarization

Signed-off-by: Taejin Park <[email protected]>

* Fixed GPU/CPU issues for clustering

Signed-off-by: Taejin Park <[email protected]>

* Fixed unreachable state

Signed-off-by: Taejin Park <[email protected]>

* resolved jit script compile error for lsa algorithm

Signed-off-by: Taejin Park <[email protected]>

* Fixed errors and bugs, checked tests

Signed-off-by: Taejin Park <[email protected]>

* Fixed docstrings

Signed-off-by: Taejin Park <[email protected]>

* Update changes on test files

Signed-off-by: Taejin Park <[email protected]>

* Refactored functions

Signed-off-by: Taejin Park <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Adding docstrings for the functions in der.py

Signed-off-by: Taejin Park <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed wrong docstrings in der.py

Signed-off-by: Taejin Park <[email protected]>

* Fixed a wrong docstring

Signed-off-by: Taejin Park <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Changed np.array input to Tensor for LSA solver in der.py

Signed-off-by: Taejin Park <[email protected]>

* Added code-QL issues and unit-tests for der.py functions

Signed-off-by: Taejin Park <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Removed print line in der.py

Signed-off-by: Taejin Park <[email protected]>

* Fixed code QL redundant comparison

Signed-off-by: Taejin Park <[email protected]>

* Fixed code QL issue

Signed-off-by: Taejin Park <[email protected]>

* Added License for the reference code

Signed-off-by: Taejin Park <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Added full license text of the original code

Signed-off-by: Taejin Park <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Reflected comments

Signed-off-by: Taejin Park <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Reflected review comments

Signed-off-by: Taejin Park <[email protected]>

---------

Signed-off-by: Taejin Park <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: hsiehjackson <[email protected]>
@tango4j tango4j deleted the fix/clus_spk_util_jit branch December 6, 2023 21:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants