Skip to content

Commit

Permalink
Fix py3.11 dataclasses issue (#7582)
Browse files Browse the repository at this point in the history
* Update ASR configs to support Python 3.11

Signed-off-by: smajumdar <[email protected]>

* Update TTS configs to support Python 3.11

Signed-off-by: smajumdar <[email protected]>

* Guard MeCab and Ipadic

Signed-off-by: smajumdar <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix remaining ASR dataclasses

Signed-off-by: smajumdar <[email protected]>

* Fix remaining ASR dataclasses

Signed-off-by: smajumdar <[email protected]>

* Fix scripts

Signed-off-by: smajumdar <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: smajumdar <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and web-flow committed Oct 3, 2023
1 parent 0a4fba3 commit 6af30c9
Show file tree
Hide file tree
Showing 34 changed files with 352 additions and 235 deletions.
10 changes: 6 additions & 4 deletions examples/asr/experimental/k2/align_speech_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@


import os
from dataclasses import dataclass, is_dataclass
from dataclasses import dataclass, field, is_dataclass
from typing import Optional

import pytorch_lightning as ptl
Expand All @@ -94,12 +94,14 @@
@dataclass
class ParallelAlignmentConfig:
model: Optional[str] = None # name
predict_ds: ASRDatasetConfig = ASRDatasetConfig(return_sample_id=True, num_workers=4)
aligner_args: K2AlignerWrapperModelConfig = K2AlignerWrapperModelConfig()
predict_ds: ASRDatasetConfig = field(
default_factory=lambda: ASRDatasetConfig(return_sample_id=True, num_workers=4)
)
aligner_args: K2AlignerWrapperModelConfig = field(default_factory=lambda: K2AlignerWrapperModelConfig())
output_path: str = MISSING
model_stride: int = 8

trainer: TrainerConfig = TrainerConfig(gpus=-1, accelerator="ddp")
trainer: TrainerConfig = field(default_factory=lambda: TrainerConfig(gpus=-1, accelerator="ddp"))

# there arguments will be ignored
return_predictions: bool = False
Expand Down
8 changes: 4 additions & 4 deletions nemo/collections/asr/metrics/rnnt_wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import copy
import re
from abc import abstractmethod
from dataclasses import dataclass, is_dataclass
from dataclasses import dataclass, field, is_dataclass
from typing import Callable, Dict, List, Optional, Tuple, Union

import editdistance
Expand Down Expand Up @@ -1299,7 +1299,7 @@ class RNNTDecodingConfig:
preserve_alignments: Optional[bool] = None

# confidence config
confidence_cfg: ConfidenceConfig = ConfidenceConfig()
confidence_cfg: ConfidenceConfig = field(default_factory=lambda: ConfidenceConfig())

# RNNT Joint fused batch size
fused_batch_size: Optional[int] = None
Expand All @@ -1317,10 +1317,10 @@ class RNNTDecodingConfig:
rnnt_timestamp_type: str = "all" # can be char, word or all for both

# greedy decoding config
greedy: greedy_decode.GreedyRNNTInferConfig = greedy_decode.GreedyRNNTInferConfig()
greedy: greedy_decode.GreedyRNNTInferConfig = field(default_factory=lambda: greedy_decode.GreedyRNNTInferConfig())

# beam decoding config
beam: beam_decode.BeamRNNTInferConfig = beam_decode.BeamRNNTInferConfig(beam_size=4)
beam: beam_decode.BeamRNNTInferConfig = field(default_factory=lambda: beam_decode.BeamRNNTInferConfig(beam_size=4))

# can be used to change temperature for decoding
temperature: float = 1.0
12 changes: 8 additions & 4 deletions nemo/collections/asr/metrics/wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import re
from abc import abstractmethod
from dataclasses import dataclass, is_dataclass
from dataclasses import dataclass, field, is_dataclass
from typing import Callable, Dict, List, Optional, Tuple, Union

import editdistance
Expand Down Expand Up @@ -1297,13 +1297,17 @@ class CTCDecodingConfig:
batch_dim_index: int = 0

# greedy decoding config
greedy: ctc_greedy_decoding.GreedyCTCInferConfig = ctc_greedy_decoding.GreedyCTCInferConfig()
greedy: ctc_greedy_decoding.GreedyCTCInferConfig = field(
default_factory=lambda: ctc_greedy_decoding.GreedyCTCInferConfig()
)

# beam decoding config
beam: ctc_beam_decoding.BeamCTCInferConfig = ctc_beam_decoding.BeamCTCInferConfig(beam_size=4)
beam: ctc_beam_decoding.BeamCTCInferConfig = field(
default_factory=lambda: ctc_beam_decoding.BeamCTCInferConfig(beam_size=4)
)

# confidence config
confidence_cfg: ConfidenceConfig = ConfidenceConfig()
confidence_cfg: ConfidenceConfig = field(default_factory=lambda: ConfidenceConfig())

# can be used to change temperature for decoding
temperature: float = 1.0
8 changes: 4 additions & 4 deletions nemo/collections/asr/models/configs/aligner_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from dataclasses import dataclass, field

from nemo.collections.asr.parts.k2.classes import GraphModuleConfig

Expand All @@ -35,10 +35,10 @@ class AlignerWrapperModelConfig:
word_output: bool = True
cpu_decoding: bool = False
decode_batch_size: int = 0
ctc_cfg: AlignerCTCConfig = AlignerCTCConfig()
rnnt_cfg: AlignerRNNTConfig = AlignerRNNTConfig()
ctc_cfg: AlignerCTCConfig = field(default_factory=lambda: AlignerCTCConfig())
rnnt_cfg: AlignerRNNTConfig = field(default_factory=lambda: AlignerRNNTConfig())


@dataclass
class K2AlignerWrapperModelConfig(AlignerWrapperModelConfig):
decoder_module_cfg: GraphModuleConfig = GraphModuleConfig()
decoder_module_cfg: GraphModuleConfig = field(default_factory=lambda: GraphModuleConfig())
30 changes: 19 additions & 11 deletions nemo/collections/asr/models/configs/asr_models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional

from omegaconf import MISSING
Expand Down Expand Up @@ -74,24 +74,32 @@ class EncDecCTCConfig(model_cfg.ModelConfig):
labels: List[str] = MISSING

# Dataset configs
train_ds: ASRDatasetConfig = ASRDatasetConfig(manifest_filepath=None, shuffle=True)
validation_ds: ASRDatasetConfig = ASRDatasetConfig(manifest_filepath=None, shuffle=False)
test_ds: ASRDatasetConfig = ASRDatasetConfig(manifest_filepath=None, shuffle=False)
train_ds: ASRDatasetConfig = field(default_factory=lambda: ASRDatasetConfig(manifest_filepath=None, shuffle=True))
validation_ds: ASRDatasetConfig = field(
default_factory=lambda: ASRDatasetConfig(manifest_filepath=None, shuffle=False)
)
test_ds: ASRDatasetConfig = field(default_factory=lambda: ASRDatasetConfig(manifest_filepath=None, shuffle=False))

# Optimizer / Scheduler config
optim: Optional[model_cfg.OptimConfig] = model_cfg.OptimConfig(sched=model_cfg.SchedConfig())
optim: Optional[model_cfg.OptimConfig] = field(
default_factory=lambda: model_cfg.OptimConfig(sched=model_cfg.SchedConfig())
)

# Model component configs
preprocessor: AudioToMelSpectrogramPreprocessorConfig = AudioToMelSpectrogramPreprocessorConfig()
spec_augment: Optional[SpectrogramAugmentationConfig] = SpectrogramAugmentationConfig()
encoder: ConvASREncoderConfig = ConvASREncoderConfig()
decoder: ConvASRDecoderConfig = ConvASRDecoderConfig()
decoding: CTCDecodingConfig = CTCDecodingConfig()
preprocessor: AudioToMelSpectrogramPreprocessorConfig = field(
default_factory=lambda: AudioToMelSpectrogramPreprocessorConfig()
)
spec_augment: Optional[SpectrogramAugmentationConfig] = field(
default_factory=lambda: SpectrogramAugmentationConfig()
)
encoder: ConvASREncoderConfig = field(default_factory=lambda: ConvASREncoderConfig())
decoder: ConvASRDecoderConfig = field(default_factory=lambda: ConvASRDecoderConfig())
decoding: CTCDecodingConfig = field(default_factory=lambda: CTCDecodingConfig())


@dataclass
class EncDecCTCModelConfig(model_cfg.NemoConfig):
model: EncDecCTCConfig = EncDecCTCConfig()
model: EncDecCTCConfig = field(default_factory=lambda: EncDecCTCConfig())


@dataclass
Expand Down
40 changes: 25 additions & 15 deletions nemo/collections/asr/models/configs/classification_models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional

from omegaconf import MISSING
Expand Down Expand Up @@ -72,30 +72,40 @@ class EncDecClassificationConfig(model_cfg.ModelConfig):
timesteps: int = MISSING

# Dataset configs
train_ds: EncDecClassificationDatasetConfig = EncDecClassificationDatasetConfig(
manifest_filepath=None, shuffle=True, trim_silence=False
train_ds: EncDecClassificationDatasetConfig = field(
default_factory=lambda: EncDecClassificationDatasetConfig(
manifest_filepath=None, shuffle=True, trim_silence=False
)
)
validation_ds: EncDecClassificationDatasetConfig = EncDecClassificationDatasetConfig(
manifest_filepath=None, shuffle=False
validation_ds: EncDecClassificationDatasetConfig = field(
default_factory=lambda: EncDecClassificationDatasetConfig(manifest_filepath=None, shuffle=False)
)
test_ds: EncDecClassificationDatasetConfig = EncDecClassificationDatasetConfig(
manifest_filepath=None, shuffle=False
test_ds: EncDecClassificationDatasetConfig = field(
default_factory=lambda: EncDecClassificationDatasetConfig(manifest_filepath=None, shuffle=False)
)

# Optimizer / Scheduler config
optim: Optional[model_cfg.OptimConfig] = model_cfg.OptimConfig(sched=model_cfg.SchedConfig())
optim: Optional[model_cfg.OptimConfig] = field(
default_factory=lambda: model_cfg.OptimConfig(sched=model_cfg.SchedConfig())
)

# Model component configs
preprocessor: AudioToMFCCPreprocessorConfig = AudioToMFCCPreprocessorConfig()
spec_augment: Optional[SpectrogramAugmentationConfig] = SpectrogramAugmentationConfig()
crop_or_pad_augment: Optional[CropOrPadSpectrogramAugmentationConfig] = CropOrPadSpectrogramAugmentationConfig(
audio_length=timesteps
preprocessor: AudioToMFCCPreprocessorConfig = field(default_factory=lambda: AudioToMFCCPreprocessorConfig())
spec_augment: Optional[SpectrogramAugmentationConfig] = field(
default_factory=lambda: SpectrogramAugmentationConfig()
)
crop_or_pad_augment: Optional[CropOrPadSpectrogramAugmentationConfig] = field(
default_factory=lambda: CropOrPadSpectrogramAugmentationConfig(audio_length=-1)
)

encoder: ConvASREncoderConfig = ConvASREncoderConfig()
decoder: ConvASRDecoderClassificationConfig = ConvASRDecoderClassificationConfig()
encoder: ConvASREncoderConfig = field(default_factory=lambda: ConvASREncoderConfig())
decoder: ConvASRDecoderClassificationConfig = field(default_factory=lambda: ConvASRDecoderClassificationConfig())

def __post_init__(self):
if self.crop_or_pad_augment is not None:
self.crop_or_pad_augment.audio_length = self.timesteps


@dataclass
class EncDecClassificationModelConfig(model_cfg.NemoConfig):
model: EncDecClassificationConfig = EncDecClassificationConfig()
model: EncDecClassificationConfig = field(default_factory=lambda: EncDecClassificationConfig())
28 changes: 14 additions & 14 deletions nemo/collections/asr/models/configs/diarizer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import asdict, dataclass
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Optional, Tuple, Union


Expand Down Expand Up @@ -78,9 +78,9 @@ class ASRDiarizerParams(DiarizerComponentConfig):
@dataclass
class ASRDiarizerConfig(DiarizerComponentConfig):
model_path: Optional[str] = "stt_en_conformer_ctc_large"
parameters: ASRDiarizerParams = ASRDiarizerParams()
ctc_decoder_parameters: ASRDiarizerCTCDecoderParams = ASRDiarizerCTCDecoderParams()
realigning_lm_parameters: ASRRealigningLMParams = ASRRealigningLMParams()
parameters: ASRDiarizerParams = field(default_factory=lambda: ASRDiarizerParams())
ctc_decoder_parameters: ASRDiarizerCTCDecoderParams = field(default_factory=lambda: ASRDiarizerCTCDecoderParams())
realigning_lm_parameters: ASRRealigningLMParams = field(default_factory=lambda: ASRRealigningLMParams())


@dataclass
Expand All @@ -102,7 +102,7 @@ class VADParams(DiarizerComponentConfig):
class VADConfig(DiarizerComponentConfig):
model_path: str = "vad_multilingual_marblenet" # .nemo local model path or pretrained VAD model name
external_vad_manifest: Optional[str] = None
parameters: VADParams = VADParams()
parameters: VADParams = field(default_factory=lambda: VADParams())


@dataclass
Expand All @@ -121,7 +121,7 @@ class SpeakerEmbeddingsParams(DiarizerComponentConfig):
class SpeakerEmbeddingsConfig(DiarizerComponentConfig):
# .nemo local model path or pretrained model name (titanet_large, ecapa_tdnn or speakerverification_speakernet)
model_path: Optional[str] = None
parameters: SpeakerEmbeddingsParams = SpeakerEmbeddingsParams()
parameters: SpeakerEmbeddingsParams = field(default_factory=lambda: SpeakerEmbeddingsParams())


@dataclass
Expand All @@ -142,7 +142,7 @@ class ClusteringParams(DiarizerComponentConfig):

@dataclass
class ClusteringConfig(DiarizerComponentConfig):
parameters: ClusteringParams = ClusteringParams()
parameters: ClusteringParams = field(default_factory=lambda: ClusteringParams())


@dataclass
Expand All @@ -166,7 +166,7 @@ class MSDDParams(DiarizerComponentConfig):
@dataclass
class MSDDConfig(DiarizerComponentConfig):
model_path: Optional[str] = "diar_msdd_telephonic"
parameters: MSDDParams = MSDDParams()
parameters: MSDDParams = field(default_factory=lambda: MSDDParams())


@dataclass
Expand All @@ -176,16 +176,16 @@ class DiarizerConfig(DiarizerComponentConfig):
oracle_vad: bool = False # If True, uses RTTM files provided in the manifest file to get VAD timestamps
collar: float = 0.25 # Collar value for scoring
ignore_overlap: bool = True # Consider or ignore overlap segments while scoring
vad: VADConfig = VADConfig()
speaker_embeddings: SpeakerEmbeddingsConfig = SpeakerEmbeddingsConfig()
clustering: ClusteringConfig = ClusteringConfig()
msdd_model: MSDDConfig = MSDDConfig()
asr: ASRDiarizerConfig = ASRDiarizerConfig()
vad: VADConfig = field(default_factory=lambda: VADConfig())
speaker_embeddings: SpeakerEmbeddingsConfig = field(default_factory=lambda: SpeakerEmbeddingsConfig())
clustering: ClusteringConfig = field(default_factory=lambda: ClusteringConfig())
msdd_model: MSDDConfig = field(default_factory=lambda: MSDDConfig())
asr: ASRDiarizerConfig = field(default_factory=lambda: ASRDiarizerConfig())


@dataclass
class NeuralDiarizerInferenceConfig(DiarizerComponentConfig):
diarizer: DiarizerConfig = DiarizerConfig()
diarizer: DiarizerConfig = field(default_factory=lambda: DiarizerConfig())
device: str = "cpu"
verbose: bool = False
batch_size: int = 64
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from dataclasses import dataclass, field

from nemo.collections.asr.models.configs.asr_models_config import EncDecCTCConfig
from nemo.collections.asr.parts.k2.classes import GraphModuleConfig as BackendConfig
Expand All @@ -26,14 +26,14 @@ class GraphModuleConfig:
split_batch_size: int = 0
dec_type: str = "topo"
transcribe_training: bool = True
backend_cfg: BackendConfig = BackendConfig()
backend_cfg: BackendConfig = field(default_factory=lambda: BackendConfig())


@dataclass
class EncDecK2SeqConfig(EncDecCTCConfig):
graph_module_cfg: GraphModuleConfig = GraphModuleConfig()
graph_module_cfg: GraphModuleConfig = field(default_factory=lambda: GraphModuleConfig())


@dataclass
class EncDecK2SeqModelConfig(NemoConfig):
model: EncDecK2SeqConfig = EncDecK2SeqConfig()
model: EncDecK2SeqConfig = field(default_factory=lambda: EncDecK2SeqConfig())
36 changes: 22 additions & 14 deletions nemo/collections/asr/models/configs/matchboxnet_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,30 +107,38 @@ class MatchboxNetModelConfig(clf_cfg.EncDecClassificationConfig):
labels: List[str] = MISSING

# Dataset configs
train_ds: clf_cfg.EncDecClassificationDatasetConfig = clf_cfg.EncDecClassificationDatasetConfig(
manifest_filepath=None, shuffle=True, trim_silence=False
train_ds: clf_cfg.EncDecClassificationDatasetConfig = field(
default_factory=lambda: clf_cfg.EncDecClassificationDatasetConfig(
manifest_filepath=None, shuffle=True, trim_silence=False
)
)
validation_ds: clf_cfg.EncDecClassificationDatasetConfig = clf_cfg.EncDecClassificationDatasetConfig(
manifest_filepath=None, shuffle=False
validation_ds: clf_cfg.EncDecClassificationDatasetConfig = field(
default_factory=lambda: clf_cfg.EncDecClassificationDatasetConfig(manifest_filepath=None, shuffle=False)
)
test_ds: clf_cfg.EncDecClassificationDatasetConfig = clf_cfg.EncDecClassificationDatasetConfig(
manifest_filepath=None, shuffle=False
test_ds: clf_cfg.EncDecClassificationDatasetConfig = field(
default_factory=lambda: clf_cfg.EncDecClassificationDatasetConfig(manifest_filepath=None, shuffle=False)
)

# Optimizer / Scheduler config
optim: Optional[model_cfg.OptimConfig] = model_cfg.OptimConfig(sched=model_cfg.SchedConfig())
optim: Optional[model_cfg.OptimConfig] = field(
default_factory=lambda: model_cfg.OptimConfig(sched=model_cfg.SchedConfig())
)

# Model general component configs
preprocessor: AudioToMFCCPreprocessorConfig = AudioToMFCCPreprocessorConfig(window_size=0.025)
spec_augment: Optional[SpectrogramAugmentationConfig] = SpectrogramAugmentationConfig(
freq_masks=2, time_masks=2, freq_width=15, time_width=25, rect_masks=5, rect_time=25, rect_freq=15
preprocessor: AudioToMFCCPreprocessorConfig = field(
default_factory=lambda: AudioToMFCCPreprocessorConfig(window_size=0.025)
)
spec_augment: Optional[SpectrogramAugmentationConfig] = field(
default_factory=lambda: SpectrogramAugmentationConfig(
freq_masks=2, time_masks=2, freq_width=15, time_width=25, rect_masks=5, rect_time=25, rect_freq=15
)
)
crop_or_pad_augment: Optional[CropOrPadSpectrogramAugmentationConfig] = CropOrPadSpectrogramAugmentationConfig(
audio_length=128
crop_or_pad_augment: Optional[CropOrPadSpectrogramAugmentationConfig] = field(
default_factory=lambda: CropOrPadSpectrogramAugmentationConfig(audio_length=128)
)

encoder: ConvASREncoderConfig = ConvASREncoderConfig(activation="relu")
decoder: ConvASRDecoderClassificationConfig = ConvASRDecoderClassificationConfig()
encoder: ConvASREncoderConfig = field(default_factory=lambda: ConvASREncoderConfig(activation="relu"))
decoder: ConvASRDecoderClassificationConfig = field(default_factory=lambda: ConvASRDecoderClassificationConfig())


@dataclass
Expand Down
Loading

0 comments on commit 6af30c9

Please sign in to comment.