Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add RandomApply in gluon's transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
guanxinq committed Jan 7, 2020
1 parent 634f95e commit c51cfab
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
25 changes: 25 additions & 0 deletions python/mxnet/gluon/data/vision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,3 +581,28 @@ def hybrid_forward(self, F, x):
if is_np_array():
F = F.npx
return F.image.random_lighting(x, self._alpha)

class RandomApply(Sequential):
"""Apply a list of transformations randomly given probability
Parameters
----------
Inputs:
- **transforms**: list of transformations
- **p**: probability
Outputs:
Transformed image.
"""

def __init__(self, transform, p=0.5):
super(RandomApply, self).__init__()
self.transform = transform
self.p = p

def forward(self, x):
import random
if self.p < random.random():
return x
x = self.transform(x)
return x
15 changes: 15 additions & 0 deletions tests/python/unittest/test_gluon_data_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,21 @@ def test_transformer():
transform(mx.nd.ones((245, 480, 3), dtype='uint8')).wait_to_read()


@with_seed()
def test_randomtransforms():
from mxnet.gluon.data.vision import transforms

transform = transforms.Compose([transforms.RandomApply(transforms.Compose([transforms.Resize(300), transforms.RandomResizedCrop(224)]), 0.5)])

img = mx.nd.ones((245, 480, 3), dtype='uint8')
iteration = 1000
num_apply = 0
for i in range(iteration):
out = transform(img)
if out.shape[0] == 224:
num_apply += 1
assert(abs(num_apply/float(iteration)-0.5) < 1e-1)


if __name__ == '__main__':
import nose
Expand Down

0 comments on commit c51cfab

Please sign in to comment.