diff --git a/cpp-package/include/mxnet-cpp/symbol.h b/cpp-package/include/mxnet-cpp/symbol.h index 1c825c1502af..a25824cad602 100644 --- a/cpp-package/include/mxnet-cpp/symbol.h +++ b/cpp-package/include/mxnet-cpp/symbol.h @@ -178,6 +178,8 @@ class Symbol { std::vector ListOutputs() const; /*! \return get the descriptions of auxiliary data for this symbol */ std::vector ListAuxiliaryStates() const; + /*! \return get the name of the symbol */ + std::string GetName() const; /*! * \brief infer and construct all the arrays to bind to executor by providing * some known arrays. diff --git a/cpp-package/include/mxnet-cpp/symbol.hpp b/cpp-package/include/mxnet-cpp/symbol.hpp index 11590fad6041..b82e060ca8da 100644 --- a/cpp-package/include/mxnet-cpp/symbol.hpp +++ b/cpp-package/include/mxnet-cpp/symbol.hpp @@ -172,6 +172,14 @@ inline std::vector Symbol::ListAuxiliaryStates() const { return ret; } +inline std::string Symbol::GetName() const { + int success; + const char* out_name; + CHECK_EQ(MXSymbolGetName(GetHandle(), &out_name, &success), 0); + CHECK_EQ(success, 1); + return std::string(out_name); +} + inline void Symbol::InferShape( const std::map > &arg_shapes, std::vector > *in_shape,