Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[Numpy] FFI: random.shuffle, equal, not_equal, less_equal, greater_eq…
Browse files Browse the repository at this point in the history
…ual, less, maximum and minimum (#17896)

* add random.shuffle, equal, not_equal, less_equal, greater_equal, less, maximum and minimum ffi and benchmark

* fix the implementation of binary ops

* fix some code logic issues (non-commutative)

Co-authored-by: Hao Jin <[email protected]>
  • Loading branch information
AntiZpvoh and haojin2 authored Apr 18, 2020
1 parent 5542d03 commit dcada9b
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 10 deletions.
8 changes: 8 additions & 0 deletions benchmark/python/ffi/benchmark_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ def prepare_workloads():
OpArgMngr.add_workload("tensordot", pool['2x2'], pool['2x2'], ((1, 0), (0, 1)))
OpArgMngr.add_workload("kron", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("cumsum", pool['3x2'], axis=0, out=pool['3x2'])
OpArgMngr.add_workload("random.shuffle", pool['3'])
OpArgMngr.add_workload("equal", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("not_equal", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("less", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("greater_equal", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("less_equal", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("maximum", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("minimum", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("sum", pool['2x2'], axis=0, keepdims=True, out=pool['1x2'])
OpArgMngr.add_workload("std", pool['2x2'], axis=0, ddof=0, keepdims=True, out=pool['1x2'])
OpArgMngr.add_workload("var", pool['2x2'], axis=0, ddof=1, keepdims=True, out=pool['1x2'])
Expand Down
32 changes: 23 additions & 9 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4536,7 +4536,9 @@ def maximum(x1, x2, out=None, **kwargs):
-------
out : mxnet.numpy.ndarray or scalar
The maximum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars."""
return _ufunc_helper(x1, x2, _npi.maximum, _np.maximum, _npi.maximum_scalar, None, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.maximum(x1, x2, out=out)
return _api_internal.maximum(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -4576,7 +4578,9 @@ def minimum(x1, x2, out=None, **kwargs):
-------
out : mxnet.numpy.ndarray or scalar
The minimum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars."""
return _ufunc_helper(x1, x2, _npi.minimum, _np.minimum, _npi.minimum_scalar, None, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.minimum(x1, x2, out=out)
return _api_internal.minimum(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -6612,7 +6616,9 @@ def equal(x1, x2, out=None):
>>> np.equal(1, np.ones(1))
array([ True])
"""
return _ufunc_helper(x1, x2, _npi.equal, _np.equal, _npi.equal_scalar, None, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.equal(x1, x2, out=out)
return _api_internal.equal(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -6644,7 +6650,10 @@ def not_equal(x1, x2, out=None):
>>> np.not_equal(1, np.ones(1))
array([False])
"""
return _ufunc_helper(x1, x2, _npi.not_equal, _np.not_equal, _npi.not_equal_scalar, None, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.not_equal(x1, x2, out=out)
return _api_internal.not_equal(x1, x2, out)



@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -6709,7 +6718,9 @@ def less(x1, x2, out=None):
>>> np.less(1, np.ones(1))
array([False])
"""
return _ufunc_helper(x1, x2, _npi.less, _np.less, _npi.less_scalar, _npi.greater_scalar, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.less(x1, x2, out=out)
return _api_internal.less(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -6741,8 +6752,10 @@ def greater_equal(x1, x2, out=None):
>>> np.greater_equal(1, np.ones(1))
array([True])
"""
return _ufunc_helper(x1, x2, _npi.greater_equal, _np.greater_equal, _npi.greater_equal_scalar,
_npi.less_equal_scalar, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.greater_equal(x1, x2, out=out)
return _api_internal.greater_equal(x1, x2, out)



@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -6774,8 +6787,9 @@ def less_equal(x1, x2, out=None):
>>> np.less_equal(1, np.ones(1))
array([True])
"""
return _ufunc_helper(x1, x2, _npi.less_equal, _np.less_equal, _npi.less_equal_scalar,
_npi.greater_equal_scalar, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.less_equal(x1, x2, out=out)
return _api_internal.less_equal(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/ndarray/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,7 @@ def shuffle(x):
[3., 4., 5.],
[0., 1., 2.]])
"""
_npi.shuffle(x, out=x)
_api_internal.shuffle(x, x)


def laplace(loc=0.0, scale=1.0, size=None, dtype=None, ctx=None, out=None):
Expand Down
74 changes: 74 additions & 0 deletions src/api/operator/numpy/np_elemwise_broadcast_logic_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_elemwise_broadcast_logic_op.cc
* \brief Implementation of the API of functions in src/operator/numpy/np_elemwise_broadcast_logic_op.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../utils.h"
#include "../ufunc_helper.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.equal")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_equal");
const nnvm::Op* op_scalar = Op::Get("_npi_equal_scalar");
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

MXNET_REGISTER_API("_npi.not_equal")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_not_equal");
const nnvm::Op* op_scalar = Op::Get("_npi_not_equal_scalar");
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

MXNET_REGISTER_API("_npi.less")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_less");
const nnvm::Op* op_scalar = Op::Get("_npi_less_scalar");
const nnvm::Op* op_rscalar = Op::Get("_npi_less_scalar");
UFuncHelper(args, ret, op, op_scalar, op_rscalar);
});

MXNET_REGISTER_API("_npi.greater_equal")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_greater_equal");
const nnvm::Op* op_scalar = Op::Get("_npi_greater_equal_scalar");
const nnvm::Op* op_rscalar = Op::Get("_npi_greater_equal_scalar");
UFuncHelper(args, ret, op, op_scalar, op_rscalar);
});

MXNET_REGISTER_API("_npi.less_equal")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_less_equal");
const nnvm::Op* op_scalar = Op::Get("_npi_less_equal_scalar");
const nnvm::Op* op_rscalar = Op::Get("_npi_less_equal_scalar");
UFuncHelper(args, ret, op, op_scalar, op_rscalar);
});

} // namespace mxnet
57 changes: 57 additions & 0 deletions src/api/operator/random/shuffle_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file shuffle_op.cc
* \brief Implementation of the API of functions in src/operator/random/shuffle_op.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../utils.h"
#include "../../../operator/elemwise_op_common.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.shuffle")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_shuffle");
nnvm::NodeAttrs attrs;

NDArray* inputs[1];
int num_inputs = 1;

if (args[0].type_code() != kNull) {
inputs[0] = args[0].operator mxnet::NDArray *();
}

attrs.op = op;

NDArray* out = args[1].operator mxnet::NDArray*();
NDArray** outputs = out == nullptr ? nullptr : &out;
int num_outputs = out != nullptr;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs);
if (out) {
*ret = PythonArg(1);
} else {
*ret = ndoutputs[0];
}
});

} // namespace mxnet
47 changes: 47 additions & 0 deletions src/api/operator/tensor/elemwise_binary_broadcast_op_extended.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file elemwise_binary_broadcast_op_extended.cc
* \brief Implementation of the API of functions in src/operator/tensor/elemwise_binary_broadcast_op_extended.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../utils.h"
#include "../ufunc_helper.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.maximum")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_maximum");
const nnvm::Op* op_scalar = Op::Get("_npi_maximum_scalar");
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

MXNET_REGISTER_API("_npi.minimum")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_minimum");
const nnvm::Op* op_scalar = Op::Get("_npi_minimum_scalar");
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

} // namespace mxnet

0 comments on commit dcada9b

Please sign in to comment.