Skip to content

Commit bb26e98

Browse files
authored
Enable CUDA graphs by default only for transcription (#9196)
* Enable CUDA graphs only for transcription. Sync streams before capture. Signed-off-by: Vladimir Bataev <[email protected]> * Apply isort and black reformatting Signed-off-by: artbataev <[email protected]> --------- Signed-off-by: Vladimir Bataev <[email protected]> Signed-off-by: artbataev <[email protected]> Co-authored-by: artbataev <[email protected]>
1 parent da34ee7 commit bb26e98

File tree

6 files changed

+266
-138
lines changed

6 files changed

+266
-138
lines changed

examples/asr/transcribe_speech.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig
3030
from nemo.collections.asr.parts.submodules.multitask_decoding import MultiTaskDecoding, MultiTaskDecodingConfig
3131
from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig
32+
from nemo.collections.asr.parts.submodules.rnnt_greedy_decoding import GreedyBatchedRNNTInferConfig
3233
from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer
3334
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
3435
from nemo.collections.asr.parts.utils.transcribe_utils import (
@@ -121,9 +122,9 @@ class TranscriptionConfig:
121122
pretrained_name: Optional[str] = None # Name of a pretrained model
122123
audio_dir: Optional[str] = None # Path to a directory which contains audio files
123124
dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest
124-
channel_selector: Optional[
125-
Union[int, str]
126-
] = None # Used to select a single channel from multichannel audio, or use average across channels
125+
channel_selector: Optional[Union[int, str]] = (
126+
None # Used to select a single channel from multichannel audio, or use average across channels
127+
)
127128
audio_key: str = 'audio_filepath' # Used to override the default audio key in dataset_manifest
128129
eval_config_yaml: Optional[str] = None # Path to a yaml file of config of evaluation
129130
presort_manifest: bool = True # Significant inference speedup on short-form data due to padding reduction
@@ -161,7 +162,10 @@ class TranscriptionConfig:
161162
ctc_decoding: CTCDecodingConfig = CTCDecodingConfig()
162163

163164
# Decoding strategy for RNNT models
164-
rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig(fused_batch_size=-1)
165+
# enable CUDA graphs for transcription
166+
rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig(
167+
fused_batch_size=-1, greedy=GreedyBatchedRNNTInferConfig(use_cuda_graph_decoder=True)
168+
)
165169

166170
# Decoding strategy for AED models
167171
multitask_decoding: MultiTaskDecodingConfig = MultiTaskDecodingConfig()
@@ -407,7 +411,10 @@ def autocast(dtype=None):
407411
override_cfg.augmentor = augmentor
408412
override_cfg.text_field = cfg.gt_text_attr_name
409413
override_cfg.lang_field = cfg.gt_lang_attr_name
410-
transcriptions = asr_model.transcribe(audio=filepaths, override_config=override_cfg,)
414+
transcriptions = asr_model.transcribe(
415+
audio=filepaths,
416+
override_config=override_cfg,
417+
)
411418

412419
if cfg.dataset_manifest is not None:
413420
logging.info(f"Finished transcribing from manifest file: {cfg.dataset_manifest}")

examples/asr/transcribe_speech_parallel.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
from nemo.collections.asr.models import ASRModel, EncDecHybridRNNTCTCModel
8585
from nemo.collections.asr.models.configs.asr_models_config import ASRDatasetConfig
8686
from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig
87+
from nemo.collections.asr.parts.submodules.rnnt_greedy_decoding import GreedyBatchedRNNTInferConfig
8788
from nemo.core.config import TrainerConfig, hydra_runner
8889
from nemo.utils import logging
8990
from nemo.utils.get_rank import is_global_rank_zero
@@ -100,7 +101,10 @@ class ParallelTranscriptionConfig:
100101
use_cer: bool = False
101102

102103
# decoding strategy for RNNT models
103-
rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig()
104+
# enable CUDA graphs for transcription
105+
rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig(
106+
fused_batch_size=-1, greedy=GreedyBatchedRNNTInferConfig(use_cuda_graph_decoder=True)
107+
)
104108

105109
# decoder type: ctc or rnnt, can be used to switch between CTC and RNNT decoder for Hybrid RNNT/CTC models
106110
decoder_type: Optional[str] = None

nemo/collections/asr/parts/submodules/rnnt_decoding.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
331331
preserve_frame_confidence=self.preserve_frame_confidence,
332332
confidence_method_cfg=self.confidence_method_cfg,
333333
loop_labels=self.cfg.greedy.get('loop_labels', True),
334-
use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', True),
334+
use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', False),
335335
)
336336
else:
337337
self.decoding = rnnt_greedy_decoding.GreedyBatchedTDTInfer(
@@ -347,7 +347,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
347347
preserve_frame_confidence=self.preserve_frame_confidence,
348348
include_duration_confidence=self.tdt_include_duration_confidence,
349349
confidence_method_cfg=self.confidence_method_cfg,
350-
use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', True),
350+
use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', False),
351351
)
352352

353353
else:
@@ -1175,7 +1175,11 @@ class RNNTDecoding(AbstractRNNTDecoding):
11751175
"""
11761176

11771177
def __init__(
1178-
self, decoding_cfg, decoder, joint, vocabulary,
1178+
self,
1179+
decoding_cfg,
1180+
decoder,
1181+
joint,
1182+
vocabulary,
11791183
):
11801184
# we need to ensure blank is the last token in the vocab for the case of RNNT and Multi-blank RNNT.
11811185
blank_id = len(vocabulary) + joint.num_extra_outputs
@@ -1186,7 +1190,10 @@ def __init__(
11861190
self.labels_map = dict([(i, vocabulary[i]) for i in range(len(vocabulary))])
11871191

11881192
super(RNNTDecoding, self).__init__(
1189-
decoding_cfg=decoding_cfg, decoder=decoder, joint=joint, blank_id=blank_id,
1193+
decoding_cfg=decoding_cfg,
1194+
decoder=decoder,
1195+
joint=joint,
1196+
blank_id=blank_id,
11901197
)
11911198

11921199
if isinstance(self.decoding, rnnt_beam_decoding.BeamRNNTInfer):

nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@
4545
from nemo.utils import logging
4646

4747

48-
def pack_hypotheses(hypotheses: List[rnnt_utils.Hypothesis], logitlen: torch.Tensor,) -> List[rnnt_utils.Hypothesis]:
48+
def pack_hypotheses(
49+
hypotheses: List[rnnt_utils.Hypothesis],
50+
logitlen: torch.Tensor,
51+
) -> List[rnnt_utils.Hypothesis]:
4952

5053
if hasattr(logitlen, 'cpu'):
5154
logitlen_cpu = logitlen.to('cpu')
@@ -139,8 +142,7 @@ class _GreedyRNNTInfer(Typing, ConfidenceMethodMixin):
139142

140143
@property
141144
def input_types(self):
142-
"""Returns definitions of module input ports.
143-
"""
145+
"""Returns definitions of module input ports."""
144146
return {
145147
"encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
146148
"encoded_lengths": NeuralType(tuple('B'), LengthsType()),
@@ -149,8 +151,7 @@ def input_types(self):
149151

150152
@property
151153
def output_types(self):
152-
"""Returns definitions of module output ports.
153-
"""
154+
"""Returns definitions of module output ports."""
154155
return {"predictions": [NeuralType(elements_type=HypothesisType())]}
155156

156157
def __init__(
@@ -578,6 +579,7 @@ class GreedyBatchedRNNTInfer(_GreedyRNNTInfer, WithOptionalCudaGraphs):
578579
(evaluating Joint multiple times in inner loop); It uses a minimal possible amount of calls
579580
to prediction network (with maximum possible batch size),
580581
which makes it especially useful for scaling the prediction network.
582+
use_cuda_graph_decoder: if CUDA graphs should be enabled for decoding (currently recommended only for inference)
581583
"""
582584

583585
def __init__(
@@ -590,7 +592,7 @@ def __init__(
590592
preserve_frame_confidence: bool = False,
591593
confidence_method_cfg: Optional[DictConfig] = None,
592594
loop_labels: bool = True,
593-
use_cuda_graph_decoder: bool = True,
595+
use_cuda_graph_decoder: bool = False,
594596
):
595597
super().__init__(
596598
decoder_model=decoder_model,
@@ -2358,7 +2360,7 @@ class GreedyBatchedRNNTInferConfig:
23582360
tdt_include_duration_confidence: bool = False
23592361
confidence_method_cfg: Optional[ConfidenceMethodConfig] = field(default_factory=lambda: ConfidenceMethodConfig())
23602362
loop_labels: bool = True
2361-
use_cuda_graph_decoder: bool = True
2363+
use_cuda_graph_decoder: bool = False
23622364

23632365
def __post_init__(self):
23642366
# OmegaConf.structured ensures that post_init check is always executed
@@ -2695,6 +2697,8 @@ class GreedyBatchedTDTInfer(_GreedyRNNTInfer, WithOptionalCudaGraphs):
26952697
Supported values:
26962698
- 'lin' for using the linear mapping.
26972699
- 'exp' for using exponential mapping with linear shift.
2700+
2701+
use_cuda_graph_decoder: if CUDA graphs should be enabled for decoding (currently recommended only for inference)
26982702
"""
26992703

27002704
def __init__(
@@ -2708,7 +2712,7 @@ def __init__(
27082712
preserve_frame_confidence: bool = False,
27092713
include_duration_confidence: bool = False,
27102714
confidence_method_cfg: Optional[DictConfig] = None,
2711-
use_cuda_graph_decoder: bool = True,
2715+
use_cuda_graph_decoder: bool = False,
27122716
):
27132717
super().__init__(
27142718
decoder_model=decoder_model,

0 commit comments

Comments
 (0)