Skip to content

Commit

Permalink
Fix edge case when casting gluon Block before export
Browse files Browse the repository at this point in the history
  • Loading branch information
leezu committed Aug 3, 2020
1 parent 9fd2cce commit 644c730
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
1 change: 1 addition & 0 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,7 @@ def cast(self, dtype):
The new data type.
"""
self._dtype = dtype
self._var = None # Clear Symbol Variable as it caches the dtype
if self._data is None:
return
with autograd.pause():
Expand Down
7 changes: 6 additions & 1 deletion tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ def test_hybrid_sequential_unique_internals():


@with_seed()
def test_symbol_block(tmpdir):
@pytest.mark.parametrize('compute_before_cast', [True, False])
def test_symbol_block(tmpdir, compute_before_cast):
model = nn.HybridSequential()
model.add(nn.Dense(128, activation='tanh'))
model.add(nn.Dropout(0.5))
Expand Down Expand Up @@ -309,6 +310,10 @@ def hybrid_forward(self, F, x):
ctx = mx.cpu(0)

net_fp32 = mx.gluon.model_zoo.vision.resnet34_v2(pretrained=True, ctx=ctx, root=tmp)
if compute_before_cast:
# Compute before casting to catch bugs where symbol dtype isn't casted correctly GH-18843
net_fp32.initialize()
net_fp32(mx.nd.zeros((1,3,224,224)))
net_fp32.cast('float64')
net_fp32.hybridize()
data = mx.nd.zeros((1,3,224,224), dtype='float64', ctx=ctx)
Expand Down

0 comments on commit 644c730

Please sign in to comment.