Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[BUGFIX][API] fix #20447 #20454

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
30 changes: 15 additions & 15 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,7 @@ def add(x1, x2, out=None, **kwargs):
"""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.add(x1, x2, out=out)
return _api_internal.add(x1, x2, out)
return _api_internal.add(x1, x2, out, False)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -1063,7 +1063,7 @@ def subtract(x1, x2, out=None, **kwargs):
"""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.subtract(x1, x2, out=out)
return _api_internal.subtract(x1, x2, out)
return _api_internal.subtract(x1, x2, out, False)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -1099,7 +1099,7 @@ def multiply(x1, x2, out=None, **kwargs):
"""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.multiply(x1, x2, out=out)
return _api_internal.multiply(x1, x2, out)
return _api_internal.multiply(x1, x2, out, False)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -1138,7 +1138,7 @@ def divide(x1, x2, out=None, **kwargs):
"""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.divide(x1, x2, out=out)
return _api_internal.true_divide(x1, x2, out)
return _api_internal.true_divide(x1, x2, out, False)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -1179,7 +1179,7 @@ def true_divide(x1, x2, out=None):
"""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.true_divide(x1, x2, out=out)
return _api_internal.true_divide(x1, x2, out)
return _api_internal.true_divide(x1, x2, out, False)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -1218,7 +1218,7 @@ def floor_divide(x1, x2, out=None):
"""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.floor_divide(x1, x2, out=out)
return _api_internal.floor_divide(x1, x2, out)
return _api_internal.floor_divide(x1, x2, out, False)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -1247,7 +1247,7 @@ def mod(x1, x2, out=None, **kwargs):
"""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.mod(x1, x2, out=out)
return _api_internal.mod(x1, x2, out)
return _api_internal.mod(x1, x2, out, False)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -1437,7 +1437,7 @@ def matmul(a, b, out=None):
...
mxnet.base.MXNetError: ... : Multiplication by scalars is not allowed.
"""
return _api_internal.matmul(a, b, out)
return _api_internal.matmul(a, b, out, False)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -1466,7 +1466,7 @@ def remainder(x1, x2, out=None):
"""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
_np.mod(x1, x2, out=out)
return _api_internal.mod(x1, x2, out)
return _api_internal.mod(x1, x2, out, False)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -1496,7 +1496,7 @@ def power(x1, x2, out=None, **kwargs):
"""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.power(x1, x2, out=out)
return _api_internal.power(x1, x2, out)
return _api_internal.power(x1, x2, out, False)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -6883,7 +6883,7 @@ def bitwise_and(x1, x2, out=None, **kwargs):
"""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.bitwise_and(x1, x2, out=out)
return _api_internal.bitwise_and(x1, x2, out)
return _api_internal.bitwise_and(x1, x2, out, False)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -6923,7 +6923,7 @@ def bitwise_xor(x1, x2, out=None, **kwargs):
"""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.bitwise_xor(x1, x2, out=out)
return _api_internal.bitwise_xor(x1, x2, out)
return _api_internal.bitwise_xor(x1, x2, out, False)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -6963,7 +6963,7 @@ def bitwise_or(x1, x2, out=None, **kwargs):
"""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.bitwise_or(x1, x2, out=out)
return _api_internal.bitwise_or(x1, x2, out)
return _api_internal.bitwise_or(x1, x2, out, False)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -10089,7 +10089,7 @@ def bitwise_left_shift(x1, x2, out=None):
"""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.left_shift(x1, x2, out=out)
return _api_internal.bitwise_left_shift(x1, x2, out)
return _api_internal.bitwise_left_shift(x1, x2, out, False)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -10128,4 +10128,4 @@ def bitwise_right_shift(x1, x2, out=None):
"""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.right_shift(x1, x2, out=out)
return _api_internal.bitwise_right_shift(x1, x2, out)
return _api_internal.bitwise_right_shift(x1, x2, out, False)
30 changes: 15 additions & 15 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
wrap_sort_functions
from ..device import current_device
from ..ndarray import numpy as _mx_nd_np
from ..ndarray.numpy import _internal as _npi
from ..ndarray.numpy import _internal as _npi, _api_internal
from ..ndarray.ndarray import _storage_type
from ..dlpack import ndarray_from_numpy, ndarray_to_dlpack_for_write, DLDeviceType,\
ndarray_from_dlpack
Expand Down Expand Up @@ -1091,7 +1091,7 @@ def __iadd__(self, other):
"""x.__iadd__(y) <=> x += y"""
if not self.writable:
raise ValueError('trying to add to a readonly ndarray')
return add(self, other, out=self)
return _api_internal.add(self, other, self, True)

@wrap_mxnp_np_ufunc
def __radd__(self, other):
Expand Down Expand Up @@ -1145,27 +1145,27 @@ def __rshift__(self, other):
@wrap_mxnp_np_ufunc
def __iand__(self, other):
"""x.__iand__(y) <=> x &= y"""
return bitwise_and(self, other, out=self)
return _api_internal.bitwise_and(self, other, self, True)

@wrap_mxnp_np_ufunc
def __ior__(self, other):
r"""x.__ior__(y) <=> x \|= y"""
return bitwise_or(self, other, out=self)
return _api_internal.bitwise_or(self, other, self, True)

@wrap_mxnp_np_ufunc
def __ixor__(self, other):
"""x.__ixor__(y) <=> x ^= y"""
return bitwise_xor(self, other, out=self)
return _api_internal.bitwise_xor(self, other, self, True)

@wrap_mxnp_np_ufunc
def __ilshift__(self, other):
"""x.__ilshift__(y) <=> x <<= y"""
return bitwise_left_shift(self, other, out=self)
return _api_internal.bitwise_left_shift(self, other, self, True)

@wrap_mxnp_np_ufunc
def __irshift__(self, other):
"""x.__irshift__(y) <=> x >>= y"""
return bitwise_right_shift(self, other, out=self)
return _api_internal.bitwise_right_shift(self, other, self, True)

@wrap_mxnp_np_ufunc
def __rlshift__(self, other):
Expand Down Expand Up @@ -1207,7 +1207,7 @@ def __isub__(self, other):
"""x.__isub__(y) <=> x -= y"""
if not self.writable:
raise ValueError('trying to subtract from a readonly ndarray')
return subtract(self, other, out=self)
return _api_internal.subtract(self, other, self, True)

@wrap_mxnp_np_ufunc
def __rsub__(self, other):
Expand All @@ -1229,7 +1229,7 @@ def __ifloordiv__(self, other):
"""x.__ifloordiv__(y) <=> x //= y"""
if not self.writable:
raise ValueError('trying to divide from a readonly ndarray')
return floor_divide(self, other, out=self)
return _api_internal.floor_divide(self, other, self, True)

