diff --git a/python/mxnet/_ctypes/_api_internal.py b/python/mxnet/_ctypes/_api_internal.py index 39128ed8fe9f..adda4b9db8bf 100644 --- a/python/mxnet/_ctypes/_api_internal.py +++ b/python/mxnet/_ctypes/_api_internal.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. -"""CachedOp APIs exposed from C++.""" +"""NDArray APIs exposed from C++.""" import mxnet._ffi -mxnet._ffi._init_api("cached_op", __name__) +mxnet._ffi._init_api("ndarray", __name__) diff --git a/python/mxnet/_ctypes/cached_op.py b/python/mxnet/_ctypes/cached_op.py index 856eb5dbdf91..14a8c2a4df59 100644 --- a/python/mxnet/_ctypes/cached_op.py +++ b/python/mxnet/_ctypes/cached_op.py @@ -48,14 +48,14 @@ def __init__(self, sym, flags=(), thread_safe=False): self.is_np_sym = bool(isinstance(sym, _Symbol)) flags = {key: str(value) for key, value in flags} - self.handle = CachedOpHandle(_api_internal.create( + self.handle = CachedOpHandle(_api_internal.cached_op_create( sym.handle, flags, thread_safe )) def __del__(self): - _api_internal.free(self.handle) + _api_internal.cached_op_free(self.handle) def get_optimized_symbol(self): """Get an optimized version of the symbol from the cached op. @@ -66,7 +66,7 @@ def get_optimized_symbol(self): Optimized symbol from the executor. """ from ..symbol import Symbol - sym_handle = SymbolHandle(_api_internal.get_optimized_symbol(self.handle)) + sym_handle = SymbolHandle(_api_internal.cached_op_get_optimized_symbol(self.handle)) ret = Symbol(sym_handle) return ret @@ -85,7 +85,7 @@ def __call__(self, *args, **kwargs): type_id = default_ctx.device_typeid if default_ctx else None device_id = default_ctx.device_id if default_ctx else None out_arg = out if out is not None and not isinstance(out, NDArrayBase) else (out, ) - output_vars = _api_internal.invoke( + output_vars = _api_internal.cached_op_invoke( self.handle, len(args), *args, @@ -157,4 +157,4 @@ def _register_op_hook(self, callback, monitor_all=False): if callback: self._monitor_callback = cb_type(_monitor_callback_wrapper(callback)) callback_ptr = ctypes.cast(self._monitor_callback, ctypes.c_void_p) - _api_internal.register_op_hook(self.handle, callback_ptr, monitor_all) + _api_internal.cached_op_register_op_hook(self.handle, callback_ptr, monitor_all) diff --git a/python/mxnet/amp/amp.py b/python/mxnet/amp/amp.py index 99272bb46bca..bc99f0cfe18a 100644 --- a/python/mxnet/amp/amp.py +++ b/python/mxnet/amp/amp.py @@ -58,7 +58,13 @@ def _cast_symbol_NDArray(s, dtype, is_numpy_module=False): amp_cast = symbol.numpy._internal.amp_cast if is_numpy_module else symbol.amp_cast return amp_cast(s, dtype=dtype) if isinstance(s, NDArray): - amp_cast = ndarray.numpy._internal.amp_cast if is_numpy_module else ndarray.amp_cast + if is_numpy_module: + def amp_cast(s, dtype=None): # pylint: disable=function-redefined + if not isinstance(dtype, str): + dtype = np.dtype(dtype).name + return ndarray.numpy._api_internal.amp_cast(s, dtype) + else: + amp_cast = ndarray.amp_cast if s.dtype != dtype and (s.dtype in float_types_gpu and s.context.device_type != 'cpu' or s.dtype in float_types_cpu and s.context.device_type == 'cpu'): return amp_cast(s, dtype=dtype) @@ -106,7 +112,13 @@ def _wrap_module_functions(module, is_numpy_module, target_dtype, get_aliases, g get_fun_to_wrap, target_precision_ops=None, conditional_fp32_ops=None, fp32_ops=None): - nd_mod = ndarray.numpy._internal if is_numpy_module else ndarray + if is_numpy_module: + def amp_cast(s, dtype=None): # pylint: disable=function-redefined + if not isinstance(dtype, str): + dtype = np.dtype(dtype).name + return ndarray.numpy._api_internal.amp_cast(s, dtype) + else: + amp_cast = ndarray.amp_cast sy_mod = symbol.numpy._internal if is_numpy_module else symbol def _ndarray_wrapper(f, target_dtype, fp32_param=None, cond_arg=None): @@ -194,7 +206,7 @@ def _new_fun(*args, **kwargs): widest_type = np.float32 for arr, index, arg in symbols: if arg.dtype != widest_type and arg.dtype == target_dtype: - arr[index] = nd_mod.amp_cast(arg, dtype=widest_type) + arr[index] = amp_cast(arg, dtype=widest_type) else: # Symbol case sym_to_check = list(map(lambda x: x[2], symbols)) diff --git a/python/mxnet/gluon/data/vision/transforms/image.py b/python/mxnet/gluon/data/vision/transforms/image.py index 37b2a061b0df..4ab696f1944e 100644 --- a/python/mxnet/gluon/data/vision/transforms/image.py +++ b/python/mxnet/gluon/data/vision/transforms/image.py @@ -701,7 +701,7 @@ def hybrid_forward(self, F, x, *args): mat = F.np.concatenate((F.np.full((3, 1), 0.2989), F.np.full((3, 1), 0.5870), F.np.full((3, 1), 0.114)), axis=1) - x = F.npx.cast(x, dtype='float32') + x = x.astype(dtype='float32') gray = F.np.where(self.p < F.np.random.uniform(), x, F.np.dot(x, mat)) else: mat = F.concat(F.full((3, 1), 0.2989), diff --git a/python/mxnet/ndarray/register.py b/python/mxnet/ndarray/register.py index 9f50ba97efee..903c3ed9f1f4 100644 --- a/python/mxnet/ndarray/register.py +++ b/python/mxnet/ndarray/register.py @@ -25,6 +25,7 @@ from ..base import mx_uint, check_call, _LIB, py_str, _init_op_module, _Null, _is_np_op, _output_is_list # pylint: disable=unused-import from ..util import use_np_shape # pylint: disable=unused-import +from .._ctypes import _api_internal # pylint: disable=unused-import def _verify_all_np_ndarrays(op_name, func_name, args, out): @@ -111,6 +112,17 @@ def _verify_all_legacy_ndarrays(op_name, func_name, args, out): .format(op_name, func_name)) +def _np_imperative_invoke(handle, ndargs, out): + """PackedFunc based numpy operator invocation call""" + output_vars = _api_internal.invoke(handle, *ndargs, out) + if out is not None: + return out + if isinstance(output_vars, NDArrayBase): + return output_vars + else: + return list(output_vars) + + # pylint: disable=too-many-locals def _generate_ndarray_function_code(handle, op_name, func_name, signature_only=False): """Generate function for ndarray op by handle and function op_name.""" @@ -176,62 +188,131 @@ def _generate_ndarray_function_code(handle, op_name, func_name, signature_only=F code = [] is_np_op = _is_np_op(op_name) output_is_list = _output_is_list(op_name) - doc_str_idx = 1 if is_np_op: - doc_str_idx = 2 - if arr_name: - code.append(""" -def %s(*%s, **kwargs):"""%(func_name, arr_name)) - if not signature_only: + if arr_name: code.append(""" +def %s(*%s, **kwargs):"""%(func_name, arr_name)) + if not signature_only: + code.append(""" ndargs = [] for i in {}: assert isinstance(i, NDArrayBase), \\ "Positional arguments must have NDArray type, " \\ "but got %s"%str(i) ndargs.append(i)""".format(arr_name)) - if dtype_name is not None: - code.append(""" + if dtype_name is not None: + code.append(""" if '%s' in kwargs: if _np.dtype(kwargs['%s']).names: kwargs['%s'] = _np.dtype(kwargs['%s']).names[0] else: kwargs['%s'] = _np.dtype(kwargs['%s']).name """%( dtype_name, dtype_name, dtype_name, dtype_name, dtype_name, dtype_name)) + code.append(""" + _ = kwargs.pop('name', None) + out = kwargs.pop('out', None)""") + if not signature_only: + code.append(""" + _verify_all_np_ndarrays("{op_name}", "{func_name}", ndargs, out) + """.format(op_name=op_name, func_name=func_name)) + code.append(""" + return _imperative_invoke(%d, ndargs, kwargs.keys(), kwargs.values(), out, True, %s)"""%( + handle.value, str(output_is_list))) + else: + code.append(""" + return (0,)""") + else: code.append(""" +def %s(%s):"""%(func_name, ', '.join(signature))) + if not signature_only: + code.append(""" + ndargs = []""") + # NDArray args + for name in ndarg_names: # pylint: disable=redefined-argument-from-local + code.append(""" + if {name} is not None: + assert isinstance({name}, NDArrayBase), \\ + "Argument {name} must have NDArray type, but got %s"%str({name}) + ndargs.append({name})""".format(name=name)) + # kwargs + if not kwarg_names: + code.append(""" + _verify_all_np_ndarrays("{op_name}", "{func_name}", ndargs, out) + """.format(op_name=op_name, func_name=func_name)) + if not signature_only: + code.append(""" + return _np_imperative_invoke(%d, ndargs, out)"""%(handle.value)) + else: + code.append(""" + return (0,)""") + else: + for name in kwarg_names: # pylint: disable=redefined-argument-from-local + code.append(""" + if %s is not _Null: + kwargs['%s'] = %s"""%(name, name, name)) + # dtype + if dtype_name is not None: + code.append(""" + if %s is not _Null and %s is not None: + kwargs['%s'] = _np.dtype(%s).name"""%(dtype_name, dtype_name, dtype_name, dtype_name)) + if not signature_only: + code.append(""" + _verify_all_np_ndarrays("{op_name}", "{func_name}", ndargs, out) + """.format(op_name=op_name, func_name=func_name)) + code.append(""" + return _imperative_invoke(%d, ndargs, kwargs.keys(), kwargs.values(), out, True, %s)"""%( + handle.value, str(output_is_list))) + else: + code.append(""" + return (0,)""") + else: + if arr_name: + code.append(""" +def %s(*%s, **kwargs):"""%(func_name, arr_name)) + if not signature_only: + code.append(""" + ndargs = [] + for i in {}: + assert isinstance(i, NDArrayBase), \\ + "Positional arguments must have NDArray type, " \\ + "but got %s"%str(i) + ndargs.append(i)""".format(arr_name)) + if dtype_name is not None: + code.append(""" + if '%s' in kwargs: + if _np.dtype(kwargs['%s']).names: + kwargs['%s'] = _np.dtype(kwargs['%s']).names[0] + else: + kwargs['%s'] = _np.dtype(kwargs['%s']).name """%( + dtype_name, dtype_name, dtype_name, dtype_name, dtype_name, dtype_name)) + code.append(""" _ = kwargs.pop('name', None) out = kwargs.pop('out', None) keys = list(kwargs.keys()) vals = list(kwargs.values())""") - else: - code.append(""" -def %s(%s):"""%(func_name, ', '.join(signature))) - if not signature_only: + else: code.append(""" +def %s(%s):"""%(func_name, ', '.join(signature))) + if not signature_only: + code.append(""" ndargs = [] keys = list(kwargs.keys()) vals = list(kwargs.values())""") - # NDArray args - for name in ndarg_names: # pylint: disable=redefined-argument-from-local - code.append(""" + # NDArray args + for name in ndarg_names: # pylint: disable=redefined-argument-from-local + code.append(""" if {name} is not None: assert isinstance({name}, NDArrayBase), \\ "Argument {name} must have NDArray type, but got %s"%str({name}) ndargs.append({name})""".format(name=name)) - # kwargs - for name in kwarg_names: # pylint: disable=redefined-argument-from-local - code.append(""" + # kwargs + for name in kwarg_names: # pylint: disable=redefined-argument-from-local + code.append(""" if %s is not _Null: keys.append('%s') vals.append(%s)"""%(name, name, name)) - # dtype - if dtype_name is not None: - if is_np_op: - code.append(""" - if %s is not _Null and %s is not None: - keys.append('%s') - vals.append(_np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name, dtype_name)) - else: + # dtype + if dtype_name is not None: code.append(""" if %s is not _Null: keys.append('%s') @@ -240,24 +321,21 @@ def %s(%s):"""%(func_name, ', '.join(signature))) else: vals.append(_np.dtype(%s).name) """%(dtype_name, dtype_name, dtype_name, dtype_name, dtype_name)) - - verify_ndarrays_fn =\ - _verify_all_np_ndarrays.__name__ if is_np_op else _verify_all_legacy_ndarrays.__name__ - if not signature_only: - code.append(""" - {verify_fn}("{op_name}", "{func_name}", ndargs, out) - """.format(verify_fn=verify_ndarrays_fn, op_name=op_name, func_name=func_name)) - code.append(""" - return _imperative_invoke(%d, ndargs, keys, vals, out, %s, %s)"""%( - handle.value, str(is_np_op), str(output_is_list))) - else: - code.append(""" + if not signature_only: + code.append(""" + _verify_all_legacy_ndarrays("{op_name}", "{func_name}", ndargs, out) + """.format(op_name=op_name, func_name=func_name)) + code.append(""" + return _imperative_invoke(%d, ndargs, keys, vals, out, False, %s)"""%( + handle.value, str(output_is_list))) + else: + code.append(""" return (0,)""") doc_str_lines = _os.linesep+''.join([' '+s if s.strip() else s for s in 'r"""{doc_str}"""'.format(doc_str=doc_str) .splitlines(True)]) - code.insert(doc_str_idx, doc_str_lines) + code.insert(1, doc_str_lines) return ''.join(code), doc_str diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index dd6504ef8fb7..f7a7d47222fc 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -48,6 +48,7 @@ is_np_default_dtype from ..context import current_context from ..ndarray import numpy as _mx_nd_np +from ..ndarray.numpy import _api_internal from ..ndarray.numpy import _internal as _npi from ..ndarray.ndarray import _storage_type from ..dlpack import ndarray_from_numpy @@ -1478,7 +1479,7 @@ def astype(self, dtype, order='K', casting='unsafe', subok=True, copy=True): # if not copy and _np.dtype(dtype) == self.dtype: return self - return _npi.cast(self, dtype=dtype) + return _api_internal.cast(self, _np.dtype(dtype).name) def copyto(self, other): """Copies the value of this array to another array. diff --git a/src/api/cached_op_api.cc b/src/api/cached_op_api.cc index 1c325d229da3..bdae26abf8b6 100644 --- a/src/api/cached_op_api.cc +++ b/src/api/cached_op_api.cc @@ -29,7 +29,7 @@ namespace mxnet { -MXNET_REGISTER_GLOBAL("cached_op.invoke") +MXNET_REGISTER_GLOBAL("ndarray.cached_op_invoke") .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { CachedOpPtr op_shared = *static_cast(args[0].value().v_handle); // CachedOp* points to CachedOpThreadSafe object if CreateCachedOpEX @@ -88,7 +88,7 @@ MXNET_REGISTER_GLOBAL("cached_op.invoke") } }); -MXNET_REGISTER_GLOBAL("cached_op.create") +MXNET_REGISTER_GLOBAL("ndarray.cached_op_create") .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { nnvm::Symbol* sym = static_cast(args[0].value().v_handle); Object* flags_ptr = static_cast(args[1].value().v_handle); @@ -110,13 +110,13 @@ MXNET_REGISTER_GLOBAL("cached_op.create") *ret = static_cast(out); }); -MXNET_REGISTER_GLOBAL("cached_op.free") +MXNET_REGISTER_GLOBAL("ndarray.cached_op_free") .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { CachedOpPtr* g = static_cast(args[0].value().v_handle); delete g; }); -MXNET_REGISTER_GLOBAL("cached_op.get_optimized_symbol") +MXNET_REGISTER_GLOBAL("ndarray.cached_op_get_optimized_symbol") .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { auto s = new nnvm::Symbol(); CachedOpPtr op = *static_cast(args[0].value().v_handle); @@ -124,7 +124,7 @@ MXNET_REGISTER_GLOBAL("cached_op.get_optimized_symbol") *ret = static_cast(static_cast(s)); }); -MXNET_REGISTER_GLOBAL("cached_op.register_op_hook") +MXNET_REGISTER_GLOBAL("ndarray.cached_op_register_op_hook") .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { CachedOpHandle handle = static_cast(args[0].value().v_handle); CachedOpMonitorCallback callback = reinterpret_cast( diff --git a/src/api/operator/numpy/np_cast.cc b/src/api/operator/numpy/np_cast.cc new file mode 100644 index 000000000000..fdbfe2001884 --- /dev/null +++ b/src/api/operator/numpy/np_cast.cc @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_cast.cc + * \brief Implementation of the API of functions in src/operator/numpy/np_cast.cc + */ +#include +#include +#include "../utils.h" +#include "../../../operator/tensor/amp_cast.h" +#include "../../../operator/tensor/elemwise_unary_op.h" + +namespace mxnet { + +MXNET_REGISTER_API("_npi.amp_cast") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npi_amp_cast"); + op::AMPCastParam param; + // dtype + if (args[1].type_code() == kNull) { + param.dtype = mxnet::common::GetDefaultDtype(); + } else { + param.dtype = String2MXNetTypeWithBool(args[1].operator std::string()); + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + // inputs + NDArray* inputs[] = {args[0].operator NDArray*()}; + int num_inputs = 1; + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; +}); + +MXNET_REGISTER_API("_npi.cast") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npi_cast"); + op::CastParam param; + // dtype + if (args[1].type_code() == kNull) { + param.dtype = mxnet::common::GetDefaultDtype(); + } else { + param.dtype = String2MXNetTypeWithBool(args[1].operator std::string()); + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + // inputs + NDArray* inputs[] = {args[0].operator NDArray*()}; + int num_inputs = 1; + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; +}); + +} // namespace mxnet diff --git a/src/api/operator/numpy/np_registered_op.cc b/src/api/operator/numpy/np_registered_op.cc new file mode 100644 index 000000000000..507e4d390489 --- /dev/null +++ b/src/api/operator/numpy/np_registered_op.cc @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_registered_op.cc + * \brief Implementation of the API of functions in src/operator/numpy/np_registered_op.cc + */ +#include +#include +#include +#include "../utils.h" + +namespace mxnet { + +MXNET_REGISTER_GLOBAL("ndarray.invoke") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + int args_size = args.size(); + const nnvm::Op* op = static_cast(args[0].value().v_handle); + int num_inputs = args_size - 2; + std::vector ndinputs; + ndinputs.reserve(num_inputs); + for (int i = 1; i < num_inputs + 1; ++i) { + ndinputs.push_back(static_cast(args[i])); + } + nnvm::NodeAttrs attrs; + attrs.op = op; + int out_type_code = args[args_size - 1].type_code(); + NDArray* out = args[args_size - 1].operator mxnet::NDArray*(); + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + + auto ndoutputs = Invoke(op, &attrs, num_inputs, ndinputs.data(), &num_outputs, outputs); + + if (out_type_code == kNull) { + if (num_outputs == 1) { + *ret = reinterpret_cast(ndoutputs[0]); + } else { + std::vector outputs_obj; + outputs_obj.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + ObjectRef out = NDArrayHandle(ndoutputs[i]); + outputs_obj.push_back(out); + delete ndoutputs[i]; + } + *ret = runtime::ADT(0, outputs_obj.begin(), outputs_obj.end()); + } + } else { + *ret = PythonArg(args_size - 1); + } +}); + +} // namespace mxnet diff --git a/src/operator/tensor/amp_cast.h b/src/operator/tensor/amp_cast.h index 685a05a14e4f..f31edc8b6e8b 100644 --- a/src/operator/tensor/amp_cast.h +++ b/src/operator/tensor/amp_cast.h @@ -26,12 +26,14 @@ #define MXNET_OPERATOR_TENSOR_AMP_CAST_H_ #include +#include #include #include #include "../mshadow_op.h" #include "../mxnet_op.h" #include "../elemwise_op_common.h" #include "../operator_common.h" +#include "../../api/operator/op_utils.h" namespace mxnet { namespace op { @@ -44,6 +46,11 @@ struct AMPCastParam : public dmlc::Parameter { MXNET_ADD_ALL_TYPES .describe("Output data type."); } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream dtype_s; + dtype_s << dtype; + (*dict)["dtype"] = MXNetTypeWithBool2String(dtype); + } }; struct AMPMultiCastParam : public dmlc::Parameter { diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index 4d34c51edc3b..0cefdcd6e0b5 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -470,6 +470,11 @@ struct CastParam : public dmlc::Parameter { MXNET_ADD_ALL_TYPES_WITH_BOOL .describe("Output data type."); } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream dtype_s; + dtype_s << dtype; + (*dict)["dtype"] = MXNetTypeWithBool2String(dtype); + } }; inline bool CastType(const nnvm::NodeAttrs& attrs, diff --git a/tests/python/unittest/test_contrib_intgemm.py b/tests/python/unittest/test_contrib_intgemm.py index ff1559274240..2c3559d7d526 100644 --- a/tests/python/unittest/test_contrib_intgemm.py +++ b/tests/python/unittest/test_contrib_intgemm.py @@ -120,11 +120,13 @@ def test_contrib_intgemm_take_weight(indices, api): @pytest.mark.parametrize('weight_cols', range(8, 24, 8)) @pytest.mark.parametrize('api', [ (mx.nd.contrib, mx.nd, mx.nd.FullyConnected, mx.nd.cast), - (npx, np, npx.fully_connected, npx.cast)]) + (npx, np, npx.fully_connected, mx.nd.np._api_internal.cast)]) def test_contrib_intgemm_multiply(data_rows, inner, weight_cols, api): if "intgemm_fully_connected" not in dir(mx.nd.contrib): return - contrib, top, fully_connected, cast = api + contrib, top, fully_connected, cast_api = api + def cast(a, dtype=None): + return cast_api(a, dtype) #The multiplication routine has approximations so everything is tested #deterministically to ensure bounds are met. random.seed(1)