From 644c730bcae38543cacf1e7cfaa54fe1bdf9236d Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Mon, 3 Aug 2020 18:15:16 +0000 Subject: [PATCH] Fix edge case when casting gluon Block before export Fixes https://github.com/apache/incubator-mxnet/issues/18843 --- python/mxnet/gluon/parameter.py | 1 + tests/python/unittest/test_gluon.py | 7 ++++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index 4d1efdc9fe65..2f1f1152a4c2 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -662,6 +662,7 @@ def cast(self, dtype): The new data type. """ self._dtype = dtype + self._var = None # Clear Symbol Variable as it caches the dtype if self._data is None: return with autograd.pause(): diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index a934fbe12e58..da209f473ead 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -259,7 +259,8 @@ def test_hybrid_sequential_unique_internals(): @with_seed() -def test_symbol_block(tmpdir): +@pytest.mark.parametrize('compute_before_cast', [True, False]) +def test_symbol_block(tmpdir, compute_before_cast): model = nn.HybridSequential() model.add(nn.Dense(128, activation='tanh')) model.add(nn.Dropout(0.5)) @@ -309,6 +310,10 @@ def hybrid_forward(self, F, x): ctx = mx.cpu(0) net_fp32 = mx.gluon.model_zoo.vision.resnet34_v2(pretrained=True, ctx=ctx, root=tmp) + if compute_before_cast: + # Compute before casting to catch bugs where symbol dtype isn't casted correctly GH-18843 + net_fp32.initialize() + net_fp32(mx.nd.zeros((1,3,224,224))) net_fp32.cast('float64') net_fp32.hybridize() data = mx.nd.zeros((1,3,224,224), dtype='float64', ctx=ctx)