Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add ffi for full_like, binary (#17811)
Browse files Browse the repository at this point in the history
  • Loading branch information
Alicia1529 committed Mar 18, 2020
1 parent dfb1b88 commit 2fae7e4
Show file tree
Hide file tree
Showing 10 changed files with 199 additions and 27 deletions.
16 changes: 16 additions & 0 deletions benchmark/python/ffi/benchmark_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@ def prepare_workloads():
OpArgMngr.add_workload("add", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("linalg.svd", pool['3x3'])
OpArgMngr.add_workload("split", pool['3x3'], (0, 1, 2), axis=1)
OpArgMngr.add_workload("subtract", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("multiply", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("mod", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("remainder", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("divide", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("true_divide", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("power", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("lcm", pool['2x2'].astype('int32'), pool['2x2'].astype('int32'))
OpArgMngr.add_workload("diff", pool['2x2'], n=1, axis=-1)
OpArgMngr.add_workload("nonzero", pool['2x2'])
OpArgMngr.add_workload("tril", pool['2x2'], k=0)
OpArgMngr.add_workload("expand_dims", pool['2x2'], axis=0)
OpArgMngr.add_workload("broadcast_to", pool['2x2'], (2, 2, 2))
OpArgMngr.add_workload("full_like", pool['2x2'], 2)
OpArgMngr.add_workload("zeros_like", pool['2x2'])
OpArgMngr.add_workload("ones_like", pool['2x2'])
OpArgMngr.add_workload("random.uniform", low=0, high=1, size=1)
OpArgMngr.add_workload("where", pool['2x3'], pool['2x3'], pool['2x1'])
OpArgMngr.add_workload("may_share_memory", pool['2x3'][:0], pool['2x3'][:1])
Expand Down
57 changes: 34 additions & 23 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,7 @@ def zeros_like(a, dtype=None, order='C', ctx=None, out=None):
"""
if order != 'C':
raise NotImplementedError
if ctx is None:
ctx = current_context()
return _npi.full_like(a, fill_value=0, dtype=dtype, ctx=ctx, out=out)
return full_like(a, 0, dtype=dtype, order=order, ctx=ctx, out=out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -270,11 +268,7 @@ def ones_like(a, dtype=None, order='C', ctx=None, out=None):
>>> np.ones_like(y)
array([1., 1., 1.], dtype=float64)
"""
if order != 'C':
raise NotImplementedError
if ctx is None:
ctx = current_context()
return _npi.full_like(a, fill_value=1, dtype=dtype, ctx=ctx, out=out)
return full_like(a, 1, dtype=dtype, order=order, ctx=ctx, out=out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -433,11 +427,15 @@ def full_like(a, fill_value, dtype=None, order='C', ctx=None, out=None): # pylin
"""
if order != 'C':
raise NotImplementedError
if ctx is None:
ctx = current_context()
if isinstance(fill_value, bool):
fill_value = int(fill_value)
return _npi.full_like(a, fill_value=fill_value, dtype=dtype, ctx=ctx, out=out)
if ctx is None:
ctx = str(current_context())
else:
ctx = str(ctx)
if dtype is not None and not isinstance(dtype, str):
dtype = _np.dtype(dtype).name
return _api_internal.full_like(a, fill_value, dtype, ctx, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -1025,8 +1023,9 @@ def subtract(x1, x2, out=None, **kwargs):
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types (including boolean), not supported yet.
"""
return _ufunc_helper(x1, x2, _npi.subtract, _np.subtract, _npi.subtract_scalar,
_npi.rsubtract_scalar, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.subtract(x1, x2, out=out)
return _api_internal.subtract(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -1060,7 +1059,9 @@ def multiply(x1, x2, out=None, **kwargs):
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types (including boolean), not supported yet.
"""
return _ufunc_helper(x1, x2, _npi.multiply, _np.multiply, _npi.multiply_scalar, None, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.multiply(x1, x2, out=out)
return _api_internal.multiply(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -1095,8 +1096,9 @@ def divide(x1, x2, out=None, **kwargs):
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types (including boolean), the output is of float32 type.
"""
return _ufunc_helper(x1, x2, _npi.true_divide, _np.divide, _npi.true_divide_scalar,
_npi.rtrue_divide_scalar, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.divide(x1, x2, out=out)
return _api_internal.true_divide(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -1133,8 +1135,9 @@ def true_divide(x1, x2, out=None):
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types (including boolean), the output is of float32 type.
"""
return _ufunc_helper(x1, x2, _npi.true_divide, _np.divide, _npi.true_divide_scalar,
_npi.rtrue_divide_scalar, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.true_divide(x1, x2, out=out)
return _api_internal.true_divide(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand All @@ -1161,7 +1164,9 @@ def mod(x1, x2, out=None, **kwargs):
out : ndarray or scalar
This is a scalar if both x1 and x2 are scalars.
"""
return _ufunc_helper(x1, x2, _npi.mod, _np.mod, _npi.mod_scalar, _npi.rmod_scalar, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.mod(x1, x2, out=out)
return _api_internal.mod(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -1349,7 +1354,9 @@ def remainder(x1, x2, out=None):
out : ndarray or scalar
This is a scalar if both x1 and x2 are scalars.
"""
return _ufunc_helper(x1, x2, _npi.mod, _np.mod, _npi.mod_scalar, _npi.rmod_scalar, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
_np.mod(x1, x2, out=out)
return _api_internal.mod(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -1377,7 +1384,9 @@ def power(x1, x2, out=None, **kwargs):
The bases in x1 raised to the exponents in x2.
This is a scalar if both x1 and x2 are scalars.
"""
return _ufunc_helper(x1, x2, _npi.power, _np.power, _npi.power_scalar, _npi.rpower_scalar, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.power(x1, x2, out=out)
return _api_internal.power(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -1976,7 +1985,9 @@ def lcm(x1, x2, out=None, **kwargs):
>>> np.lcm(np.arange(6, dtype=int), 20)
array([ 0, 20, 20, 60, 20, 20], dtype=int64)
"""
return _ufunc_helper(x1, x2, _npi.lcm, _np.lcm, _npi.lcm_scalar, None, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.lcm(x1, x2, out=out)
return _api_internal.lcm(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -6658,7 +6669,7 @@ def nonzero(a):
>>> (a > 3).nonzero()
(array([1, 1, 1, 2, 2, 2], dtype=int64), array([0, 1, 2, 0, 1, 2], dtype=int64))
"""
out = _npi.nonzero(a).transpose()
out = _api_internal.nonzero(a).transpose()
return tuple([out[i] for i in range(len(out))])


Expand Down
3 changes: 2 additions & 1 deletion src/api/operator/numpy/np_broadcast_reduce_op_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ MXNET_REGISTER_API("_npi.broadcast_to")

int num_outputs = 0;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, nullptr);
int num_inputs = 1;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});

Expand Down
3 changes: 2 additions & 1 deletion src/api/operator/numpy/np_diff_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ MXNET_REGISTER_API("_npi.diff")

int num_outputs = 0;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, nullptr);
int num_inputs = 1;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});

Expand Down
52 changes: 52 additions & 0 deletions src/api/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,56 @@ MXNET_REGISTER_API("_npi.add")
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

MXNET_REGISTER_API("_npi.subtract")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_subtract");
const nnvm::Op* op_scalar = Op::Get("_npi_subtract_scalar");
const nnvm::Op* op_rscalar = Op::Get("_npi_rsubtract_scalar");
UFuncHelper(args, ret, op, op_scalar, op_rscalar);
});

MXNET_REGISTER_API("_npi.multiply")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_multiply");
const nnvm::Op* op_scalar = Op::Get("_npi_multiply_scalar");
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

MXNET_REGISTER_API("_npi.true_divide")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_true_divide");
const nnvm::Op* op_scalar = Op::Get("_npi_true_divide_scalar");
const nnvm::Op* op_rscalar = Op::Get("_npi_rtrue_divide_scalar");
UFuncHelper(args, ret, op, op_scalar, op_rscalar);
});

MXNET_REGISTER_API("_npi.mod")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_mod");
const nnvm::Op* op_scalar = Op::Get("_npi_mod_scalar");
const nnvm::Op* op_rscalar = Op::Get("_npi_rmod_scalar");
UFuncHelper(args, ret, op, op_scalar, op_rscalar);
});

MXNET_REGISTER_API("_npi.power")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_power");
const nnvm::Op* op_scalar = Op::Get("_npi_power_scalar");
const nnvm::Op* op_rscalar = Op::Get("_npi_rpower_scalar");
UFuncHelper(args, ret, op, op_scalar, op_rscalar);
});

MXNET_REGISTER_API("_npi.lcm")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_lcm");
const nnvm::Op* op_scalar = Op::Get("_npi_lcm_scalar");
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

} // namespace mxnet
33 changes: 33 additions & 0 deletions src/api/operator/numpy/np_init_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* \file np_init_op.cc
* \brief Implementation of the API of functions in src/operator/numpy/np_init_op.cc
*/
#include <dmlc/optional.h>
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../utils.h"
Expand Down Expand Up @@ -55,4 +56,36 @@ MXNET_REGISTER_API("_npi.zeros")
*ret = ndoutputs[0];
});

MXNET_REGISTER_API("_npi.full_like")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_full_like");
nnvm::NodeAttrs attrs;
op::FullLikeOpParam param;
param.fill_value = args[1].operator double();
if (args[2].type_code() == kNull) {
param.dtype = dmlc::nullopt;
} else {
param.dtype = String2MXNetTypeWithBool(args[2].operator std::string());
}
attrs.parsed = std::move(param);
attrs.op = op;
if (args[3].type_code() != kNull) {
attrs.dict["ctx"] = args[3].operator std::string();
}
SetAttrDict<op::FullLikeOpParam>(&attrs);
NDArray* out = args[4].operator mxnet::NDArray*();
NDArray** outputs = out == nullptr ? nullptr : &out;
int num_outputs = out != nullptr;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
int num_inputs = 1;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs);
if (out) {
*ret = PythonArg(4);
} else {
*ret = ndoutputs[0];
}
*ret = ndoutputs[0];
});

} // namespace mxnet
3 changes: 2 additions & 1 deletion src/api/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ MXNET_REGISTER_API("_npi.expand_dims")