@wrap_mxnp_np_ufunc
def __rfloordiv__(self, other):
Expand All @@ -1249,7 +1249,7 @@ def __imul__(self, other):
r"""x.__imul__(y) <=> x \*= y"""
if not self.writable:
raise ValueError('trying to add to a readonly ndarray')
return multiply(self, other, out=self)
return _api_internal.multiply(self, other, self, True)

@wrap_mxnp_np_ufunc
def __rmul__(self, other):
Expand All @@ -1269,7 +1269,7 @@ def __rdiv__(self, other):
@wrap_mxnp_np_ufunc
def __idiv__(self, other):
"""x.__idiv__(y) <=> x /= y"""
return divide(self, other, out=self)
return _api_internal.true_divide(self, other, self, True)

@wrap_mxnp_np_ufunc
def __truediv__(self, other):
Expand All @@ -1284,7 +1284,7 @@ def __rtruediv__(self, other):
@wrap_mxnp_np_ufunc
def __itruediv__(self, other):
"""x.__itruediv__(y) <=> x /= y"""
return divide(self, other, out=self)
return _api_internal.true_divide(self, other, self, True)

@wrap_mxnp_np_ufunc
def __mod__(self, other):
Expand All @@ -1299,7 +1299,7 @@ def __rmod__(self, other):
@wrap_mxnp_np_ufunc
def __imod__(self, other):
"""x.__imod__(y) <=> x %= y"""
return mod(self, other, out=self)
return _api_internal.mod(self, other, self, True)

@wrap_mxnp_np_ufunc
def __pow__(self, other):
Expand All @@ -1314,7 +1314,7 @@ def __rpow__(self, other):
@wrap_mxnp_np_ufunc
def __ipow__(self, other):
"""x.__ipow__(y) <=> x **= y"""
return power(self, other, out=self)
return _api_internal.power(self, other, self, True)

@wrap_mxnp_np_ufunc
def __eq__(self, other):
Expand Down Expand Up @@ -1362,7 +1362,7 @@ def __rmatmul__(self, other):
@wrap_mxnp_np_ufunc
def __imatmul__(self, other):
"""x.__imatmul__(y) <=> x @= y"""
return matmul(self, other, out=self)
return _api_internal.matmul(self, other, self, True)

def __bool__(self):
num_elements = self.size
Expand Down
40 changes: 39 additions & 1 deletion src/api/operator/ufunc_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "utils.h"
#include "../../imperative/imperative_utils.h"
#include "../../operator/tensor/elemwise_binary_scalar_op.h"
#include "../../operator/numpy/np_elemwise_broadcast_op.h"

namespace mxnet {

Expand Down Expand Up @@ -55,6 +56,31 @@ void UFuncHelper(NDArray* lhs,
}
}

void UFuncHelper(NDArray* lhs,
NDArray* rhs,
NDArray* out,
runtime::MXNetRetValue* ret,
const nnvm::Op* op,
bool in_place) {
using namespace runtime;
nnvm::NodeAttrs attrs;
op::NumpyBinaryParam param = {};
param.in_place = in_place;
attrs.op = op;
attrs.parsed = param;
SetAttrDict<op::NumpyBinaryParam>(&attrs);
NDArray* inputs[] = {lhs, rhs};
int num_inputs = 2;
NDArray** outputs = out == nullptr ? nullptr : &out;
int num_outputs = out != nullptr;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs);
if (outputs) {
*ret = PythonArg(2);
} else {
*ret = reinterpret_cast<NDArray*>(ndoutputs[0]);
}
}

void UFuncHelper(NDArray* lhs,
int64_t rhs,
NDArray* out,
Expand Down Expand Up @@ -164,7 +190,19 @@ void UFuncHelper(runtime::MXNetArgs args,
NDArray* out = args[2].operator NDArray*();
if (args[0].type_code() == kNDArrayHandle) {
if (args[1].type_code() == kNDArrayHandle) {
UFuncHelper(args[0].operator NDArray*(), args[1].operator NDArray*(), out, ret, fn_array);
int args_size = args.size();
if (args_size == 4) {
bool in_place = args[3].operator bool();
if (in_place) {
UFuncHelper(
args[0].operator NDArray*(), args[1].operator NDArray*(), out, ret, fn_array, true);
} else {
UFuncHelper(
args[0].operator NDArray*(), args[1].operator NDArray*(), out, ret, fn_array, false);
}
} else {
UFuncHelper(args[0].operator NDArray*(), args[1].operator NDArray*(), out, ret, fn_array);
}
} else if (args[1].type_code() == kDLInt) {
UFuncHelper(args[0].operator NDArray*(), args[1].operator int64_t(), out, ret, lfn_scalar);
} else {
Expand Down
Loading