diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index 861542220927..83a5674c45bd 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -38,6 +38,7 @@ class requests_failed_to_import(object): import numpy as np from .. import ndarray +from ..util import is_np_shape def split_data(data, num_slice, batch_axis=0, even_split=True): """Splits an NDArray into `num_slice` slices along `batch_axis`. @@ -412,3 +413,17 @@ def __enter__(self): def __exit__(self, ptype, value, trace): self.detach() + +def shape_is_known(shape): + """Check whether a shape is completely known w/ or w/o np semantics.""" + if shape is None: + return False + unknown_dim_size = -1 if is_np_shape() else 0 + if len(shape) == 0: + return unknown_dim_size == -1 + for dim_size in shape: + if dim_size == unknown_dim_size: + return False + assert dim_size > unknown_dim_size, "shape dimension size cannot be less than {}, while " \ + "received {}".format(unknown_dim_size, dim_size) + return True diff --git a/python/mxnet/util.py b/python/mxnet/util.py index 97442b126d6a..29f5b78e454e 100644 --- a/python/mxnet/util.py +++ b/python/mxnet/util.py @@ -245,17 +245,3 @@ def _with_np_shape(*args, **kwargs): return func(*args, **kwargs) return _with_np_shape - -def shape_is_known(shape): - """Check whether a shape is completely known w/ or w/o np semantics.""" - if shape is None: - return False - unknown_dim_size = -1 if is_np_shape() else 0 - if len(shape) == 0: - return unknown_dim_size == -1 - for dim_size in shape: - if dim_size == unknown_dim_size: - return False - assert dim_size > unknown_dim_size, "shape dimension size cannot be less than {}, while " \ - "received {}".format(unknown_dim_size, dim_size) - return True