Skip to content

Commit 574c047

Browse files
committed
discretization added to trainer class
1 parent d831779 commit 574c047

File tree

3 files changed

+16
-8
lines changed

3 files changed

+16
-8
lines changed

bayesmedaug/augmentations/functional.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ class ShiftY():
392392
* x>=0: 100, x<0: -100
393393
"""
394394
def __init__(self, shift_amount_y):
395-
self.shift_amount_y = int(shift_amount_y)
395+
self.shift_amount_y = shift_amount_y
396396
self.to_mask = True
397397
self.to_img = True
398398

@@ -428,7 +428,7 @@ class ShiftX():
428428
* x>=0: 100, x<0: -100
429429
"""
430430
def __init__(self, shift_amount_x):
431-
self.shift_amount_x = int(shift_amount_x)
431+
self.shift_amount_x = shift_amount_x
432432
self.to_mask = True
433433
self.to_img = True
434434

bayesmedaug/callbacks/trainer.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
roundup,
2626
discrete_angle,
2727
discrete_shift,
28-
discrete_rcrop
28+
discrete_rcrop,
29+
discrete_shear
2930
)
3031

3132
class Trainer():
@@ -107,13 +108,17 @@ def train(self, **params):
107108

108109
if "angle" in params.keys():
109110
params["angle"] = discrete_angle(params["angle"])
110-
if "shift_x" in params.keys():
111-
params["shift_x"] = discrete_shift(params["shift_x"])
112-
if "shift_y" in params.keys():
113-
params["shift_y"] = discrete_shift(params["shift_y"])
114-
111+
if "shift_amount_x" in params.keys():
112+
params["shift_amount_x"] = discrete_shift(params["shift_amount_x"])
113+
if "shift_amount_y" in params.keys():
114+
params["shift_amount_y"] = discrete_shift(params["shift_amount_y"])
115115
if "crop_height" in params.keys():
116116
params["crop_height"] = discrete_rcrop(params["crop_height"])
117+
if "shear_amount_y" in params.keys():
118+
params["shear_amount_y"] = discrete_shear(params["shear_amount_y"])
119+
if "shear_amount_x" in params.keys():
120+
params["shear_amount_x"] = discrete_shear(params["shear_amount_x"])
121+
117122

118123
if type(self.augmentations) == Listed:
119124
transform = self.augmentations(**params)

bayesmedaug/utils/discretize.py

+3
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,6 @@ def discrete_shift(x):
1818

1919
def discrete_rcrop(x):
2020
return int(x * 1000)
21+
22+
def discrete_shear(x):
23+
return x * 10

0 commit comments

Comments
 (0)