Skip to content

Commit

Permalink
fix att_context_size bug for older models. (#6635)
Browse files Browse the repository at this point in the history
Signed-off-by: Vahid <[email protected]>
  • Loading branch information
VahidooX committed May 11, 2023
1 parent facd003 commit 50f895a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
11 changes: 7 additions & 4 deletions nemo/collections/asr/modules/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch
import torch.distributed
import torch.nn as nn
from omegaconf import DictConfig, ListConfig
from omegaconf import DictConfig, ListConfig, open_dict

from nemo.collections.asr.models.configs import CacheAwareStreamingConfig
from nemo.collections.asr.parts.mixins.streaming import StreamingEncoder
Expand Down Expand Up @@ -884,8 +884,10 @@ def change_attention_model(

if att_context_size:
att_context_size = list(att_context_size)
else:
elif hasattr(self._cfg, "att_context_size"):
att_context_size = self._cfg.att_context_size
else:
att_context_size = self.att_context_size

if self_attention_model is None:
self_attention_model = self._cfg.self_attention_model
Expand Down Expand Up @@ -971,8 +973,9 @@ def change_attention_model(
m.self_attention_model = self_attention_model

if update_config:
self._cfg.self_attention_model = self_attention_model
self._cfg.att_context_size = att_context_size
with open_dict(self._cfg):
self._cfg.self_attention_model = self_attention_model
self._cfg.att_context_size = att_context_size


class ConformerEncoderAdapter(ConformerEncoder, adapter_mixins.AdapterModuleMixin):
Expand Down
8 changes: 6 additions & 2 deletions nemo/collections/asr/parts/mixins/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,9 @@ def change_attention_model(
update_config (bool): Whether to update the config or not with the new attention model.
Defaults to True.
"""
if self_attention_model is None and att_context_size is None:
return

if not hasattr(self, 'encoder'):
logging.info(
"Could not change the self_attention_model in encoder "
Expand All @@ -425,8 +428,9 @@ def change_attention_model(

self.encoder.change_attention_model(self_attention_model, att_context_size, update_config, self.device)
if update_config:
self.cfg.encoder.self_attention_model = self_attention_model
self.cfg.encoder.att_context_size = att_context_size
with open_dict(self.cfg):
self.cfg.encoder.self_attention_model = self_attention_model
self.cfg.encoder.att_context_size = att_context_size

def conformer_stream_step(
self,
Expand Down

0 comments on commit 50f895a

Please sign in to comment.