diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index d4f756f5333c..846f90ded116 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -113,6 +113,11 @@ typedef void (*EngineFuncParamDeleter)(void*); typedef void (*ExecutorMonitorCallback)(const char*, NDArrayHandle, void*); +/*! \brief Monitor callback called at operator level for cached op */ +typedef void (*CachedOpMonitorCallback)(const char*, + const char*, + NDArrayHandle); + struct NativeOpInfo { void (*forward)(int, float**, int*, unsigned**, int*, void*); @@ -1286,6 +1291,13 @@ MXNET_DLL int MXInvokeCachedOpEx(CachedOpHandle handle, NDArrayHandle **outputs, const int** out_stypes); +/*! + * \brief cached op set monitor callback + */ +MXNET_DLL int MXCachedOpRegisterOpHook(NDArrayHandle handle, + CachedOpMonitorCallback callback, + bool monitor_all); + //-------------------------------------------- // Part 3: symbolic configuration generation //-------------------------------------------- diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py index b1a38c1d2621..0d5dade2f163 100644 --- a/python/mxnet/_ctypes/ndarray.py +++ b/python/mxnet/_ctypes/ndarray.py @@ -29,6 +29,13 @@ from ..base import check_call +def _monitor_callback_wrapper(callback): + """A wrapper for the user-defined handle.""" + def callback_handle(name, opr_name, array, _): + """ ctypes function """ + callback(name, opr_name, array) + return callback_handle + class NDArrayBase(object): """Base data structure for ndarray""" __slots__ = ["handle", "writable"] @@ -112,10 +119,11 @@ def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op): class CachedOp(object): """Cached operator handle.""" - __slots__ = ["handle", "is_np_sym"] + __slots__ = ["handle", "is_np_sym", "_monitor_callback"] def __init__(self, sym, flags=()): self.handle = CachedOpHandle() + self._monitor_callback = None from ..symbol.numpy._symbol import _Symbol self.is_np_sym = bool(isinstance(sym, _Symbol)) @@ -170,3 +178,21 @@ def __call__(self, *args, **kwargs): else: return [create_ndarray_fn(ctypes.cast(output_vars[i], NDArrayHandle), stype=out_stypes[i]) for i in range(num_output.value)] + + def _register_op_hook(self, callback, monitor_all=False): + """Install callback for monitor. + + Parameters + ---------- + callback : function + Takes a string for node_name, string for op_name and a NDArrayHandle. + monitor_all : bool, default False + If true, monitor both input _imperative_invoked output, otherwise monitor output only. + """ + cb_type = ctypes.CFUNCTYPE(None, ctypes.c_char_p, ctypes.c_char_p, NDArrayHandle, ctypes.c_void_p) + if callback: + self._monitor_callback = cb_type(_monitor_callback_wrapper(callback)) + check_call(_LIB.MXCachedOpRegisterOpHook( + self.handle, + self._monitor_callback, + ctypes.c_int(monitor_all))) diff --git a/python/mxnet/cython/base.pyi b/python/mxnet/cython/base.pyi index 548afc782763..0a35555aabb3 100644 --- a/python/mxnet/cython/base.pyi +++ b/python/mxnet/cython/base.pyi @@ -2,13 +2,18 @@ from ..base import MXNetError from libcpp.vector cimport vector from libcpp.string cimport string +from libcpp cimport bool as _bool from cpython.version cimport PY_MAJOR_VERSION ctypedef void* SymbolHandle ctypedef void* NDArrayHandle ctypedef void* OpHandle ctypedef void* CachedOpHandle +ctypedef void* MonitorCallbackHandle ctypedef unsigned nn_uint +ctypedef void (*CachedOpMonitorCallback)(const char*, + const char*, + NDArrayHandle) cdef py_str(const char* x): if PY_MAJOR_VERSION < 3: @@ -112,3 +117,6 @@ cdef extern from "mxnet/c_api.h": int *num_outputs, NDArrayHandle **outputs, const int **out_stypes); + int MXCachedOpRegisterOpHook(NDArrayHandle handle, + CachedOpMonitorCallback callback, + _bool monitor_all); diff --git a/python/mxnet/cython/ndarray.pyx b/python/mxnet/cython/ndarray.pyx index 50791e9b9a86..74ad1c4c3d49 100644 --- a/python/mxnet/cython/ndarray.pyx +++ b/python/mxnet/cython/ndarray.pyx @@ -22,6 +22,7 @@ import ctypes as _ctypes import numpy as np from ..ndarray_doc import _build_doc from libc.stdint cimport uint32_t, int64_t +from ..base import _LIB include "./base.pyi" @@ -47,7 +48,6 @@ cdef class NDArrayBase: return _ctypes.cast(self.chandle, _ctypes.c_void_p) def __set__(self, value): self._set_handle(value) - property writable: def __get__(self): return bool(self.cwritable) @@ -75,6 +75,10 @@ def _set_np_ndarray_class(cls): global _np_ndarray_cls _np_ndarray_cls = cls +def _monitor_callback_wrapper(callback): + def callback_handle(name, opr_name, arr, _): + callback(name, opr_name, arr) + return callback_handle cdef NewArray(NDArrayHandle handle, int stype=-1, int is_np_array=0): """Create a new array given handle""" @@ -103,6 +107,7 @@ cdef class CachedOp: self._set_handle(value) cdef int is_np_sym + cdef readonly object mhandle def __init__(self, sym, flags=()): cdef vector[string] s_flag_keys @@ -169,6 +174,15 @@ cdef class CachedOp: else: return [NewArray(p_output_vars[i], p_output_stypes[i], self.is_np_sym) for i in range(num_output)] + def _register_op_hook(self, callback, monitor_all=False): + cb_type = _ctypes.CFUNCTYPE(None, _ctypes.c_char_p, _ctypes.c_char_p, _ctypes.c_void_p, _ctypes.c_void_p) + if callback: + self.mhandle = cb_type(_monitor_callback_wrapper(callback)) + chandle = _ctypes.cast(self.chandle, _ctypes.c_void_p) + CALL(_LIB.MXCachedOpRegisterOpHook(chandle, + self.mhandle, + _ctypes.c_int(monitor_all))) + def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op=0): """cython implementation of imperative invoke wrapper""" diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 97e6e8b68453..fc08b4c6bd32 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -590,6 +590,19 @@ def forward(self, *args): # pylint: disable= invalid-name raise NotImplementedError + def register_op_hook(self, callback, monitor_all=False): + """Install callback monitor. + + Parameters + ---------- + callback : function + Takes a string and a NDArrayHandle. + monitor_all : bool, default False + If true, monitor both input and output, otherwise monitor output only. + """ + for cld in self._children.values(): + cld.register_op_hook(callback, monitor_all) + def summary(self, *inputs): """Print the summary of the model's output and parameters. @@ -754,6 +767,8 @@ def __init__(self, prefix=None, params=None): self._in_format = None self._active = False self._flags = [] + self._callback = None + self._monitor_all = False def __setattr__(self, name, value): """Registers parameters.""" @@ -833,6 +848,12 @@ def _deferred_infer_shape(self, *args): def _call_cached_op(self, *args): if self._cached_op is None: self._build_cache(*args) + assert self._cached_op, "cached op is not None" + if self._callback: + self._cached_op._register_op_hook(self._callback, self._monitor_all) + if len(self._flags) >= 2 and (self._flags[1] or self._flags[0]): + warnings.warn("register_op_hook is experimental when static_alloc=True / static_shape=True " + " and may not work correctly") args, fmt = _flatten(args, "input") assert fmt == self._in_format, "Invalid input format" @@ -938,6 +959,22 @@ def export(self, path, epoch=0, remove_amp_cast=True): save_fn = _mx_npx.save if is_np_array() else ndarray.save save_fn('%s-%04d.params'%(path, epoch), arg_dict) + def register_op_hook(self, callback, monitor_all=False): + """Install op hook for block recursively. + + Parameters + ---------- + callback : function + Takes a string and a NDArrayHandle. + monitor_all : bool, default False + If true, monitor both input and output, otherwise monitor output only. + """ + self._callback = callback + self._monitor_all = monitor_all + for cld in self._children.values(): + cld._callback = callback + cld._monitor_all = monitor_all + def forward(self, x, *args): """Defines the forward computation. Arguments can be either :py:class:`NDArray` or :py:class:`Symbol`.""" diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index 4546659ca64e..e661c2268500 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -378,3 +378,23 @@ int MXAutogradGetSymbol(NDArrayHandle handle, SymbolHandle *out) { *out = reinterpret_cast(sym); API_END(); } + +int MXCachedOpRegisterOpHook(NDArrayHandle handle, + CachedOpMonitorCallback callback, + bool monitor_all) { + API_BEGIN(); + CachedOpMonitorCallback callback_temp = nullptr; + std::function clbk; + if (callback) { + callback_temp = callback; + clbk = [callback_temp](const char *name, const char *opr_name, + void *handle) { + callback_temp(name, opr_name, handle); + }; + } else { + clbk = nullptr; + } + CachedOpPtr op = *static_cast(handle); + op->RegisterOpHook(clbk, monitor_all); + API_END(); +} diff --git a/src/common/utils.cc b/src/common/utils.cc index 9fe46d94d036..032a324c96b0 100644 --- a/src/common/utils.cc +++ b/src/common/utils.cc @@ -51,5 +51,62 @@ void CastStorageDispatch(const OpContext& ctx, mxnet::op::CastStorageComputeImpl(ctx, input, output); } +void ExecuteMonInputCallback( + const nnvm::IndexedGraph &idx, const std::vector &state_arrays, + size_t nid, const std::function + &monitor_callback) { + static const auto &flist_inputs = + nnvm::Op::GetAttr("FListInputNames"); + std::vector input_names; + const nnvm::IndexedGraph::Node &inode = idx[nid]; + const nnvm::Node *node = inode.source; + if (flist_inputs.count(node->op())) { + input_names = flist_inputs[node->op()](node->attrs); + } else { + for (size_t i = 0; i < node->num_inputs(); ++i) { + input_names.emplace_back("input" + std::to_string(i)); + } + } + + for (size_t i = 0; i < node->num_inputs(); ++i) { + const nnvm::NodeEntry &input = node->inputs[i]; + if (state_arrays[idx.entry_id(input)]->is_none()) { + continue; + } + NDArray *cpy = new NDArray(*state_arrays[idx.entry_id(input)]); + std::string name = inode.source->attrs.name + "_" + input_names[i]; + monitor_callback(name.c_str(), inode.source->op()->name.c_str(), + reinterpret_cast(cpy)); + } +} + +void ExecuteMonOutputCallback( + const nnvm::IndexedGraph &idx, const std::vector &state_arrays, + size_t nid, const std::function + &monitor_callback) { + static const auto &flist_outputs = + nnvm::Op::GetAttr("FListOutputNames"); + std::vector output_names; + const nnvm::IndexedGraph::Node &inode = idx[nid]; + const nnvm::Node *node = inode.source; + if (flist_outputs.count(node->op())) { + output_names = flist_outputs[node->op()](node->attrs); + } else { + for (size_t i = 0; i < node->num_outputs(); ++i) { + output_names.emplace_back(std::to_string(i)); + } + } + + for (size_t i = 0; i < node->num_outputs(); ++i) { + if (state_arrays[idx.entry_id(nid, i)]->is_none()) { + continue; + } + NDArray *cpy = new NDArray(*state_arrays[idx.entry_id(nid, i)]); + std::string name = inode.source->attrs.name + "_" + output_names[i]; + monitor_callback(name.c_str(), inode.source->op()->name.c_str(), + reinterpret_cast(cpy)); + } +} + } // namespace common } // namespace mxnet diff --git a/src/common/utils.h b/src/common/utils.h index 9dad5df84fd2..c94b9fa9718a 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -803,6 +803,15 @@ inline void ConvertToLegacyShape(mxnet::ShapeVector* shapes) { ConvertToLegacyShape(&(shapes->at(i))); } } +void ExecuteMonInputCallback( + const nnvm::IndexedGraph &idx, const std::vector &state_arrays, + size_t nid, const std::function + &monitor_callback); + +void ExecuteMonOutputCallback( + const nnvm::IndexedGraph &idx, const std::vector &state_arrays, + size_t nid, const std::function + &monitor_callback); } // namespace common } // namespace mxnet diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index efe38019cfda..6818d757ab79 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -697,6 +697,9 @@ void CachedOp::StaticRunOps( ndinputs.emplace_back(state_arrays[idx.entry_id(j)]); CHECK(!ndinputs.back()->is_none()); } + if (monitor_callback_ && monitor_all_) { + mxnet::common::ExecuteMonInputCallback(idx, state_arrays, i, monitor_callback_); + } ndoutputs.clear(); ndoutputs.reserve(num_outputs); req.clear(); @@ -708,6 +711,7 @@ void CachedOp::StaticRunOps( CHECK(req.back() == kNullOp || !ndoutputs.back()->is_none()); } const DispatchMode dispatch_mode = dispatch_modes[i]; + if (createop.count(node.source->op())) { arg_shapes.clear(); arg_dtypes.clear(); @@ -735,6 +739,9 @@ void CachedOp::StaticRunOps( default_ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode); } + if (monitor_callback_) { + mxnet::common::ExecuteMonOutputCallback(idx, state_arrays, i, monitor_callback_); + } } } } @@ -883,12 +890,12 @@ OpStatePtr CachedOp::DynamicForward( // So if it's not the inline mode, we disable recording. RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs), std::move(ref_count), &states, dispatch_modes, - recording && inlining_); + recording && inlining_, nullptr, monitor_callback_, monitor_all_); } else { mxnet::ShapeVector shapes = g.GetAttr("shape"); NaiveRunGraph(false, default_ctx, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs), std::move(ref_count), &states, - dispatch_modes, recording && inlining_, &shapes); + dispatch_modes, recording && inlining_, &shapes, monitor_callback_, monitor_all_); { auto state_ptr = GetCachedOpState(default_ctx); auto& state = state_ptr.get_state(); @@ -1028,7 +1035,7 @@ void CachedOp::DynamicBackward( RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(), std::move(array_reqs), std::move(ref_count), &states, dispatch_modes, - Imperative::Get()->is_recording()); + Imperative::Get()->is_recording(), nullptr, monitor_callback_); if (retain_graph) { buff.resize(num_forward_entries); @@ -1295,6 +1302,16 @@ void CachedOpBackward(const OpStatePtr& state_ptr, CopyFromTo(out_bufs[i], outputs[i]); } +/* + * Register the callback to be called when the operator is executed + */ +void CachedOp::RegisterOpHook(const CachedOp::CachedOpMonCallback& callback, + bool monitor_all) { + CHECK(callback) << "invalid callback"; + monitor_callback_ = callback; + monitor_all_ = monitor_all; +} + OpStatePtr CreateCachedOpState(const NodeAttrs& attrs, Context ctx, const mxnet::ShapeVector& in_shapes, diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index c45f137b2d63..db049d59ed80 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -74,6 +74,9 @@ struct CachedOpConfig : public dmlc::Parameter { }; class CachedOp { + using CachedOpMonCallback = + std::function; + public: CachedOp( const nnvm::Symbol& sym, @@ -134,6 +137,8 @@ class CachedOp { sym.outputs = fwd_graph_.outputs; return sym; } + void RegisterOpHook(const CachedOp::CachedOpMonCallback& callback, + bool monitor_all = false); private: struct GraphInfo; @@ -203,6 +208,9 @@ class CachedOp { std::vector save_inputs_, save_outputs_; std::vector bwd_output_reqs_; + std::function monitor_callback_{nullptr}; + bool monitor_all_{false}; + std::mutex mutex_; std::unordered_map > cached_op_states_; }; diff --git a/src/imperative/imperative_utils.cc b/src/imperative/imperative_utils.cc index 568d39fc8043..5491457b188f 100644 --- a/src/imperative/imperative_utils.cc +++ b/src/imperative/imperative_utils.cc @@ -137,7 +137,9 @@ void RunGraph( std::vector *p_states, const DispatchModeVector &dispatch_modes, bool recording, - mxnet::ShapeVector *shapes) { + mxnet::ShapeVector *shapes, + const imperative::CachedOpMonCallback& callback, + const bool monitor_all) { CHECK(shapes == nullptr); for (size_t i = node_start; i < node_end; ++i) { const nnvm::IndexedGraph::Node& node = idx[i]; @@ -148,6 +150,9 @@ void RunGraph( std::vector ndoutputs = NodeOutputs(idx, i, arrays); std::vector req = NodeReq(idx, i, array_reqs); Context ctx = ndoutputs[0]->ctx(); + if (callback && monitor_all) { + mxnet::common::ExecuteMonInputCallback(idx, arrays, i, callback); + } auto invoke = [&](const OpStatePtr &state) { const nnvm::IndexedGraph::Node& node = idx[i]; DispatchMode dispatch_mode = dispatch_modes[i]; @@ -159,6 +164,9 @@ void RunGraph( }; InvokeOperator(idx, i, retain_graph, arrays, ctx, p_states, ndinputs, ndoutputs, &req, &ref_count, invoke); + if (callback) { + mxnet::common::ExecuteMonOutputCallback(idx, arrays, i, callback); + } } } @@ -173,7 +181,9 @@ void NaiveRunGraph( std::vector *p_states, const DispatchModeVector &dispatch_modes, bool recording, - mxnet::ShapeVector *shapes) { + mxnet::ShapeVector *shapes, + const imperative::CachedOpMonCallback& callback, + const bool monitor_all) { for (size_t i = node_start; i < node_end; ++i) { const nnvm::IndexedGraph::Node& node = idx[i]; if (node.source->op() == nullptr) { @@ -183,6 +193,9 @@ void NaiveRunGraph( std::vector ndoutputs = NodeOutputs(idx, i, arrays); std::vector req; Context ctx = GetContext(node.source->attrs, ndinputs, ndoutputs, default_ctx); + if (callback && monitor_all) { + mxnet::common::ExecuteMonInputCallback(idx, arrays, i, callback); + } auto invoke = [&](const OpStatePtr &state) { const nnvm::IndexedGraph::Node& node = idx[i]; DispatchMode dispatch_mode = DispatchMode::kUndefined; @@ -205,6 +218,9 @@ void NaiveRunGraph( }; InvokeOperator(idx, i, retain_graph, arrays, ctx, p_states, ndinputs, ndoutputs, &req, &ref_count, invoke); + if (callback) { + mxnet::common::ExecuteMonOutputCallback(idx, arrays, i, callback); + } } } diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 21caafa124f9..c5932bb3bbfe 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -59,6 +59,7 @@ struct EngineOprSeg { }; using MemoryPlanVector = std::vector; +using CachedOpMonCallback = std::function; inline Context GetContext(const nnvm::NodeAttrs& attrs, const std::vector& inputs, @@ -1056,7 +1057,9 @@ void RunGraph(const bool retain_graph, std::vector *p_states, const DispatchModeVector &dispatch_modes, bool recording, - mxnet::ShapeVector *shapes = nullptr); + mxnet::ShapeVector *shapes = nullptr, + const CachedOpMonCallback& callback = nullptr, + const bool monitor_all_ = false); void NaiveRunGraph(const bool retain_graph, const Context& default_ctx, @@ -1068,7 +1071,9 @@ void NaiveRunGraph(const bool retain_graph, std::vector *p_states, const DispatchModeVector &dispatch_modes, bool recording, - mxnet::ShapeVector *shapes); + mxnet::ShapeVector *shapes, + const CachedOpMonCallback& callback = nullptr, + const bool monitor_all_ = false); } // namespace imperative } // namespace mxnet diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index af30980b10ea..46e976432fa8 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -21,6 +21,7 @@ import mxnet as mx from mxnet import gluon from mxnet.gluon import nn +from mxnet.base import py_str from mxnet.test_utils import assert_almost_equal from mxnet.ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID from common import (setup_module, with_seed, assertRaises, teardown, @@ -1503,6 +1504,74 @@ def call_pre_hook(block, x): assert hook_call_count == 1 assert pre_hook_call_count == 2 +@with_seed() +def test_op_hook_output_names(): + def check_name(block, expected_names, inputs=None, expected_opr_names=None, monitor_all=False): + opr_names = [] + output_names = [] + + def mon_callback(node_name, opr_name, arr): + output_names.append(py_str(node_name)) + opr_names.append(py_str(opr_name)) + + block.register_op_hook(mon_callback, monitor_all) + if not inputs: + block(mx.nd.ones((2, 3, 4))) + else: + block(inputs) + + for output_name, expected_name in zip(output_names, expected_names): + print(output_name) + assert output_name == expected_name + + if expected_opr_names: + for opr_name, expected_opr_name in zip(opr_names, expected_opr_names): + assert opr_name == expected_opr_name + + # Test with Dense layer + model = mx.gluon.nn.HybridSequential(prefix="dense_") + with model.name_scope(): + model.add(mx.gluon.nn.Dense(2)) + model.initialize() + model.hybridize() + check_name(model, ["dense_dense0_fwd_output"]) + + # Test with Activation, FListInputNames not registered, input name will have _input appended + model = mx.gluon.nn.HybridSequential(prefix="relu_") + with model.name_scope(): + model.add(mx.gluon.nn.Activation("relu")) + model.initialize() + model.hybridize() + check_name(model, ["relu_relu0_fwd_output"]) + + # Test with Pooling, monitor_all is set to True + model = mx.gluon.nn.HybridSequential("pool_") + with model.name_scope(): + model.add(mx.gluon.nn.AvgPool1D()) + model.initialize() + model.hybridize() + check_name(model, ['pool_pool0_fwd_data', 'pool_pool0_fwd_output'], expected_opr_names=["Pooling"], + monitor_all=True) + + # stack two layers and test + model = mx.gluon.nn.HybridSequential("dense_") + with model.name_scope(): + model.add(mx.gluon.nn.Dense(2)) + model.add(mx.gluon.nn.Activation("relu")) + model.initialize() + model.hybridize() + check_name(model, + ['dense_dense0_fwd_data', 'dense_dense0_fwd_weight', + 'dense_dense0_fwd_bias', 'dense_dense0_fwd_output', + 'dense_relu0_fwd_input0', 'dense_relu0_fwd_output'], monitor_all=True) + + # check with different hybridize modes + model.hybridize(static_alloc=True) + check_name(model, + ['dense_dense0_fwd_data', 'dense_dense0_fwd_weight', + 'dense_dense0_fwd_bias', 'dense_dense0_fwd_output', + 'dense_relu0_fwd_input0', 'dense_relu0_fwd_output'], monitor_all=True) + @with_seed() def test_apply():