Skip to content

Commit

Permalink
Add missing default axis value to symbol.squeeze op (apache#15707)
Browse files Browse the repository at this point in the history
* Add missing default arg

* Add test

* add test
  • Loading branch information
leezu authored and Ubuntu committed Aug 20, 2019
1 parent ae4f93a commit e2fec53
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 5 deletions.
2 changes: 1 addition & 1 deletion python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2539,7 +2539,7 @@ def softmin(self, *args, **kwargs):
"""
return op.softmin(self, *args, **kwargs)

def squeeze(self, axis, inplace=False, **kwargs): # pylint: disable=unused-argument
def squeeze(self, axis=None, inplace=False, **kwargs): # pylint: disable=unused-argument
"""Convenience fluent method for :py:func:`squeeze`.
The arguments are the same as for :py:func:`squeeze`, with
Expand Down
24 changes: 20 additions & 4 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_parameter_dict():
params1.get('w1', shape=(10, 10), stype='row_sparse')
params1.load('test_parameter_dict.params', ctx)
trainer1 = mx.gluon.Trainer(params1, 'sgd')

# compare the values before and after save/load
cur_w0 = params1.get('w0').data(ctx)
cur_w1 = params1.get('w1').row_sparse_data(all_row_ids)
Expand All @@ -134,7 +134,7 @@ def test_parameter_dict():
cur_w1 = params2.get('w1').data(ctx)
mx.test_utils.assert_almost_equal(prev_w0.asnumpy(), cur_w0.asnumpy())
mx.test_utils.assert_almost_equal(prev_w1.asnumpy(), cur_w1.asnumpy())

# test the dtype casting functionality
params0 = gluon.ParameterDict('')
params0.get('w0', shape=(10, 10), dtype='float32')
Expand Down Expand Up @@ -386,7 +386,7 @@ def hybrid_forward(self, F, x):
if 'conv' in param_name and 'weight' in param_name:
break
assert np.dtype(net_fp64.params[param_name].dtype) == np.dtype(np.float64)

# 3.b Verify same functionnality with the imports API
net_fp_64 = mx.gluon.SymbolBlock.imports(sym_file, 'data', params_file, ctx=ctx)

Expand Down Expand Up @@ -2788,7 +2788,7 @@ def test_gluon_param_load():
net.cast('float16')
net.load_parameters('test_gluon_param_load.params', cast_dtype=True)
mx.nd.waitall()

@with_seed()
def test_gluon_param_load_dtype_source():
net = mx.gluon.nn.Dense(10, in_units=10)
Expand All @@ -2800,6 +2800,22 @@ def test_gluon_param_load_dtype_source():
assert net.weight.dtype == np.float16
mx.nd.waitall()

@with_seed()
def test_squeeze_consistency():
class Foo(gluon.HybridBlock):
def __init__(self, inplace, **kwargs):
super(Foo, self).__init__(**kwargs)
self.inplace = inplace

def forward(self, x):
return x.squeeze(inplace=self.inplace)

for inplace in (True, False):
block = Foo(inplace)
block.hybridize()
shape = (np.random.randint(1, 10), np.random.randint(1, 10), 1)
block(mx.nd.ones(shape))

if __name__ == '__main__':
import nose
nose.runmodule()
1 change: 1 addition & 0 deletions tests/python/unittest/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def check_fluent_regular(func, kwargs, shape=(5, 17, 1), equal_nan=False):
check_fluent_regular('reshape', {'shape': (17, 1, 5)})
check_fluent_regular('broadcast_to', {'shape': (5, 17, 47)})
check_fluent_regular('squeeze', {'axis': (1, 3)}, shape=(2, 1, 3, 1, 4))
check_fluent_regular('squeeze', {}, shape=(2, 1, 3, 1, 4))

def check_symbol_consistency(sym1, sym2, ctx, skip_grad=False, equal_nan=False):
assert sym1.list_arguments() == sym2.list_arguments()
Expand Down

0 comments on commit e2fec53

Please sign in to comment.