diff --git a/cpp-package/include/mxnet-cpp/ndarray.hpp b/cpp-package/include/mxnet-cpp/ndarray.hpp index b667542bffb5..b4022bae9431 100644 --- a/cpp-package/include/mxnet-cpp/ndarray.hpp +++ b/cpp-package/include/mxnet-cpp/ndarray.hpp @@ -34,6 +34,7 @@ #include "dmlc/logging.h" #include "mxnet-cpp/ndarray.h" #include "mxnet-cpp/operator.h" +#include "mxnet/ndarray.h" namespace mxnet { namespace cpp { @@ -391,6 +392,9 @@ inline mx_float NDArray::At(size_t c, size_t h, size_t w) const { } inline size_t NDArray::Size() const { + NDArrayHandle handle = GetHandle(); + if (static_cast(handle)->is_none()) + return 0u; size_t ret = 1; for (auto &i : GetShape()) ret *= i; return ret;