38
38
from nemo .collections .asr .losses .rnnt_pytorch import MultiblankRNNTLossPytorch , RNNTLossPytorch , TDTLossPytorch
39
39
from nemo .core .classes import Loss , typecheck
40
40
from nemo .core .neural_types import LabelsType , LengthsType , LogprobsType , LossType , NeuralType
41
+ from nemo .core .utils import numba_utils
41
42
from nemo .core .utils .k2_utils import K2_INSTALLATION_MESSAGE
42
43
from nemo .core .utils .numba_utils import NUMBA_INSTALLATION_MESSAGE
43
- from nemo .utils import logging , model_utils
44
+ from nemo .utils import logging , logging_mode , model_utils
44
45
45
46
try :
46
47
import warprnnt_pytorch as warprnnt
@@ -98,7 +99,7 @@ class RNNTLossConfig:
98
99
min_version = '0.53.0' ,
99
100
is_available = NUMBA_RNNT_AVAILABLE ,
100
101
installation_msg = NUMBA_INSTALLATION_MESSAGE ,
101
- force_float32 = True ,
102
+ force_float32 = not numba_utils . NUMBA_FP16_SUPPORTED ,
102
103
),
103
104
"pytorch" : RNNTLossConfig (
104
105
loss_name = "pytorch" ,
@@ -387,7 +388,7 @@ def __init__(self, num_classes, reduction: str = 'mean_batch', loss_name: str =
387
388
for the standard "blank" symbol. In particular, say V is the number of non-blank tokens in
388
389
the vocabulary, then in the case of,
389
390
standard RNNT: num_classes = V
390
- multiblank RNNT: num_classes = V + number-big-blanks (since we store big-blanks before
391
+ multiblank RNNT: num_classes = V + number-big-blanks (since we store big-blanks before
391
392
standard blank, and the standard blank is the last symbol in the vocab)
392
393
TDT: num_classes = V. Note, V here does not include any of the "duration outputs".
393
394
@@ -413,6 +414,7 @@ def __init__(self, num_classes, reduction: str = 'mean_batch', loss_name: str =
413
414
self .reduction = reduction
414
415
self ._loss = resolve_rnnt_loss (loss_name , blank_idx = self ._blank , loss_kwargs = loss_kwargs )
415
416
self ._force_float32 = RNNT_LOSS_RESOLVER [loss_name ].force_float32
417
+ self ._fp16_compat_checked = False
416
418
417
419
def reduce (self , losses , target_lengths ):
418
420
@@ -442,8 +444,22 @@ def forward(self, log_probs, targets, input_lengths, target_lengths):
442
444
max_targets_len = target_lengths .max ()
443
445
444
446
# Force cast joint to float32
445
- # TODO: Remove once Numba supports FP16
446
- if self ._force_float32 and log_probs .dtype != torch .float32 :
447
+ if not self ._force_float32 and numba_utils .NUMBA_FP16_SUPPORTED :
448
+ # Execute the kernel in fp16
449
+ pass
450
+ elif self ._force_float32 and log_probs .dtype != torch .float32 :
451
+ # Log just once if fp16 tensor was passed and fp16 Numba CUDA loss could not be used.
452
+ if log_probs .dtype == torch .float16 and not self ._fp16_compat_checked :
453
+ _ , reason = numba_utils .is_numba_cuda_fp16_supported (return_reason = True )
454
+ logging .warning (
455
+ f"Provided RNNT Joint tensor is of dtype { log_probs .dtype } , but RNNT loss could not be calculated "
456
+ f"in fp16 due to following reason stated below. Loss will be calculated in fp32. \n \n "
457
+ f"{ reason } " ,
458
+ mode = logging_mode .ONCE ,
459
+ )
460
+ self ._fp16_compat_checked = True
461
+
462
+ # Upcast the activation tensor and compute loss and grads in fp32
447
463
logits_orig = log_probs
448
464
log_probs = log_probs .float ()
449
465
del logits_orig # save memory *before* computing the loss
0 commit comments