Skip to content

Commit

Permalink
From PR:apache#15839 of apache mxnet.
Browse files Browse the repository at this point in the history
  • Loading branch information
joapolarbear committed Dec 19, 2019
1 parent 1c80de8 commit 15b1b68
Show file tree
Hide file tree
Showing 13 changed files with 305 additions and 8 deletions.
12 changes: 12 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ typedef void (*ExecutorMonitorCallback)(const char*,
NDArrayHandle,
void*);

/*! \brief Monitor callback called at operator level for cached op */
typedef void (*CachedOpMonitorCallback)(const char*,
const char*,
NDArrayHandle);

struct NativeOpInfo {
void (*forward)(int, float**, int*, unsigned**, int*, void*);
void (*backward)(int, float**, int*, unsigned**, int*, void*);
Expand Down Expand Up @@ -1206,6 +1211,13 @@ MXNET_DLL int MXInvokeCachedOpEx(CachedOpHandle handle,
NDArrayHandle **outputs,
const int** out_stypes);

/*!
* \brief cached op set monitor callback
*/
MXNET_DLL int MXCachedOpRegisterOpHook(NDArrayHandle handle,
CachedOpMonitorCallback callback,
bool monitor_all);

//--------------------------------------------
// Part 3: symbolic configuration generation
//--------------------------------------------
Expand Down
26 changes: 25 additions & 1 deletion python/mxnet/_ctypes/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@
from ..base import NDArrayHandle, CachedOpHandle
from ..base import check_call

def _monitor_callback_wrapper(callback):
"""A wrapper for the user-defined handle."""
def callback_handle(name, opr_name, array, _):
""" ctypes function """
callback(name, opr_name, array)
return callback_handle

class NDArrayBase(object):
"""Base data structure for ndarray"""
Expand Down Expand Up @@ -104,9 +110,10 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):

class CachedOp(object):
"""Cached operator handle."""
__slots__ = ["handle"]
__slots__ = ["handle", "is_np_sym", "_monitor_callback"]
def __init__(self, sym, flags=()):
self.handle = CachedOpHandle()
self._monitor_callback = None

check_call(_LIB.MXCreateCachedOpEx(
sym.handle,
Expand Down Expand Up @@ -158,3 +165,20 @@ def __call__(self, *args, **kwargs):
return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle),
stype=out_stypes[i])
for i in range(num_output.value)]

def _register_op_hook(self, callback, monitor_all=False):
"""Install callback for monitor.
Parameters
----------
callback : function
Takes a string for node_name, string for op_name and a NDArrayHandle.
monitor_all : bool, default False
If true, monitor both input _imperative_invoked output, otherwise monitor output only.
"""
cb_type = ctypes.CFUNCTYPE(None, ctypes.c_char_p, ctypes.c_char_p, NDArrayHandle, ctypes.c_void_p)
if callback:
self._monitor_callback = cb_type(_monitor_callback_wrapper(callback))
check_call(_LIB.MXCachedOpRegisterOpHook(
self.handle,
self._monitor_callback,
ctypes.c_int(monitor_all)))
8 changes: 8 additions & 0 deletions python/mxnet/cython/base.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@ from ..base import MXNetError

from libcpp.vector cimport vector
from libcpp.string cimport string
from libcpp cimport bool as _bool
from cpython.version cimport PY_MAJOR_VERSION

ctypedef void* SymbolHandle
ctypedef void* NDArrayHandle
ctypedef void* OpHandle
ctypedef void* CachedOpHandle
ctypedef void* MonitorCallbackHandle
ctypedef unsigned nn_uint
ctypedef void (*CachedOpMonitorCallback)(const char*,
const char*,
NDArrayHandle)

cdef py_str(const char* x):
if PY_MAJOR_VERSION < 3:
Expand Down Expand Up @@ -112,3 +117,6 @@ cdef extern from "mxnet/c_api.h":
int *num_outputs,
NDArrayHandle **outputs,
const int **out_stypes);
int MXCachedOpRegisterOpHook(NDArrayHandle handle,
CachedOpMonitorCallback callback,
_bool monitor_all);
16 changes: 16 additions & 0 deletions python/mxnet/cython/ndarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import ctypes as _ctypes
import numpy as np
from ..ndarray_doc import _build_doc
from libc.stdint cimport uint32_t, int64_t
from ..base import _LIB

