From 4530ad8f4b7b661a985ad668f2668f9476848a81 Mon Sep 17 00:00:00 2001 From: Abhinav Sharma Date: Mon, 8 Apr 2019 11:07:15 -0700 Subject: [PATCH] Fix aspect ratio sampling for RandomResizedCrop (#14585) * added log sampling for aspect ratio * added test * added comments * fix test * remove math, fix test --- python/mxnet/image/image.py | 6 ++---- tests/python/unittest/test_image.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py index d2631e810529..8bcf724ac4d2 100644 --- a/python/mxnet/image/image.py +++ b/python/mxnet/image/image.py @@ -585,14 +585,12 @@ def random_size_crop(src, size, area, ratio, interp=2, **kwargs): area = (area, 1.0) for _ in range(10): target_area = random.uniform(area[0], area[1]) * src_area - new_ratio = random.uniform(*ratio) + log_ratio = (np.log(ratio[0]), np.log(ratio[1])) + new_ratio = np.exp(random.uniform(*log_ratio)) new_w = int(round(np.sqrt(target_area * new_ratio))) new_h = int(round(np.sqrt(target_area / new_ratio))) - if random.random() < 0.5: - new_h, new_w = new_w, new_h - if new_w <= w and new_h <= h: x0 = random.randint(0, w - new_w) y0 = random.randint(0, h - new_h) diff --git a/tests/python/unittest/test_image.py b/tests/python/unittest/test_image.py index e0abbd75ef8e..6b212da26d62 100644 --- a/tests/python/unittest/test_image.py +++ b/tests/python/unittest/test_image.py @@ -355,6 +355,19 @@ def test_det_augmenters(self): for batch in det_iter: pass + @with_seed() + def test_random_size_crop(self): + # test aspect ratio within bounds + width = np.random.randint(100, 500) + height = np.random.randint(100, 500) + src = np.random.rand(height, width, 3) * 255. + ratio = (0.75, 1) + out, (x0, y0, new_w, new_h) = mx.image.random_size_crop(mx.nd.array(src), size=(width, height), area=0.08, ratio=ratio) + _, pts = mx.image.center_crop(mx.nd.array(src), size=(width, height)) + if (x0, y0, new_w, new_h) != pts: + assert ratio[0] <= float(new_w)/new_h <= ratio[1] + + if __name__ == '__main__': import nose nose.runmodule()