Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.7] Fix the monitor_callback invalid issue during calibration with variable input shapes #18703

Merged
merged 2 commits into from
Jul 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ codecov:
require_ci_to_pass: yes

coverage:
status:
project: off
patch: off
precision: 2
round: down
range: "70...100"
Expand Down
9 changes: 9 additions & 0 deletions python/mxnet/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(self, handle, symbol, ctx, grad_req, group2ctx):
self._aux_dict = None
self._output_dict = None
self._monitor_callback = None
self._monitor_all = None
self._ctx = copy.deepcopy(ctx)
self._grad_req = copy.deepcopy(grad_req)
self._group2ctx = copy.deepcopy(group2ctx)
Expand Down Expand Up @@ -253,6 +254,7 @@ def set_monitor_callback(self, callback, monitor_all=False):
"""
cb_type = ctypes.CFUNCTYPE(None, ctypes.c_char_p, NDArrayHandle, ctypes.c_void_p)
self._monitor_callback = cb_type(_monitor_callback_wrapper(callback))
self._monitor_all = monitor_all
check_call(_LIB.MXExecutorSetMonitorCallbackEX(
self.handle,
self._monitor_callback,
Expand Down Expand Up @@ -477,6 +479,13 @@ def reshape(self, partial_shaping=False, allow_up_sizing=False, **kwargs):
executor.arg_arrays = arg_arrays
executor.grad_arrays = grad_arrays
executor.aux_arrays = aux_arrays
if (self._monitor_callback is not None) and (self._monitor_all is not None):
# rebind callback to the new executor if the callback is valid
check_call(_LIB.MXExecutorSetMonitorCallbackEX(
handle,
self._monitor_callback,
None,
ctypes.c_int(self._monitor_all)))
return executor

def debug_str(self):
Expand Down
53 changes: 53 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8364,6 +8364,59 @@ def get_output_names_callback(name, arr):
check_name(us_sym, ['data', 'pooling_data', 'pooling_output'])
del os.environ['MXNET_SUBGRAPH_BACKEND']

@with_seed()
def test_monitor_with_variable_input_shape():
output = {}

def get_output_min_callback(name, arr):
name = py_str(name)
handle = ctypes.cast(arr, NDArrayHandle)
arr = NDArray(handle, writable=False)
min_val = mx.ndarray.min(arr).asscalar()
if name in output:
output[name] = min(output[name], min_val)
else:
output[name] = min_val

def check_result(output, names):
assert len(output) > 0
for k, v in output.items():
assert k in names
assert v is not None

is_windows = sys.platform.startswith('win')
if (is_windows):
# Windows doesn't support set environment variable on the fly, so disable it for now
pass
else:
# Disable subgraph in case subgraph will replace symbol
os.environ['MXNET_SUBGRAPH_BACKEND'] = "NONE"

batch_size = 1
op_name = 'conv'
dshape = (batch_size, 3, 10, 10)
data = mx.sym.Variable('data', shape=dshape)
sym = mx.sym.Convolution(data, kernel=(1, 1), num_filter=1, name=op_name)

mod = mx.module.Module(symbol=sym, label_names=None)
mod.bind(for_training=False, data_shapes=[('data', dshape)])
mod.init_params()
mod._exec_group.execs[0].set_monitor_callback(get_output_min_callback, monitor_all=True)

new_dshape = dshape[:-1] + (dshape[-1] + 4,)
new_data = mx.nd.random.uniform(shape=new_dshape)
new_data = mx.io.NDArrayIter(data=new_data, batch_size=batch_size)
new_data = DummyIter(new_data)

for batch in new_data:
mod.forward(data_batch=batch, is_train=False)
mx.nd.waitall()
break

name_list = ['data', 'conv_data', 'conv_weight', 'conv_bias', 'conv_output']
check_result(output, name_list)
del os.environ['MXNET_SUBGRAPH_BACKEND']

@with_seed()
@unittest.skip("test fails intermittently. temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/13915")
def test_activation():
Expand Down