Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[API] Extend NumPy Array dtypes with int16, uint16, uint32, uint64 (#…
Browse files Browse the repository at this point in the history
…20478)

* extend dtypes with int16, uint16, uint32, uint64

* update operator_tune.cc

* add test suite

* fix sanity

* update test

* update ci

* fix

* fix

* fix Uint32

* extend dtypes support to tvmop
  • Loading branch information
barry-jin committed Sep 10, 2021
1 parent cb83c4c commit 17088c6
Show file tree
Hide file tree
Showing 21 changed files with 476 additions and 37 deletions.
310 changes: 308 additions & 2 deletions 3rdparty/mshadow/mshadow/base.h

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions ci/docker/install/requirements
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,6 @@ decorator==4.4.0
boto3==1.9.229
h5py==2.10.0
Pillow<6

# Array API Standardization requirements
hypothesis==6.14.0
16 changes: 16 additions & 0 deletions ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,22 @@ unittest_ubuntu_python3_cpu_onednn() {
pytest --durations=50 --cov-report xml:tests_mkl.xml --verbose tests/python/mkl
}

unittest_array_api_standardization() {
set -ex
python3 -m pip install -e /work/mxnet/python --user
cd ..
git clone https://github.com/data-apis/array-api-tests.git
pushd /work/array-api-tests
export ARRAY_API_TESTS_MODULE=mxnet.numpy pytest
# OverflowError: Python int too large to convert to C long
# when cython is enabled
export MXNET_ENABLE_CYTHON=0
export DMLC_LOG_STACK_TRACE_DEPTH=100
python3 -m pytest --durations=50 --cov-report xml:tests_api.xml --verbose \
array_api_tests/test_type_promotion.py::test_elementwise_function_two_arg_bool_type_promotion
popd
}

unittest_ubuntu_python3_gpu() {
set -ex
export PYTHONPATH=./python/
Expand Down
18 changes: 18 additions & 0 deletions ci/jenkins/Jenkins_steps.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ def python3_ut_onednn(docker_container_name) {
}
}

def python3_ut_array_api(docker_container_name) {
timeout(time: max_time, unit: 'MINUTES') {
utils.docker_run(docker_container_name, 'unittest_array_api_standardization', false)
}
}

// GPU test has two parts. 1) run unittest on GPU, 2) compare the results on
// both CPU and GPU
// Python 3
Expand Down Expand Up @@ -665,6 +671,18 @@ def test_unix_python3_cpu(lib_name) {
}]
}

def test_unix_python3_array_api(lib_name) {
return ['Python3: Array-API': {
node(NODE_LINUX_CPU) {
ws('workspace/ut-python3-cpu') {
utils.unpack_and_init(lib_name, mx_lib, true)
python3_ut_array_api('ubuntu_cpu')
utils.publish_test_coverage()
}
}
}]
}

