From d236d20ecff25a1e40f8fc3220eb1c0d3695aac4 Mon Sep 17 00:00:00 2001 From: guanxinq Date: Tue, 7 Jan 2020 19:36:31 +0000 Subject: [PATCH] Add RandomApply in gluon's transforms --- python/mxnet/gluon/data/vision/transforms.py | 31 +++++++++++++++++++ .../python/unittest/test_gluon_data_vision.py | 16 ++++++++++ 2 files changed, 47 insertions(+) diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py index 935ce2738a6f..0be5578360df 100644 --- a/python/mxnet/gluon/data/vision/transforms.py +++ b/python/mxnet/gluon/data/vision/transforms.py @@ -19,6 +19,7 @@ # pylint: disable= arguments-differ "Image transforms." +import random from ...block import Block, HybridBlock from ...nn import Sequential, HybridSequential from .... import image @@ -581,3 +582,33 @@ def hybrid_forward(self, F, x): if is_np_array(): F = F.npx return F.image.random_lighting(x, self._alpha) + + +class RandomApply(Sequential): + """Apply a list of transformations randomly given probability + + Parameters + ---------- + transforms + List of transformations. + p : float + Probability of applying the transformations. + + + Inputs: + - **data**: input tensor. + + Outputs: + - **out**: transformed image. + """ + + def __init__(self, transforms, p=0.5): + super(RandomApply, self).__init__() + self.transforms = transforms + self.p = p + + def forward(self, x): + if self.p < random.random(): + return x + x = self.transforms(x) + return x diff --git a/tests/python/unittest/test_gluon_data_vision.py b/tests/python/unittest/test_gluon_data_vision.py index 8bc0f8072260..71efb72b9ce5 100644 --- a/tests/python/unittest/test_gluon_data_vision.py +++ b/tests/python/unittest/test_gluon_data_vision.py @@ -229,6 +229,22 @@ def test_transformer(): transform(mx.nd.ones((245, 480, 3), dtype='uint8')).wait_to_read() +@with_seed() +def test_random_transforms(): + from mxnet.gluon.data.vision import transforms + + tmp_t = transforms.Compose([transforms.Resize(300), transforms.RandomResizedCrop(224)]) + transform = transforms.Compose([transforms.RandomApply(tmp_t, 0.5)]) + + img = mx.nd.ones((10, 10, 3), dtype='uint8') + iteration = 1000 + num_apply = 0 + for _ in range(iteration): + out = transform(img) + if out.shape[0] == 224: + num_apply += 1 + assert_almost_equal(num_apply/float(iteration), 0.5, 0.1) + if __name__ == '__main__': import nose