Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix edge case when casting gluon Block before export #18853

Merged
merged 2 commits into from
Aug 4, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,7 @@ def cast(self, dtype):
The new data type.
"""
self._dtype = dtype
self._var = None # Clear Symbol Variable as it caches the dtype
szha marked this conversation as resolved.
Show resolved Hide resolved
if self._data is None:
return
with autograd.pause():
Expand Down
7 changes: 6 additions & 1 deletion tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ def test_hybrid_sequential_unique_internals():


@with_seed()
def test_symbol_block(tmpdir):
@pytest.mark.parametrize('compute_before_cast', [True, False])
def test_symbol_block(tmpdir, compute_before_cast):
model = nn.HybridSequential()
model.add(nn.Dense(128, activation='tanh'))
model.add(nn.Dropout(0.5))
Expand Down Expand Up @@ -309,6 +310,10 @@ def hybrid_forward(self, F, x):
ctx = mx.cpu(0)

net_fp32 = mx.gluon.model_zoo.vision.resnet34_v2(pretrained=True, ctx=ctx, root=tmp)
if compute_before_cast:
# Compute before casting to catch bugs where symbol dtype isn't casted correctly GH-18843
net_fp32.initialize()
net_fp32(mx.nd.zeros((1,3,224,224)))
net_fp32.cast('float64')
net_fp32.hybridize()
data = mx.nd.zeros((1,3,224,224), dtype='float64', ctx=ctx)
Expand Down