Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
check unknown shape correctly.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed May 29, 2019
1 parent dc690b6 commit d1403d7
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from .. import symbol, ndarray, initializer, context
from ..context import Context, cpu
from .. import autograd
from .utils import _indent, _brief_print_list
from .utils import _indent, _brief_print_list, shape_is_known
from .. import is_np_shape

# pylint: disable= invalid-name
Expand Down Expand Up @@ -273,7 +273,7 @@ def _finish_deferred_init(self):
return
init, ctx, default_init, data = self._deferred_init
self._deferred_init = ()
assert self.shape is not None and np.prod(self.shape) > 0, \
assert shape_is_known(self.shape), \
"Cannot initialize Parameter '%s' because it has " \
"invalid shape: %s. Please specify in_units, " \
"in_channels, etc for `Block`s."%(
Expand Down Expand Up @@ -384,7 +384,7 @@ def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(),
ctx = [ctx]
if init is None:
init = default_init if self.init is None else self.init
if not self.shape or np.prod(self.shape) <= 0:
if not shape_is_known(self.shape):
if self._allow_deferred_init:
self._deferred_init = (init, ctx, default_init, None)
return
Expand Down
14 changes: 14 additions & 0 deletions python/mxnet/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,17 @@ def _with_np_shape(*args, **kwargs):
return func(*args, **kwargs)

return _with_np_shape

def shape_is_known(shape):
"""Check whether a shape is completely known w/ or w/o np semantics."""
if shape is None:
return False
unknown_dim_size = -1 if is_np_shape() else 0
if len(shape) == 0:
return unknown_dim_size == -1
for dim_size in shape:
if dim_size == unknown_dim_size:
return False
assert dim_size > unknown_dim_size, "shape dimension size cannot be less than {}, while " \
"received {}".format(unknown_dim_size, dim_size)
return True

0 comments on commit d1403d7

Please sign in to comment.