Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 64f737c

Browse files
ciyongchTaoLv
andauthored
[v1.6] Fix the monitor_callback invalid issue during calibration with variable input shapes (#18632) (#18703)
* Fix the monitor_callback invalid issue during calibration with variable input shapes * retrigger CI * Add UT for monitor check and disable codecov Co-authored-by: Tao Lv <tao.a.lv@intel.com>
1 parent 61597a5 commit 64f737c

3 files changed

Lines changed: 65 additions & 0 deletions

File tree

.codecov.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ codecov:
44
require_ci_to_pass: yes
55

66
coverage:
7+
status:
8+
project: off
9+
patch: off
710
precision: 2
811
round: down
912
range: "70...100"

python/mxnet/executor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __init__(self, handle, symbol, ctx, grad_req, group2ctx):
7979
self._aux_dict = None
8080
self._output_dict = None
8181
self._monitor_callback = None
82+
self._monitor_all = None
8283
self._ctx = copy.deepcopy(ctx)
8384
self._grad_req = copy.deepcopy(grad_req)
8485
self._group2ctx = copy.deepcopy(group2ctx)
@@ -253,6 +254,7 @@ def set_monitor_callback(self, callback, monitor_all=False):
253254
"""
254255
cb_type = ctypes.CFUNCTYPE(None, ctypes.c_char_p, NDArrayHandle, ctypes.c_void_p)
255256
self._monitor_callback = cb_type(_monitor_callback_wrapper(callback))
257+
self._monitor_all = monitor_all
256258
check_call(_LIB.MXExecutorSetMonitorCallbackEX(
257259
self.handle,
258260
self._monitor_callback,
@@ -477,6 +479,13 @@ def reshape(self, partial_shaping=False, allow_up_sizing=False, **kwargs):
477479
executor.arg_arrays = arg_arrays
478480
executor.grad_arrays = grad_arrays
479481
executor.aux_arrays = aux_arrays
482+
if (self._monitor_callback is not None) and (self._monitor_all is not None):
483+
# rebind callback to the new executor if the callback is valid
484+
check_call(_LIB.MXExecutorSetMonitorCallbackEX(
485+
handle,
486+
self._monitor_callback,
487+
None,
488+
ctypes.c_int(self._monitor_all)))
480489
return executor
481490

482491
def debug_str(self):

tests/python/unittest/test_operator.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8364,6 +8364,59 @@ def get_output_names_callback(name, arr):
83648364
check_name(us_sym, ['data', 'pooling_data', 'pooling_output'])
83658365
del os.environ['MXNET_SUBGRAPH_BACKEND']
83668366

8367+
@with_seed()
8368+
def test_monitor_with_variable_input_shape():
8369+
output = {}
8370+
8371+
def get_output_min_callback(name, arr):
8372+
name = py_str(name)
8373+
handle = ctypes.cast(arr, NDArrayHandle)
8374+
arr = NDArray(handle, writable=False)
8375+
min_val = mx.ndarray.min(arr).asscalar()
8376+
if name in output:
8377+
output[name] = min(output[name], min_val)
8378+
else:
8379+
output[name] = min_val
8380+
8381+
def check_result(output, names):
8382+
assert len(output) > 0
8383+
for k, v in output.items():
8384+
assert k in names
8385+
assert v is not None
8386+
8387+
is_windows = sys.platform.startswith('win')
8388+
if (is_windows):
8389+
# Windows doesn't support set environment variable on the fly, so disable it for now
8390+
pass
8391+
else:
8392+
# Disable subgraph in case subgraph will replace symbol
8393+
os.environ['MXNET_SUBGRAPH_BACKEND'] = "NONE"
8394+
8395+
batch_size = 1
8396+
op_name = 'conv'
8397+
dshape = (batch_size, 3, 10, 10)
8398+
data = mx.sym.Variable('data', shape=dshape)
8399+
sym = mx.sym.Convolution(data, kernel=(1, 1), num_filter=1, name=op_name)
8400+
8401+
mod = mx.module.Module(symbol=sym, label_names=None)
8402+
mod.bind(for_training=False, data_shapes=[('data', dshape)])
8403+
mod.init_params()
8404+
mod._exec_group.execs[0].set_monitor_callback(get_output_min_callback, monitor_all=True)
8405+
8406+
new_dshape = dshape[:-1] + (dshape[-1] + 4,)
8407+
new_data = mx.nd.random.uniform(shape=new_dshape)
8408+
new_data = mx.io.NDArrayIter(data=new_data, batch_size=batch_size)
8409+
new_data = DummyIter(new_data)
8410+
8411+
for batch in new_data:
8412+
mod.forward(data_batch=batch, is_train=False)
8413+
mx.nd.waitall()
8414+
break
8415+
8416+
name_list = ['data', 'conv_data', 'conv_weight', 'conv_bias', 'conv_output']
8417+
check_result(output, name_list)
8418+
del os.environ['MXNET_SUBGRAPH_BACKEND']
8419+
83678420
@with_seed()
83688421
@unittest.skip("test fails intermittently. temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/13915")
83698422
def test_activation():

0 commit comments

Comments
 (0)