diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cpu_utils/cpu_rnnt.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cpu_utils/cpu_rnnt.py index 3feb7b513a50..bcc1865ab78e 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cpu_utils/cpu_rnnt.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cpu_utils/cpu_rnnt.py @@ -203,10 +203,13 @@ def __init__( self.num_threads_ = num_threads self.batch_first = batch_first + _torch_num_threads = torch.get_num_threads() if num_threads > 0: numba.set_num_threads(min(multiprocessing.cpu_count(), num_threads)) + self.num_threads_ = numba.get_num_threads() else: self.num_threads_ = numba.get_num_threads() + torch.set_num_threads(_torch_num_threads) def cost_and_grad_kernel( self, diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py index 70ffb459cb97..87d6ee147dea 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py @@ -82,11 +82,13 @@ def __init__( self.num_threads_ = num_threads self.stream_ = stream # type: cuda.cudadrv.driver.Stream + _torch_num_threads = torch.get_num_threads() if num_threads > 0: numba.set_num_threads(min(multiprocessing.cpu_count(), num_threads)) self.num_threads_ = numba.get_num_threads() else: self.num_threads_ = numba.get_num_threads() + torch.set_num_threads(_torch_num_threads) def log_softmax(self, acts: torch.Tensor, denom: torch.Tensor): """