From c51cfab3b88a702ee5c99321dcefcc079d6f8df4 Mon Sep 17 00:00:00 2001 From: guanxinq Date: Tue, 7 Jan 2020 19:36:31 +0000 Subject: [PATCH] add RandomApply in gluon's transforms --- python/mxnet/gluon/data/vision/transforms.py | 25 +++++++++++++++++++ .../python/unittest/test_gluon_data_vision.py | 15 +++++++++++ 2 files changed, 40 insertions(+) diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py index 935ce2738a6f..acaa7a5d33d9 100644 --- a/python/mxnet/gluon/data/vision/transforms.py +++ b/python/mxnet/gluon/data/vision/transforms.py @@ -581,3 +581,28 @@ def hybrid_forward(self, F, x): if is_np_array(): F = F.npx return F.image.random_lighting(x, self._alpha) + +class RandomApply(Sequential): + """Apply a list of transformations randomly given probability + + Parameters + ---------- + Inputs: + - **transforms**: list of transformations + - **p**: probability + + Outputs: + Transformed image. + """ + + def __init__(self, transform, p=0.5): + super(RandomApply, self).__init__() + self.transform = transform + self.p = p + + def forward(self, x): + import random + if self.p < random.random(): + return x + x = self.transform(x) + return x diff --git a/tests/python/unittest/test_gluon_data_vision.py b/tests/python/unittest/test_gluon_data_vision.py index 8bc0f8072260..11a0b89eb136 100644 --- a/tests/python/unittest/test_gluon_data_vision.py +++ b/tests/python/unittest/test_gluon_data_vision.py @@ -229,6 +229,21 @@ def test_transformer(): transform(mx.nd.ones((245, 480, 3), dtype='uint8')).wait_to_read() +@with_seed() +def test_randomtransforms(): + from mxnet.gluon.data.vision import transforms + + transform = transforms.Compose([transforms.RandomApply(transforms.Compose([transforms.Resize(300), transforms.RandomResizedCrop(224)]), 0.5)]) + + img = mx.nd.ones((245, 480, 3), dtype='uint8') + iteration = 1000 + num_apply = 0 + for i in range(iteration): + out = transform(img) + if out.shape[0] == 224: + num_apply += 1 + assert(abs(num_apply/float(iteration)-0.5) < 1e-1) + if __name__ == '__main__': import nose