From d1403d718431686aa4183f33978895808276c065 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 29 May 2019 22:47:28 +0000 Subject: [PATCH] check unknown shape correctly. --- python/mxnet/gluon/parameter.py | 6 +++--- python/mxnet/util.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index b6b58a24f8fe..9d6f8a3bb88b 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -30,7 +30,7 @@ from .. import symbol, ndarray, initializer, context from ..context import Context, cpu from .. import autograd -from .utils import _indent, _brief_print_list +from .utils import _indent, _brief_print_list, shape_is_known from .. import is_np_shape # pylint: disable= invalid-name @@ -273,7 +273,7 @@ def _finish_deferred_init(self): return init, ctx, default_init, data = self._deferred_init self._deferred_init = () - assert self.shape is not None and np.prod(self.shape) > 0, \ + assert shape_is_known(self.shape), \ "Cannot initialize Parameter '%s' because it has " \ "invalid shape: %s. Please specify in_units, " \ "in_channels, etc for `Block`s."%( @@ -384,7 +384,7 @@ def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(), ctx = [ctx] if init is None: init = default_init if self.init is None else self.init - if not self.shape or np.prod(self.shape) <= 0: + if not shape_is_known(self.shape): if self._allow_deferred_init: self._deferred_init = (init, ctx, default_init, None) return diff --git a/python/mxnet/util.py b/python/mxnet/util.py index 29f5b78e454e..97442b126d6a 100644 --- a/python/mxnet/util.py +++ b/python/mxnet/util.py @@ -245,3 +245,17 @@ def _with_np_shape(*args, **kwargs): return func(*args, **kwargs) return _with_np_shape + +def shape_is_known(shape): + """Check whether a shape is completely known w/ or w/o np semantics.""" + if shape is None: + return False + unknown_dim_size = -1 if is_np_shape() else 0 + if len(shape) == 0: + return unknown_dim_size == -1 + for dim_size in shape: + if dim_size == unknown_dim_size: + return False + assert dim_size > unknown_dim_size, "shape dimension size cannot be less than {}, while " \ + "received {}".format(unknown_dim_size, dim_size) + return True