@@ -99,7 +99,7 @@ class RNNTLossConfig:
99
99
min_version = '0.53.0' ,
100
100
is_available = NUMBA_RNNT_AVAILABLE ,
101
101
installation_msg = NUMBA_INSTALLATION_MESSAGE ,
102
- force_float32 = not numba_utils . NUMBA_FP16_SUPPORTED ,
102
+ force_float32 = False , # This is only temporarily false, will be dynamically updated during resolution
103
103
),
104
104
"pytorch" : RNNTLossConfig (
105
105
loss_name = "pytorch" ,
@@ -258,6 +258,9 @@ def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None)
258
258
_warn_unused_additional_kwargs (loss_name , loss_kwargs )
259
259
260
260
elif loss_name == 'warprnnt_numba' :
261
+ # Update loss config's forced float32 flag if set to None
262
+ loss_config .force_float32 = not numba_utils .is_numba_cuda_fp16_supported ()
263
+
261
264
fastemit_lambda = loss_kwargs .pop ('fastemit_lambda' , 0.0 )
262
265
clamp = loss_kwargs .pop ('clamp' , - 1.0 )
263
266
loss_func = RNNTLossNumba (blank = blank_idx , reduction = 'none' , fastemit_lambda = fastemit_lambda , clamp = clamp )
@@ -444,7 +447,7 @@ def forward(self, log_probs, targets, input_lengths, target_lengths):
444
447
max_targets_len = target_lengths .max ()
445
448
446
449
# Force cast joint to float32
447
- if not self ._force_float32 and numba_utils .NUMBA_FP16_SUPPORTED :
450
+ if not self ._force_float32 and numba_utils .is_numba_cuda_fp16_supported () :
448
451
# Execute the kernel in fp16
449
452
pass
450
453
elif self ._force_float32 and log_probs .dtype != torch .float32 :
0 commit comments