Skip to content

Commit 0212ba8

Browse files
committed
just embrace the numba and add it in the test...
1 parent 47a0beb commit 0212ba8

File tree

3 files changed

+6
-13
lines changed

3 files changed

+6
-13
lines changed

.github/workflows/ci.yml

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ jobs:
5353
run: |
5454
pip install pytest
5555
pip install torch torchvision
56+
pip install numba
5657
pytest tests
5758
5859
# - name: Install from dist

tests/equi2pers/numpy_inv.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,7 @@
55

66
from timeit import timeit
77

8-
try:
9-
from numba import jit
10-
except ImportError:
11-
print("numba not available")
12-
jit = None
8+
from numba import jit
139

1410
import numpy as np
1511

@@ -59,7 +55,7 @@ def hdinv(A):
5955
return invA
6056

6157

62-
# @jit("float64[:,:](float64[:,:])", cache=True, nopython=True, nogil=True)
58+
@jit("float64[:,:](float64[:,:])", cache=True, nopython=True, nogil=True)
6359
def fast_inverse(A):
6460
inv = np.empty_like(A)
6561
a = A[0, 0]
@@ -89,7 +85,7 @@ def fast_inverse(A):
8985
return inv
9086

9187

92-
# @jit(cache=True, nopython=True, nogil=True)
88+
@jit(cache=True, nopython=True, nogil=True)
9389
def vecinv(A):
9490
invA = np.zeros_like(A)
9591
for i in range(A.shape[0]):

tests/grid_sample/numpy/nearest.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,7 @@
44
55
"""
66

7-
try:
8-
from numba import njit
9-
except ImportError:
10-
print("numba not available")
11-
njit = None
7+
from numba import njit
128

139
import numpy as np
1410

@@ -80,7 +76,7 @@ def faster_nearest(
8076
return out
8177

8278

83-
# @njit
79+
@njit
8480
def run(img, grid, out, b, h, w):
8581
for i in range(b):
8682
for y_out in range(h):

0 commit comments

Comments
 (0)