Skip to content

Commit

Permalink
Disable conformer seq len NCCL sync
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Żelasko <[email protected]>
  • Loading branch information
pzelasko committed May 17, 2024
1 parent defe205 commit c7e06a9
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions nemo/collections/asr/modules/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,14 +654,15 @@ def forward_internal(
return audio_signal, length

def update_max_seq_length(self, seq_length: int, device):
# TODO: this sync seems unnecessary, remove or disable properly later
# Find global max audio length across all nodes
if torch.distributed.is_initialized():
global_max_len = torch.tensor([seq_length], dtype=torch.float32, device=device)
#if torch.distributed.is_initialized():
# global_max_len = torch.tensor([seq_length], dtype=torch.float32, device=device)

# Update across all ranks in the distributed system
torch.distributed.all_reduce(global_max_len, op=torch.distributed.ReduceOp.MAX)
# # Update across all ranks in the distributed system
# torch.distributed.all_reduce(global_max_len, op=torch.distributed.ReduceOp.MAX)

seq_length = global_max_len.int().item()
# seq_length = global_max_len.int().item()

if seq_length > self.max_audio_length:
self.set_max_audio_length(seq_length)
Expand Down Expand Up @@ -1053,7 +1054,7 @@ def change_attention_model(

def change_subsampling_conv_chunking_factor(self, subsampling_conv_chunking_factor: int):
"""
Update the conv_chunking_factor (int)
Update the conv_chunking_factor (int)
Default is 1 (auto)
Set it to -1 (disabled) or to a specific value (power of 2) if you OOM in the conv subsampling layers
Expand Down Expand Up @@ -1116,9 +1117,9 @@ def get_accepted_adapter_types(self,) -> Set[type]:
class ConformerMultiLayerFeatureExtractor(NeuralModule, Exportable, AccessMixin):
"""
A wrapper module that extracts features from multiple layers of a ConformerEncoder,
by reusing existing mechanisim for interctc loss.
by reusing existing mechanisim for interctc loss.
To use it, set `layer_idx_list` to specify the indices of layers to extract from.
Also, you can specify an `aggretator` module to aggregate the features from different layers, default not aggregating.
Also, you can specify an `aggretator` module to aggregate the features from different layers, default not aggregating.
"""

def __init__(
Expand Down

0 comments on commit c7e06a9

Please sign in to comment.