From 478c60a36e2d21aa49e6c5b0327e86b38ca3f454 Mon Sep 17 00:00:00 2001 From: mseth10 Date: Wed, 14 Aug 2019 01:52:46 +0000 Subject: [PATCH] API to trigger partitioning --- include/mxnet/c_api.h | 9 +++++++++ python/mxnet/symbol/symbol.py | 6 ++++++ 2 files changed, 15 insertions(+) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 5ab10b6b2204..41c0df19bc5a 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1987,6 +1987,15 @@ MXNET_DLL int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle, const float* high_quantiles, SymbolHandle* ret_sym_handle); +/*! + * \brief Partitions symbol for given backend, potentially creating subgraphs + * \param sym_handle symbol to be partitioned + * \param backend backend name + * \param ret_sym_handle partitioned symbol returned + */ +MXNET_DLL int MXOptimizeForBackend(SymbolHandle sym_handle, + const char* backend, + SymbolHandle* ret_sym_handle); /*! * \brief Run subgraph pass based on the backend provided * \param sym_handle symbol to be converted diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 3ac44176a87b..4b71177afaa8 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -1437,6 +1437,12 @@ def _gen_atomic_symbol(self): return Symbol(handle) + def optimizeFor(self, backend): + """Partition symbol and optimize it for a given backend""" + out = SymbolHandle() + check_call(_LIB.MXOptimizeForBackend(self.handle, c_str(backend), ctypes.byref(out))) + + # pylint: disable=too-many-locals def simple_bind(self, ctx, grad_req='write', type_dict=None, stype_dict=None, group2ctx=None, shared_arg_names=None, shared_exec=None,