diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index deedf0fe83d2..1e2defab3713 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -2539,7 +2539,7 @@ def softmin(self, *args, **kwargs): """ return op.softmin(self, *args, **kwargs) - def squeeze(self, axis, inplace=False, **kwargs): # pylint: disable=unused-argument + def squeeze(self, axis=None, inplace=False, **kwargs): # pylint: disable=unused-argument """Convenience fluent method for :py:func:`squeeze`. The arguments are the same as for :py:func:`squeeze`, with diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index b59ce2d0864c..af30980b10ea 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -115,7 +115,7 @@ def test_parameter_dict(): params1.get('w1', shape=(10, 10), stype='row_sparse') params1.load('test_parameter_dict.params', ctx) trainer1 = mx.gluon.Trainer(params1, 'sgd') - + # compare the values before and after save/load cur_w0 = params1.get('w0').data(ctx) cur_w1 = params1.get('w1').row_sparse_data(all_row_ids) @@ -134,7 +134,7 @@ def test_parameter_dict(): cur_w1 = params2.get('w1').data(ctx) mx.test_utils.assert_almost_equal(prev_w0.asnumpy(), cur_w0.asnumpy()) mx.test_utils.assert_almost_equal(prev_w1.asnumpy(), cur_w1.asnumpy()) - + # test the dtype casting functionality params0 = gluon.ParameterDict('') params0.get('w0', shape=(10, 10), dtype='float32') @@ -386,7 +386,7 @@ def hybrid_forward(self, F, x): if 'conv' in param_name and 'weight' in param_name: break assert np.dtype(net_fp64.params[param_name].dtype) == np.dtype(np.float64) - + # 3.b Verify same functionnality with the imports API net_fp_64 = mx.gluon.SymbolBlock.imports(sym_file, 'data', params_file, ctx=ctx) @@ -2788,7 +2788,7 @@ def test_gluon_param_load(): net.cast('float16') net.load_parameters('test_gluon_param_load.params', cast_dtype=True) mx.nd.waitall() - + @with_seed() def test_gluon_param_load_dtype_source(): net = mx.gluon.nn.Dense(10, in_units=10) @@ -2800,6 +2800,22 @@ def test_gluon_param_load_dtype_source(): assert net.weight.dtype == np.float16 mx.nd.waitall() +@with_seed() +def test_squeeze_consistency(): + class Foo(gluon.HybridBlock): + def __init__(self, inplace, **kwargs): + super(Foo, self).__init__(**kwargs) + self.inplace = inplace + + def forward(self, x): + return x.squeeze(inplace=self.inplace) + + for inplace in (True, False): + block = Foo(inplace) + block.hybridize() + shape = (np.random.randint(1, 10), np.random.randint(1, 10), 1) + block(mx.nd.ones(shape)) + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py index 2dfe3e44eedb..0c97c68b0880 100644 --- a/tests/python/unittest/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -242,6 +242,7 @@ def check_fluent_regular(func, kwargs, shape=(5, 17, 1), equal_nan=False): check_fluent_regular('reshape', {'shape': (17, 1, 5)}) check_fluent_regular('broadcast_to', {'shape': (5, 17, 47)}) check_fluent_regular('squeeze', {'axis': (1, 3)}, shape=(2, 1, 3, 1, 4)) + check_fluent_regular('squeeze', {}, shape=(2, 1, 3, 1, 4)) def check_symbol_consistency(sym1, sym2, ctx, skip_grad=False, equal_nan=False): assert sym1.list_arguments() == sym2.list_arguments()