Skip to content

Commit

Permalink
[v1.9.x] Port apache#20759 from v1.x (apache#20815)
Browse files Browse the repository at this point in the history
* [v1.x] Restore quantized RNN operator from MXNet 1.6 (apache#20759)

* restore but seg fault

* Refactor & seg fault fixed

* apply formatter

* fix sanity

* Fix docs build

* anko review

* Remove copyright by contributors from touched files

* remove comments / apply formatter

* Update call to work with older mkldnn version.

Co-authored-by: bgawrych <[email protected]>
  • Loading branch information
josephevans and bgawrych authored Jan 12, 2022
1 parent 75c6373 commit 0642923
Show file tree
Hide file tree
Showing 18 changed files with 1,814 additions and 216 deletions.
1 change: 1 addition & 0 deletions docs/python_docs/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies:
- conda>=4.6.13
- pip
- python
- setuptools==49.6.0
- jupyter
- sphinx==2.4.0
- matplotlib
Expand Down
15 changes: 15 additions & 0 deletions include/mxnet/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,21 @@ using FAvoidQuantizeInput = std::function<bool (const NodeAttrs& attrs,
const size_t index,
const std::string quantize_granularity)>;

/*!
* \brief Register a function to determine if the input of a quantized operator
* needs to be quantized asymmetrically.
*/
using FNeedAsymQuantizeInput = std::function<bool (const NodeAttrs& attrs,
const size_t index)>;

/*!
* \brief Register a function to determine if the output of a quantized operator
* needs to be dequantized. This is usually used for the quantized operators
* which can produce fp32 outputs directly.
*/
using FAvoidDequantizeOutput = std::function<bool (const NodeAttrs& attrs,
const size_t index)>;

/*!
* \brief Register a function to determine if the input of a quantized operator
* needs to be calibrated. This is usually used for the quantized operators
Expand Down
23 changes: 16 additions & 7 deletions python/mxnet/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from ..ndarray import array
from ..ndarray import concat, tile

from .utils import _init_data, _has_instance, _getdata_by_idx
from .utils import _init_data, _has_instance, _getdata_by_idx, _slice_along_batch_axis

class DataDesc(namedtuple('DataDesc', ['name', 'shape'])):
"""DataDesc is used to store name, shape, type and layout
Expand Down Expand Up @@ -602,10 +602,12 @@ class NDArrayIter(DataIter):
The data name.
label_name : str, optional
The label name.
layout : str, optional
The data layout
"""
def __init__(self, data, label=None, batch_size=1, shuffle=False,
last_batch_handle='pad', data_name='data',
label_name='softmax_label'):
label_name='softmax_label', layout='NCHW'):
super(NDArrayIter, self).__init__(batch_size)

self.data = _init_data(data, allow_empty=False, default_name=data_name)
Expand All @@ -631,20 +633,27 @@ def __init__(self, data, label=None, batch_size=1, shuffle=False,
# used for 'roll_over'
self._cache_data = None
self._cache_label = None
self.layout = layout

@property
def provide_data(self):
"""The name and shape of data provided by this iterator."""
batch_axis = self.layout.find('N')
return [
DataDesc(k, tuple([self.batch_size] + list(v.shape[1:])), v.dtype)
DataDesc(k, tuple(list(v.shape[:batch_axis]) + \
[self.batch_size] + list(v.shape[batch_axis + 1:])),
v.dtype, layout=self.layout)
for k, v in self.data
]

@property
def provide_label(self):
"""The name and shape of label provided by this iterator."""
batch_axis = self.layout.find('N')
return [
DataDesc(k, tuple([self.batch_size] + list(v.shape[1:])), v.dtype)
DataDesc(k, tuple(list(v.shape[:batch_axis]) + \
[self.batch_size] + list(v.shape[batch_axis + 1:])),
v.dtype, layout=self.layout)
for k, v in self.label
]

Expand Down Expand Up @@ -681,7 +690,7 @@ def next(self):
data = self.getdata()
label = self.getlabel()
# iter should stop when last batch is not complete
if data[0].shape[0] != self.batch_size:
if data[0].shape[self.layout.find('N')] != self.batch_size:
# in this case, cache it for next epoch
self._cache_data = data
self._cache_label = label
Expand All @@ -697,7 +706,7 @@ def _getdata(self, data_source, start=None, end=None):
end = data_source[0][1].shape[0] if data_source else 0
s = slice(start, end)
return [
x[1][s]
_slice_along_batch_axis(x[1], s, self.layout.find('N'))
if isinstance(x[1], (np.ndarray, NDArray)) else
# h5py (only supports indices in increasing order)
array(x[1][sorted(self.idx[s])][[
Expand All @@ -716,7 +725,7 @@ def _concat(self, first_data, second_data):
concat(
first_data[i],
second_data[i],
dim=0
dim=self.layout.find('N')
) for i in range(len(first_data))
]

Expand Down
5 changes: 5 additions & 0 deletions python/mxnet/io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,8 @@ def _getdata_by_idx(data, idx):
shuffle_data.append((k, array(v.asnumpy()[idx], v.context)))

return shuffle_data

def _slice_along_batch_axis(data, s, batch_axis):
"""Apply slice along the batch axis"""
ret = data.slice_axis(axis=batch_axis, begin=s.start, end=s.stop)
return ret
Loading

0 comments on commit 0642923

Please sign in to comment.