def test_unix_python3_mkl_cpu(lib_name) {
return ['Python3: MKL-CPU': {
node(NODE_LINUX_CPU) {
Expand Down
1 change: 1 addition & 0 deletions ci/jenkins/Jenkinsfile_unix_cpu
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ core_logic: {
utils.parallel_stage('Tests', [
custom_steps.test_unix_python3_cpu('cpu'),
custom_steps.test_unix_python3_onnx_cpu('cpu'),
custom_steps.test_unix_python3_array_api('cpu'),
custom_steps.test_unix_python3_mkl_cpu('cpu_mkl'),
custom_steps.test_unix_python3_onednn_cpu('onednn_cpu'),
custom_steps.test_unix_python3_onednn_mkl_cpu('onednn_mkl_cpu'),
Expand Down
3 changes: 2 additions & 1 deletion contrib/tvmop/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
# coding: utf-8
import tvm

AllTypes = ["float32", "float64", "float16", "uint8", "int8", "int32", "int64"]
AllTypes = ["float32", "float64", "float16", "uint8", "uint16",
"uint32", "uint64", "int8", "int16", "int32", "int64"]
RealTypes = ["float32", "float64", "float16"]


Expand Down
16 changes: 16 additions & 0 deletions include/mxnet/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,14 @@ inline int String2MXNetTypeWithBool(const std::string& s) {
return mshadow::kInt64;
} else if (s == "bool") {
return mshadow::kBool;
} else if (s == "int16") {
return mshadow::kInt16;
} else if (s == "uint16") {
return mshadow::kUint16;
} else if (s == "uint32") {
return mshadow::kUint32;
} else if (s == "uint64") {
return mshadow::kUint64;
} else {
LOG(FATAL) << "unknown type " << s;
}
Expand All @@ -915,6 +923,14 @@ inline int String2MXNetType(const std::string& s) {
return mshadow::kInt32;
} else if (s == "int64") {
return mshadow::kInt64;
} else if (s == "int16") {
return mshadow::kInt16;
} else if (s == "uint16") {
return mshadow::kUint16;
} else if (s == "uint32") {
return mshadow::kUint32;
} else if (s == "uint64") {
return mshadow::kUint64;
} else {
LOG(FATAL) << "unknown type " << s;
}
Expand Down
8 changes: 8 additions & 0 deletions include/mxnet/tensor_blob.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,10 @@ class TBlob {
case mshadow::kInt8: return DLDataType{kDLInt, 8, 1};
case mshadow::kInt64: return DLDataType{kDLInt, 64, 1};
case mshadow::kBool: return DLDataType{kDLUInt, 1, 1};
case mshadow::kInt16: return DLDataType{kDLInt, 16, 1};
case mshadow::kUint16: return DLDataType{kDLUInt, 16, 1};
case mshadow::kUint32: return DLDataType{kDLUInt, 32, 1};
case mshadow::kUint64: return DLDataType{kDLUInt, 64, 1};
default: {
LOG(FATAL) << "Unknown type_flag=" << type_flag;
return DLDataType();
Expand Down Expand Up @@ -413,11 +417,15 @@ class TBlob {
switch (dldata_type.bits) {
case 1: return mshadow::kBool;
case 8: return mshadow::kUint8;
case 16: return mshadow::kUint16;
case 32: return mshadow::kUint32;
case 64: return mshadow::kUint64;
}
break;
case kDLInt:
switch (dldata_type.bits) {
case 8: return mshadow::kInt8;
case 16: return mshadow::kInt16;
case 32: return mshadow::kInt32;
case 64: return mshadow::kInt64;
}
Expand Down
8 changes: 5 additions & 3 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,9 +509,11 @@ def empty_like(prototype, dtype=None, order='C', subok=False, shape=None): # pyl
array([[4.9e-324, 9.9e-324, 1.5e-323], # uninitialized
[2.0e-323, 2.5e-323, 3.0e-323]])
"""
dtype_list = {None:'None', _np.int8:'int8', _np.uint8:'uint8', _np.int32:'int32',
_np.int64:'int64', _np.float16:'float16', _np.float32:'float32',
_np.float64:'float64', _np.bool_:'bool_', bool:'bool', int:'int64', float:'float64'}
dtype_list = {_np.float16: 'float16', _np.float32: 'float32', _np.float64: 'float64',
float: 'float64', _np.int8: 'int8', _np.int16: 'int16', _np.int32: 'int32',
_np.int64: 'int64', int:'int64', _np.uint8: 'uint8', _np.uint16: 'uint16',
_np.uint32: 'uint32', _np.uint64: 'uint64', _np.bool: 'bool',
_np.bool_: 'bool_', bool: 'bool', None: 'None'}
if order != 'C':
raise NotImplementedError("Only support C-order at this moment")
if subok:
Expand Down
19 changes: 16 additions & 3 deletions python/mxnet/numpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
import numpy as onp

__all__ = ['float16', 'float32', 'float64', 'uint8', 'int32', 'int8', 'int64',
'int16', 'uint16', 'uint32', 'uint64',
'bool', 'bool_', 'pi', 'inf', 'nan', 'PZERO', 'NZERO', 'newaxis', 'finfo',
'e', 'NINF', 'PINF', 'NAN', 'NaN',
'_STR_2_DTYPE_']
'_STR_2_DTYPE_', '_DTYPE_2_STR_']

py_bool = bool

float16 = onp.float16
float32 = onp.float32
Expand All @@ -35,6 +38,10 @@
int64 = onp.int64
bool_ = onp.bool_
bool = onp.bool
int16 = onp.int16
uint16 = onp.uint16
uint32 = onp.uint32
uint64 = onp.uint64

pi = onp.pi
inf = onp.inf
Expand All @@ -50,10 +57,16 @@
newaxis = None
finfo = onp.finfo

_STR_2_DTYPE_ = {'float16': float16, 'float32': float32, 'float64':float64, 'float': float64,
'uint8': uint8, 'int8': int8, 'int32': int32, 'int64': int64, 'int': int64,
_STR_2_DTYPE_ = {'float16': float16, 'float32': float32, 'float64': float64, 'float': float64,
'int8': int8, 'int16': int16, 'int32': int32, 'int64': int64, 'int': int64,
'uint8': uint8, 'uint16': uint16, 'uint32': uint32, 'uint64': uint64,
'bool': bool, 'bool_': bool_, 'None': None}

_DTYPE_2_STR_ = {float16: 'float16', float32: 'float32', float64: 'float64', float: 'float64',
int8: 'int8', int16: 'int16', int32: 'int32', int64: 'int64', int:'int64',
uint8: 'uint8', uint16: 'uint16', uint32: 'uint32', uint64: 'uint64',
bool: 'bool', bool_: 'bool_', py_bool: 'bool', None: 'None'}

_ONP_OP_MODULES = [onp, onp.linalg, onp.random, onp.fft]


Expand Down
2 changes: 1 addition & 1 deletion src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ NDArray NDArray::Slice(index_t begin, index_t end) const {
CHECK_EQ(storage_type(), kDefaultStorage);
NDArray ret = this->Detach();
size_t length = shape_.ProdShape(1, shape_.ndim());
MSHADOW_TYPE_SWITCH_WITH_BOOL(
MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(
ret.dtype(), DType, { ret.byte_offset_ += begin * length * sizeof(DType); });
ret.reuse_ = false;
ret.shape_[0] = end - begin;
Expand Down
4 changes: 2 additions & 2 deletions src/ndarray/ndarray_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ template<>
void Copy<cpu, cpu>(const TBlob &from, TBlob *to,
Context from_ctx, Context to_ctx,
RunContext ctx) {
MSHADOW_TYPE_SWITCH_WITH_BOOL(to->type_flag_, DType, {
MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(to->type_flag_, DType, {
if (to->type_flag_ == from.type_flag_) {
if (!features::is_enabled(features::INT64_TENSOR_SIZE)) {
CHECK_LT(from.Size(), (int64_t{1} << 31) - 1) <<
Expand All @@ -48,7 +48,7 @@ void Copy<cpu, cpu>(const TBlob &from, TBlob *to,
<< " bytes, to: " << to->Size() * sizeof(DType) << " bytes.";
common::ParallelCopy(to->dptr<DType>(), from.dptr<DType>(), size);
} else {
MSHADOW_TYPE_SWITCH_WITH_BOOL(from.type_flag_, SrcDType, {
MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(from.type_flag_, SrcDType, {
to->FlatTo1D<cpu, DType>() =
mshadow::expr::tcast<DType>(from.FlatTo1D<cpu, SrcDType>());
})
Expand Down
35 changes: 33 additions & 2 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,37 @@ struct AccType<mshadow::half::half_t> {
.add_enum("bool", mshadow::kBool)


#define MXNET_ADD_ALL_TYPES_EXT \
.add_enum("float32", mshadow::kFloat32) \
.add_enum("float64", mshadow::kFloat64) \
.add_enum("float16", mshadow::kFloat16) \
.add_enum("bfloat16", mshadow::kBfloat16) \
.add_enum("uint8", mshadow::kUint8) \
.add_enum("int8", mshadow::kInt8) \
.add_enum("int32", mshadow::kInt32) \
.add_enum("int64", mshadow::kInt64) \
.add_enum("int16", mshadow::kInt16) \
.add_enum("uint16", mshadow::kUint16) \
.add_enum("uint32", mshadow::kUint32) \
.add_enum("uint64", mshadow::kUint64)


#define MXNET_ADD_ALL_TYPES_EXT_WITH_BOOL \
.add_enum("float32", mshadow::kFloat32) \
.add_enum("float64", mshadow::kFloat64) \
.add_enum("float16", mshadow::kFloat16) \
.add_enum("bfloat16", mshadow::kBfloat16) \
.add_enum("uint8", mshadow::kUint8) \
.add_enum("int8", mshadow::kInt8) \
.add_enum("int32", mshadow::kInt32) \
.add_enum("int64", mshadow::kInt64) \
.add_enum("bool", mshadow::kBool) \
.add_enum("int16", mshadow::kInt16) \
.add_enum("uint16", mshadow::kUint16) \
.add_enum("uint32", mshadow::kUint32) \
.add_enum("uint64", mshadow::kUint64)


/* \brief Compute flattened index given coordinates and shape. */
template<int ndim>
MSHADOW_XINLINE index_t ravel(const Shape<ndim>& coord, const Shape<ndim>& shape) {
Expand Down Expand Up @@ -768,11 +799,11 @@ template <typename xpu>
MSHADOW_CINLINE void copy(mshadow::Stream<xpu> *s, const TBlob& to, const TBlob& from) {
CHECK_EQ(from.Size(), to.Size());
CHECK_EQ(from.dev_mask(), to.dev_mask());
MSHADOW_TYPE_SWITCH_WITH_BOOL(to.type_flag_, DType, {
MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(to.type_flag_, DType, {
if (to.type_flag_ == from.type_flag_) {
mshadow::Copy(to.FlatTo1D<xpu, DType>(s), from.FlatTo1D<xpu, DType>(s), s);
} else {
MSHADOW_TYPE_SWITCH_WITH_BOOL(from.type_flag_, SrcDType, {
MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(from.type_flag_, SrcDType, {
to.FlatTo1D<xpu, DType>(s) = mshadow::expr::tcast<DType>(from.FlatTo1D<xpu, SrcDType>(s));
})
}
Expand Down
4 changes: 2 additions & 2 deletions src/operator/numpy/np_elemwise_broadcast_logic_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ struct GetBinaryBroadcastCompute {
} else {
if (req[0] == kNullOp) return;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH_WITH_BOOL(lhs.type_flag_, DType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(rhs.type_flag_, EType, {
MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(lhs.type_flag_, DType, {
MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(rhs.type_flag_, EType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>());
Expand Down
2 changes: 1 addition & 1 deletion src/operator/numpy/np_init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ struct NumpyEyeParam : public dmlc::Parameter<NumpyEyeParam> {
DMLC_DECLARE_FIELD(dtype)
.set_default(-1)
.add_enum("None", -1)
MXNET_ADD_ALL_TYPES
MXNET_ADD_ALL_TYPES_EXT_WITH_BOOL
.describe("Data-type of the returned array.");
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
Expand Down
16 changes: 14 additions & 2 deletions src/operator/operator_tune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(uint8_t);
IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(int32_t);
IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(int64_t);
IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(bool);
IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(int16_t);
IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(uint16_t);
IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(uint32_t);
IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(uint64_t);

/*!
* \brief Init variable used to facilitate registering a tunable operator during
Expand All @@ -85,7 +89,11 @@ struct static_init_var {
__macro$(__VA_ARGS__, uint8_t); \
__macro$(__VA_ARGS__, int8_t); \
__macro$(__VA_ARGS__, int32_t); \
__macro$(__VA_ARGS__, int64_t);
__macro$(__VA_ARGS__, int64_t); \
__macro$(__VA_ARGS__, int16_t); \
__macro$(__VA_ARGS__, uint16_t); \
__macro$(__VA_ARGS__, uint32_t); \
__macro$(__VA_ARGS__, uint64_t)

#define MSHADOW_MACRO_FOREACH_TYPE_WITH_BOOL(__macro$, ...) \
__macro$(__VA_ARGS__, float); \
Expand All @@ -96,7 +104,11 @@ struct static_init_var {
__macro$(__VA_ARGS__, int8_t); \
__macro$(__VA_ARGS__, int32_t); \
__macro$(__VA_ARGS__, int64_t); \
__macro$(__VA_ARGS__, bool)
__macro$(__VA_ARGS__, bool); \
__macro$(__VA_ARGS__, int16_t); \
__macro$(__VA_ARGS__, uint16_t); \
__macro$(__VA_ARGS__, uint32_t); \
__macro$(__VA_ARGS__, uint64_t)

#define IMPLEMENT_WORKLOAD_VALUE_FOR_TYPE(__op$, __typ$) \
namespace mxnet_op { \
Expand Down
4 changes: 2 additions & 2 deletions src/operator/tensor/elemwise_binary_broadcast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,8 @@ void BinaryBroadcastComputeLogic(const nnvm::NodeAttrs& attrs,
} else {
if (req[0] == kNullOp) return;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH_WITH_BOOL(lhs.type_flag_, DType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(rhs.type_flag_, EType, {
MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(lhs.type_flag_, DType, {
MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(rhs.type_flag_, EType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>());
Expand Down
4 changes: 2 additions & 2 deletions src/operator/tensor/elemwise_binary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -599,8 +599,8 @@ template<typename xpu, typename OP>
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[1].type_flag_, EType, {
MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(inputs[1].type_flag_, EType, {
const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size())
+ DataType<DType>::kLanes - 1) / DataType<DType>::kLanes;
if (size != 0) {
Expand Down
6 changes: 3 additions & 3 deletions src/operator/tensor/elemwise_unary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ struct CastParam : public dmlc::Parameter<CastParam> {
int dtype;
DMLC_DECLARE_PARAMETER(CastParam) {
DMLC_DECLARE_FIELD(dtype)
MXNET_ADD_ALL_TYPES_WITH_BOOL
MXNET_ADD_ALL_TYPES_EXT_WITH_BOOL
.describe("Output data type.");
}
};
Expand All @@ -491,9 +491,9 @@ void CastCompute(const nnvm::NodeAttrs& attrs,
using namespace mshadow;
using namespace mshadow::expr;
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DstDType, {
MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(outputs[0].type_flag_, DstDType, {
Tensor<xpu, 1, DstDType> out = outputs[0].FlatTo1D<xpu, DstDType>(s);
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, SrcDType, {
MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(inputs[0].type_flag_, SrcDType, {
Tensor<xpu, 1, SrcDType> data = inputs[0].FlatTo1D<xpu, SrcDType>(s);
if ((outputs[0].type_flag_ != inputs[0].type_flag_ ||
req[0] != kWriteInplace) && outputs[0].Size() != 0) {
Expand Down
Loading

0 comments on commit 17088c6

Please sign in to comment.