include "./base.pyi"

Expand Down Expand Up @@ -69,6 +70,10 @@ def _set_ndarray_class(cls):
global _ndarray_cls
_ndarray_cls = cls

def _monitor_callback_wrapper(callback):
def callback_handle(name, opr_name, arr, _):
callback(name, opr_name, arr)
return callback_handle

cdef NewArray(NDArrayHandle handle, int stype=-1):
"""Create a new array given handle"""
Expand Down Expand Up @@ -96,6 +101,9 @@ cdef class CachedOp:
def __set__(self, value):
self._set_handle(value)

cdef int is_np_sym
cdef readonly object mhandle

def __init__(self, sym, flags=()):
cdef vector[string] s_flag_keys
cdef vector[string] s_flag_vals
Expand Down Expand Up @@ -158,6 +166,14 @@ cdef class CachedOp:
else:
return [NewArray(p_output_vars[i], p_output_stypes[i]) for i in range(num_output)]

def _register_op_hook(self, callback, monitor_all=False):
cb_type = _ctypes.CFUNCTYPE(None, _ctypes.c_char_p, _ctypes.c_char_p, _ctypes.c_void_p, _ctypes.c_void_p)
if callback:
self.mhandle = cb_type(_monitor_callback_wrapper(callback))
chandle = _ctypes.cast(<unsigned long long>self.chandle, _ctypes.c_void_p)
CALL(_LIB.MXCachedOpRegisterOpHook(chandle,
self.mhandle,
_ctypes.c_int(monitor_all)))

def _imperative_invoke(handle, ndargs, keys, vals, out):
"""cython implementation of imperative invoke wrapper"""
Expand Down
36 changes: 36 additions & 0 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,18 @@ def forward(self, *args):
# pylint: disable= invalid-name
raise NotImplementedError

def register_op_hook(self, callback, monitor_all=False):
"""Install callback monitor.
Parameters
----------
callback : function
Takes a string and a NDArrayHandle.
monitor_all : bool, default False
If true, monitor both input and output, otherwise monitor output only.
"""
for cld in self._children.values():
cld.register_op_hook(callback, monitor_all)

