diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index 4d1efdc9fe65..2f1f1152a4c2 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -662,6 +662,7 @@ def cast(self, dtype): The new data type. """ self._dtype = dtype + self._var = None # Clear Symbol Variable as it caches the dtype if self._data is None: return with autograd.pause(): diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index a934fbe12e58..71ae41c2cfa6 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -259,7 +259,8 @@ def test_hybrid_sequential_unique_internals(): @with_seed() -def test_symbol_block(tmpdir): +@pytest.mark.parametrize('compute_before_cast', [True, False]) +def test_symbol_block(tmpdir, compute_before_cast): model = nn.HybridSequential() model.add(nn.Dense(128, activation='tanh')) model.add(nn.Dropout(0.5)) @@ -309,6 +310,10 @@ def hybrid_forward(self, F, x): ctx = mx.cpu(0) net_fp32 = mx.gluon.model_zoo.vision.resnet34_v2(pretrained=True, ctx=ctx, root=tmp) + if compute_before_cast: + # Compute before casting to catch bugs where symbol dtype isn't casted correctly GH-18843 + net_fp32.initialize() + net_fp32(mx.nd.zeros((1,3,224,224), ctx=ctx)) net_fp32.cast('float64') net_fp32.hybridize() data = mx.nd.zeros((1,3,224,224), dtype='float64', ctx=ctx)