Skip to content

Commit e138cfb

Browse files
committed
Merge branch 'canary_speechllm1_cross_t5_pr3' of https://github.com/pzelasko/NeMo into canary_speechllm1_cross_t5_pr3
2 parents c5c8c42 + c7e06a9 commit e138cfb

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

nemo/collections/asr/modules/conformer_encoder.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -654,14 +654,15 @@ def forward_internal(
654654
return audio_signal, length
655655

656656
def update_max_seq_length(self, seq_length: int, device):
657+
# TODO: this sync seems unnecessary, remove or disable properly later
657658
# Find global max audio length across all nodes
658-
if torch.distributed.is_initialized():
659-
global_max_len = torch.tensor([seq_length], dtype=torch.float32, device=device)
659+
#if torch.distributed.is_initialized():
660+
# global_max_len = torch.tensor([seq_length], dtype=torch.float32, device=device)
660661

661-
# Update across all ranks in the distributed system
662-
torch.distributed.all_reduce(global_max_len, op=torch.distributed.ReduceOp.MAX)
662+
# # Update across all ranks in the distributed system
663+
# torch.distributed.all_reduce(global_max_len, op=torch.distributed.ReduceOp.MAX)
663664

664-
seq_length = global_max_len.int().item()
665+
# seq_length = global_max_len.int().item()
665666

666667
if seq_length > self.max_audio_length:
667668
self.set_max_audio_length(seq_length)
@@ -1053,7 +1054,7 @@ def change_attention_model(
10531054

10541055
def change_subsampling_conv_chunking_factor(self, subsampling_conv_chunking_factor: int):
10551056
"""
1056-
Update the conv_chunking_factor (int)
1057+
Update the conv_chunking_factor (int)
10571058
Default is 1 (auto)
10581059
Set it to -1 (disabled) or to a specific value (power of 2) if you OOM in the conv subsampling layers
10591060
@@ -1116,9 +1117,9 @@ def get_accepted_adapter_types(self,) -> Set[type]:
11161117
class ConformerMultiLayerFeatureExtractor(NeuralModule, Exportable, AccessMixin):
11171118
"""
11181119
A wrapper module that extracts features from multiple layers of a ConformerEncoder,
1119-
by reusing existing mechanisim for interctc loss.
1120+
by reusing existing mechanisim for interctc loss.
11201121
To use it, set `layer_idx_list` to specify the indices of layers to extract from.
1121-
Also, you can specify an `aggretator` module to aggregate the features from different layers, default not aggregating.
1122+
Also, you can specify an `aggretator` module to aggregate the features from different layers, default not aggregating.
11221123
"""
11231124

11241125
def __init__(

0 commit comments

Comments
 (0)