Skip to content

Commit

Permalink
fix: numba.*_num_threads resets torch num_threads #8141 (#8145)
Browse files Browse the repository at this point in the history
temporary fix until numba/numba#9387 gets resolved.

Signed-off-by: Iztok Lebar Bajec <[email protected]>
  • Loading branch information
itzsimpl authored Jan 11, 2024
1 parent 24d4344 commit 03e7cf1
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,13 @@ def __init__(
self.num_threads_ = num_threads
self.batch_first = batch_first

_torch_num_threads = torch.get_num_threads()
if num_threads > 0:
numba.set_num_threads(min(multiprocessing.cpu_count(), num_threads))
self.num_threads_ = numba.get_num_threads()
else:
self.num_threads_ = numba.get_num_threads()
torch.set_num_threads(_torch_num_threads)

def cost_and_grad_kernel(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@ def __init__(
self.num_threads_ = num_threads
self.stream_ = stream # type: cuda.cudadrv.driver.Stream

_torch_num_threads = torch.get_num_threads()
if num_threads > 0:
numba.set_num_threads(min(multiprocessing.cpu_count(), num_threads))
self.num_threads_ = numba.get_num_threads()
else:
self.num_threads_ = numba.get_num_threads()
torch.set_num_threads(_torch_num_threads)

def log_softmax(self, acts: torch.Tensor, denom: torch.Tensor):
"""
Expand Down

0 comments on commit 03e7cf1

Please sign in to comment.