def summary(self, *inputs):
"""Print the summary of the model's output and parameters.
Expand Down Expand Up @@ -728,6 +740,8 @@ def __init__(self, prefix=None, params=None):
self._in_format = None
self._active = False
self._flags = []
self._callback = None
self._monitor_all = False

def __setattr__(self, name, value):
"""Registers parameters."""
Expand Down Expand Up @@ -804,6 +818,13 @@ def _call_cached_op(self, *args):
if self._cached_op is None:
self._build_cache(*args)

assert self._cached_op, "cached op is not None"
if self._callback:
self._cached_op._register_op_hook(self._callback, self._monitor_all)
if len(self._flags) >= 2 and (self._flags[1] or self._flags[0]):
warnings.warn("register_op_hook is experimental when static_alloc=True / static_shape=True "
" and may not work correctly")

args, fmt = _flatten(args, "input")
assert fmt == self._in_format, "Invalid input format"
try:
Expand Down Expand Up @@ -906,6 +927,21 @@ def export(self, path, epoch=0, remove_amp_cast=True):
arg_dict['aux:%s'%name] = param._reduce()
ndarray.save('%s-%04d.params'%(path, epoch), arg_dict)

def register_op_hook(self, callback, monitor_all=False):
"""Install op hook for block recursively.
Parameters
----------
callback : function
Takes a string and a NDArrayHandle.
monitor_all : bool, default False
If true, monitor both input and output, otherwise monitor output only.
"""
self._callback = callback
self._monitor_all = monitor_all
for cld in self._children.values():
cld._callback = callback
cld._monitor_all = monitor_all

def forward(self, x, *args):
"""Defines the forward computation. Arguments can be either
:py:class:`NDArray` or :py:class:`Symbol`."""
Expand Down
21 changes: 21 additions & 0 deletions src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,24 @@ int MXAutogradGetSymbol(NDArrayHandle handle, SymbolHandle *out) {
*out = reinterpret_cast<SymbolHandle>(sym);
API_END();
}

int MXCachedOpRegisterOpHook(NDArrayHandle handle,
CachedOpMonitorCallback callback,
bool monitor_all) {
API_BEGIN();
CachedOpMonitorCallback callback_temp = nullptr;
std::function<void(const char *, const char *, void*)> clbk;
if (callback) {
callback_temp = callback;
clbk = [callback_temp](const char *name, const char *opr_name,
void *handle) {
callback_temp(name, opr_name, handle);
};
} else {
clbk = nullptr;
}
CachedOpPtr op = *static_cast<CachedOpPtr *>(handle);
op->RegisterOpHook(clbk, monitor_all);
API_END();
}

57 changes: 57 additions & 0 deletions src/common/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,62 @@ void CastStorageDispatch<cpu>(const OpContext& ctx,
mxnet::op::CastStorageComputeImpl<cpu>(ctx, input, output);
}

void ExecuteMonInputCallback(
const nnvm::IndexedGraph &idx, const std::vector<NDArray *> &state_arrays,
size_t nid, const std::function<void(const char *, const char *, void *)>
&monitor_callback) {
static const auto &flist_inputs =
nnvm::Op::GetAttr<nnvm::FListInputNames>("FListInputNames");
std::vector<std::string> input_names;
const nnvm::IndexedGraph::Node &inode = idx[nid];
const nnvm::Node *node = inode.source;
if (flist_inputs.count(node->op())) {
input_names = flist_inputs[node->op()](node->attrs);
} else {
for (size_t i = 0; i < node->num_inputs(); ++i) {
input_names.emplace_back("input" + std::to_string(i));
}
}

for (size_t i = 0; i < node->num_inputs(); ++i) {
const nnvm::NodeEntry &input = node->inputs[i];
if (state_arrays[idx.entry_id(input)]->is_none()) {
continue;
}
NDArray *cpy = new NDArray(*state_arrays[idx.entry_id(input)]);
std::string name = inode.source->attrs.name + "_" + input_names[i];
monitor_callback(name.c_str(), inode.source->op()->name.c_str(),
reinterpret_cast<void *>(cpy));
}
}

void ExecuteMonOutputCallback(
const nnvm::IndexedGraph &idx, const std::vector<NDArray *> &state_arrays,
size_t nid, const std::function<void(const char *, const char *, void *)>
&monitor_callback) {
static const auto &flist_outputs =
nnvm::Op::GetAttr<nnvm::FListOutputNames>("FListOutputNames");
std::vector<std::string> output_names;
const nnvm::IndexedGraph::Node &inode = idx[nid];
const nnvm::Node *node = inode.source;
if (flist_outputs.count(node->op())) {
output_names = flist_outputs[node->op()](node->attrs);
} else {
for (size_t i = 0; i < node->num_outputs(); ++i) {
output_names.emplace_back(std::to_string(i));
}
}

for (size_t i = 0; i < node->num_outputs(); ++i) {
if (state_arrays[idx.entry_id(nid, i)]->is_none()) {
continue;
}
NDArray *cpy = new NDArray(*state_arrays[idx.entry_id(nid, i)]);
std::string name = inode.source->attrs.name + "_" + output_names[i];
monitor_callback(name.c_str(), inode.source->op()->name.c_str(),
reinterpret_cast<void *>(cpy));
}
}

} // namespace common
} // namespace mxnet
10 changes: 10 additions & 0 deletions src/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,16 @@ inline void ConvertToLegacyShape(mxnet::ShapeVector* shapes) {
}
}

void ExecuteMonInputCallback(
const nnvm::IndexedGraph &idx, const std::vector<NDArray *> &state_arrays,
size_t nid, const std::function<void(const char *, const char *, void *)>
&monitor_callback);

void ExecuteMonOutputCallback(
const nnvm::IndexedGraph &idx, const std::vector<NDArray *> &state_arrays,
size_t nid, const std::function<void(const char *, const char *, void *)>
&monitor_callback);

} // namespace common
} // namespace mxnet
#endif // MXNET_COMMON_UTILS_H_
22 changes: 19 additions & 3 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,9 @@ void CachedOp::StaticRunOps(
ndinputs.emplace_back(state_arrays[idx.entry_id(j)]);
CHECK(!ndinputs.back()->is_none());
}
if (monitor_callback_ && monitor_all_) {
mxnet::common::ExecuteMonInputCallback(idx, state_arrays, i, monitor_callback_);
}
ndoutputs.clear();
ndoutputs.reserve(num_outputs);
req.clear();
Expand Down Expand Up @@ -735,6 +738,9 @@ void CachedOp::StaticRunOps(
default_ctx, node.source->attrs, ndinputs, ndoutputs, req,
dispatch_mode);
}
if (monitor_callback_) {
mxnet::common::ExecuteMonOutputCallback(idx, state_arrays, i, monitor_callback_);
}
}
}
}
Expand Down Expand Up @@ -883,12 +889,12 @@ OpStatePtr CachedOp::DynamicForward(
// So if it's not the inline mode, we disable recording.
RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
std::move(ref_count), &states, dispatch_modes,
recording && inlining_);
recording && inlining_, nullptr, monitor_callback_, monitor_all_);
} else {
mxnet::ShapeVector shapes = g.GetAttr<mxnet::ShapeVector>("shape");
NaiveRunGraph(false, default_ctx, idx, arrays, 0, idx.num_nodes(),
std::move(array_reqs), std::move(ref_count), &states,
dispatch_modes, recording && inlining_, &shapes);
dispatch_modes, recording && inlining_, &shapes, monitor_callback_, monitor_all_);
{
auto state_ptr = GetCachedOpState(default_ctx);
auto& state = state_ptr.get_state<CachedOpState>();
Expand Down Expand Up @@ -1028,7 +1034,7 @@ void CachedOp::DynamicBackward(

RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
std::move(array_reqs), std::move(ref_count), &states, dispatch_modes,
Imperative::Get()->is_recording());
Imperative::Get()->is_recording(), nullptr, monitor_callback_);

if (retain_graph) {
buff.resize(num_forward_entries);
Expand Down Expand Up @@ -1295,6 +1301,16 @@ void CachedOpBackward(const OpStatePtr& state_ptr,
CopyFromTo(out_bufs[i], outputs[i]);
}

/*
* Register the callback to be called when the operator is executed
*/
void CachedOp::RegisterOpHook(const CachedOp::CachedOpMonCallback& callback,
bool monitor_all) {
CHECK(callback) << "invalid callback";
monitor_callback_ = callback;
monitor_all_ = monitor_all;
}

OpStatePtr CreateCachedOpState(const NodeAttrs& attrs,
Context ctx,
const mxnet::ShapeVector& in_shapes,
Expand Down
7 changes: 7 additions & 0 deletions src/imperative/cached_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
};

class CachedOp {
using CachedOpMonCallback =
std::function<void(const char *, const char *, void *)>;
public:
CachedOp(
const nnvm::Symbol& sym,
Expand Down Expand Up @@ -134,6 +136,8 @@ class CachedOp {
sym.outputs = fwd_graph_.outputs;
return sym;
}
void RegisterOpHook(const CachedOp::CachedOpMonCallback& callback,
bool monitor_all = false);

private:
struct GraphInfo;
Expand Down Expand Up @@ -203,6 +207,9 @@ class CachedOp {
std::vector<bool> save_inputs_, save_outputs_;
std::vector<OpReqType> bwd_output_reqs_;

std::function<void(const char*, const char*, NDArrayHandle)> monitor_callback_{nullptr};
bool monitor_all_{false};

std::mutex mutex_;
std::unordered_map<Context, std::vector<OpStatePtr> > cached_op_states_;
};
Expand Down
Loading

0 comments on commit 15b1b68

Please sign in to comment.