Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Hainan Xu committed Sep 29, 2024
1 parent bd014d9 commit fcd7278
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 17 deletions.
1 change: 1 addition & 0 deletions nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,7 @@ def forward(
processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length)

encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length)
encoded_len = -(-encoded_len // 8)
return encoded, encoded_len

# PTL-specific methods
Expand Down
12 changes: 11 additions & 1 deletion nemo/collections/asr/modules/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,7 @@ def __init__(
self,
jointnet: Dict[str, Any],
num_classes: int,
sub_sampling_factor: int = 8,
num_extra_outputs: int = 0,
vocabulary: Optional[List] = None,
log_softmax: Optional[bool] = None,
Expand Down Expand Up @@ -1351,6 +1352,7 @@ def __init__(
self.pred_hidden = jointnet['pred_hidden']
self.joint_hidden = jointnet['joint_hidden']
self.activation = jointnet['activation']
self.sub_sampling_factor = sub_sampling_factor

# Optional arguments
dropout = jointnet.get('dropout', 0.0)
Expand Down Expand Up @@ -1383,6 +1385,14 @@ def forward(
# encoder = (B, D, T)
# decoder = (B, D, U) if passed, else None
encoder_outputs = encoder_outputs.transpose(1, 2) # (B, T, D)
B, T, D = encoder_outputs.shape
s = self.sub_sampling_factor
if T % s != 0:
t_to_add = s - T % s
encoder_outputs = torch.cat([encoder_outputs, torch.zeros([B, t_to_add, D]).to(encoder_outputs.device)], dim=1)
T = T + t_to_add

encoder_outputs = torch.reshape(encoder_outputs, [B, T // s, D * s])

if decoder_outputs is not None:
decoder_outputs = decoder_outputs.transpose(1, 2) # (B, U, D)
Expand Down Expand Up @@ -1622,7 +1632,7 @@ def _joint_net_modules(self, num_classes, pred_n_hidden, enc_n_hidden, joint_n_h
dropout: Dropout value to apply to joint.
"""
pred = torch.nn.Linear(pred_n_hidden, joint_n_hidden)
enc = torch.nn.Linear(enc_n_hidden, joint_n_hidden)
enc = torch.nn.Linear(enc_n_hidden * self.sub_sampling_factor, joint_n_hidden)

if activation not in ['relu', 'sigmoid', 'tanh']:
raise ValueError("Unsupported activation for joint step - please pass one of " "[relu, sigmoid, tanh]")
Expand Down
10 changes: 9 additions & 1 deletion nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,11 +402,19 @@ def forward(

@torch.no_grad()
def _greedy_decode(
self, x: torch.Tensor, out_len: torch.Tensor, partial_hypotheses: Optional[rnnt_utils.Hypothesis] = None
self, x: torch.Tensor, out_len: torch.Tensor, subsampling_factor: int = 8, partial_hypotheses: Optional[rnnt_utils.Hypothesis] = None
):
# x: [T, 1, D]
# out_len: [seq_len]

T, _, D = x.shape
if T % subsampling_factor != 0:
t_to_add = subsampling_factor - T % subsampling_factor
x = torch.cat([x, torch.zeros([t_to_add, 1, D]).to(x.device)], dim=0)

x = torch.reshape(x, [-1, 1, subsampling_factor * D])
out_len = x.shape[0]

# Initialize blank state and empty label set in Hypothesis
hypothesis = rnnt_utils.Hypothesis(score=0.0, y_sequence=[], dec_state=None, timestep=[], last_token=None)

Expand Down
30 changes: 15 additions & 15 deletions nemo/collections/nlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,18 @@
IntentSlotClassificationModel,
MultiLabelIntentSlotClassificationModel,
)
from nemo.collections.nlp.models.language_modeling import MegatronGPTPromptLearningModel
from nemo.collections.nlp.models.language_modeling.bert_lm_model import BERTLMModel
from nemo.collections.nlp.models.language_modeling.transformer_lm_model import TransformerLMModel
from nemo.collections.nlp.models.machine_translation import MTEncDecModel
from nemo.collections.nlp.models.question_answering.qa_model import QAModel
from nemo.collections.nlp.models.spellchecking_asr_customization import SpellcheckingAsrCustomizationModel
from nemo.collections.nlp.models.text2sparql.text2sparql_model import Text2SparqlModel
from nemo.collections.nlp.models.text_classification import TextClassificationModel
from nemo.collections.nlp.models.text_normalization_as_tagging import ThutmoseTaggerModel
from nemo.collections.nlp.models.token_classification import (
PunctuationCapitalizationLexicalAudioModel,
PunctuationCapitalizationModel,
TokenClassificationModel,
)
from nemo.collections.nlp.models.zero_shot_intent_recognition import ZeroShotIntentModel
#from nemo.collections.nlp.models.language_modeling import MegatronGPTPromptLearningModel
#from nemo.collections.nlp.models.language_modeling.bert_lm_model import BERTLMModel
#from nemo.collections.nlp.models.language_modeling.transformer_lm_model import TransformerLMModel
#from nemo.collections.nlp.models.machine_translation import MTEncDecModel
#from nemo.collections.nlp.models.question_answering.qa_model import QAModel
#from nemo.collections.nlp.models.spellchecking_asr_customization import SpellcheckingAsrCustomizationModel
#from nemo.collections.nlp.models.text2sparql.text2sparql_model import Text2SparqlModel
#from nemo.collections.nlp.models.text_classification import TextClassificationModel
#from nemo.collections.nlp.models.text_normalization_as_tagging import ThutmoseTaggerModel
#from nemo.collections.nlp.models.token_classification import (
# PunctuationCapitalizationLexicalAudioModel,
# PunctuationCapitalizationModel,
# TokenClassificationModel,
#)
#from nemo.collections.nlp.models.zero_shot_intent_recognition import ZeroShotIntentModel

0 comments on commit fcd7278

Please sign in to comment.