Skip to content

Commit 507f2c7

Browse files
committed
Fixed a bug that made StaticVolume crash on cpu; updated test_devices to validate differences between devices
1 parent 234f97c commit 507f2c7

File tree

3 files changed

+33
-10
lines changed

3 files changed

+33
-10
lines changed

tests/test_devices.py

+31-9
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
import numpy as np
33
import cupy as cp
44
import matplotlib.pyplot as plt
5+
plt.ioff()
56

67
size = (50, 50, 50)
78
data = np.random.random(size).astype(np.float32)
9+
rotation = np.random.uniform(0, 180, 3)
10+
scale = np.random.uniform(0.7, 1.3, 3)
811

912
rows = [
1013
[
@@ -24,33 +27,52 @@
2427
[
2528
{
2629
'interpolation': 'linear',
27-
'device': 'gpu:0'
30+
'device': 'gpu'
2831
},
2932
{
3033
'interpolation': 'bspline',
31-
'device': 'gpu:0'
34+
'device': 'gpu'
3235
},
3336
{
3437
'interpolation': 'filt_bspline',
35-
'device': 'gpu:0'
38+
'device': 'gpu'
3639
},
3740
]
3841
]
3942

43+
### testing transforms methods
4044
fig, ax = plt.subplots(len(rows), len(rows[0]), sharex=True, sharey=True)
4145
for i, r in enumerate(rows):
4246
for j, case in enumerate(r):
4347
print(f'Test case: {case["interpolation"]} / {case["device"]}')
44-
tf = vt.transform(data, rotation=(0, 30, 0), interpolation=case['interpolation'],
48+
tf = vt.transform(data, rotation=rotation, scale=scale, interpolation=case['interpolation'],
4549
profile=True, device=case['device'])
4650

4751
ax[i][j].set_title(f'{case["interpolation"]} / {case["device"]}')
48-
49-
if isinstance(tf, cp.ndarray):
50-
ax[i][j].imshow(tf[size[0] // 2].get())
51-
else:
52-
ax[i][j].imshow(tf[size[0] // 2])
52+
ax[i][j].imshow(tf[size[0] // 2])
5353

5454
plt.show()
5555

5656

57+
### testing static volume methods
58+
print('\n\n\n')
59+
st_volumes = [
60+
vt.StaticVolume(data, interpolation='linear', device='cpu'),
61+
vt.StaticVolume(data, interpolation='bspline', device='cpu'),
62+
vt.StaticVolume(data, interpolation='filt_bspline', device='cpu'),
63+
vt.StaticVolume(data, interpolation='linear', device='gpu'),
64+
vt.StaticVolume(data, interpolation='bspline', device='gpu'),
65+
vt.StaticVolume(data, interpolation='filt_bspline', device='gpu')
66+
]
67+
68+
fig, ax = plt.subplots(2, 3, sharex=True, sharey=True)
69+
70+
for n, v in enumerate(st_volumes):
71+
print(f'Test case: {v.interpolation} / {v.device}')
72+
tf = v.transform(scale=scale, rotation=rotation, profile=True)
73+
i, j = int(n / 3), n % 3
74+
ax[i][j].set_title(f'{v.interpolation} / {v.device}')
75+
ax[i][j].imshow(tf[size[0] // 2])
76+
77+
plt.show()
78+

voltools/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = '0.3.2'
1+
__version__ = '0.3.3'
22

33
from .transforms import AVAILABLE_INTERPOLATIONS, AVAILABLE_DEVICES, scale, shear, rotate, translate, transform, affine
44
from .volume import StaticVolume

voltools/volume.py

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(self, data: np.ndarray, interpolation: str = 'linear', device: str
5252
del data
5353

5454
elif device == 'cpu':
55+
self.shape = data.shape
5556
self.data = data
5657

5758
def affine(self, transform_m: np.ndarray, profile: bool = False, output: cp.ndarray = None) -> Union[np.ndarray, None]:

0 commit comments

Comments
 (0)