@@ -654,14 +654,15 @@ def forward_internal(
654
654
return audio_signal , length
655
655
656
656
def update_max_seq_length (self , seq_length : int , device ):
657
+ # TODO: this sync seems unnecessary, remove or disable properly later
657
658
# Find global max audio length across all nodes
658
- if torch .distributed .is_initialized ():
659
- global_max_len = torch .tensor ([seq_length ], dtype = torch .float32 , device = device )
659
+ # if torch.distributed.is_initialized():
660
+ # global_max_len = torch.tensor([seq_length], dtype=torch.float32, device=device)
660
661
661
- # Update across all ranks in the distributed system
662
- torch .distributed .all_reduce (global_max_len , op = torch .distributed .ReduceOp .MAX )
662
+ # # Update across all ranks in the distributed system
663
+ # torch.distributed.all_reduce(global_max_len, op=torch.distributed.ReduceOp.MAX)
663
664
664
- seq_length = global_max_len .int ().item ()
665
+ # seq_length = global_max_len.int().item()
665
666
666
667
if seq_length > self .max_audio_length :
667
668
self .set_max_audio_length (seq_length )
@@ -1053,7 +1054,7 @@ def change_attention_model(
1053
1054
1054
1055
def change_subsampling_conv_chunking_factor (self , subsampling_conv_chunking_factor : int ):
1055
1056
"""
1056
- Update the conv_chunking_factor (int)
1057
+ Update the conv_chunking_factor (int)
1057
1058
Default is 1 (auto)
1058
1059
Set it to -1 (disabled) or to a specific value (power of 2) if you OOM in the conv subsampling layers
1059
1060
@@ -1116,9 +1117,9 @@ def get_accepted_adapter_types(self,) -> Set[type]:
1116
1117
class ConformerMultiLayerFeatureExtractor (NeuralModule , Exportable , AccessMixin ):
1117
1118
"""
1118
1119
A wrapper module that extracts features from multiple layers of a ConformerEncoder,
1119
- by reusing existing mechanisim for interctc loss.
1120
+ by reusing existing mechanisim for interctc loss.
1120
1121
To use it, set `layer_idx_list` to specify the indices of layers to extract from.
1121
- Also, you can specify an `aggretator` module to aggregate the features from different layers, default not aggregating.
1122
+ Also, you can specify an `aggretator` module to aggregate the features from different layers, default not aggregating.
1122
1123
"""
1123
1124
1124
1125
def __init__ (
0 commit comments