int num_outputs = 0;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, nullptr);
int num_inputs = 1;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});

Expand Down
45 changes: 45 additions & 0 deletions src/api/operator/numpy/np_nonzero_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_nonzero_op.cc
* \brief Implementation of the API of functions in src/operator/numpy/np_nonzero_op.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../utils.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.nonzero")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_nonzero");
nnvm::NodeAttrs attrs;

attrs.op = op;

int num_inputs = 1;
int num_outputs = 0;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});

} // namespace mxnet
3 changes: 2 additions & 1 deletion src/api/operator/numpy/np_tril_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ MXNET_REGISTER_API("_npi.tril")

int num_outputs = 0;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, nullptr);
int num_inputs = 1;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});

Expand Down
11 changes: 11 additions & 0 deletions src/operator/tensor/init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,17 @@ struct FullLikeOpParam : public dmlc::Parameter<FullLikeOpParam> {
MXNET_ADD_ALL_TYPES_WITH_BOOL
.describe("Target data type.");
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream fill_value_s, dtype_s;
fill_value_s << fill_value;
dtype_s << dtype;
(*dict)["fill_value"] = fill_value_s.str();
if (dtype.has_value()) {
(*dict)["dtype"] = MXNetTypeWithBool2String(dtype.value());
} else {
(*dict)["dtype"] = dtype_s.str();
}
}
};

/*! \brief Infer type of FullLikeOpCompute*/
Expand Down

0 comments on commit 2fae7e4

Please sign in to comment.